In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datetime import datetime
from types import SimpleNamespace
import os
from glob import glob
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import wandb
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import xarray
import pandas as pd
from dataset.dataset import ImageDataset, SingleImageDataset, pickle_read, SeviriDataset
from dataset.normalization import MinMax, ZeroMinMax
from dataset.station_dataset import GroundstationDataset
from train import get_dataloaders
from lightning.pytorch import Trainer, LightningDataModule
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from models.ConvResNet_Jiang import ConvResNet, ConvResNet_dropout, ConvResNet_batchnormMLP
from models.LightningModule import LitEstimatorPoint
from torch.utils.data import DataLoader, Dataset, Subset
from torchmetrics import MeanSquaredError
from tqdm import tqdm
from utils.plotting import prediction_error_plot, plot_station_locations
from cartopy.mpl.gridliner import (
    LongitudeFormatter,
    LatitudeFormatter,
    LongitudeLocator,
    LatitudeLocator,
)
plt.rcParams['text.usetex'] = False

KeyboardInterrupt: 

In [None]:
## Set Up Model

# Emulator
chkpt_fn = '/scratch/snx3000/kschuurm/irradiance_estimation/train/SIS_point_estimation/jup3gn3n/checkpoints/epoch=1-val_loss=0.00705.ckpt'



# chkpt_fn = '/scratch/snx3000/kschuurm/irradiance_estimation/train/SIS_point_estimation/4nbyae30/checkpoints/epoch=7-val_loss=0.01023.ckpt'

inference_fn = chkpt_fn.split('checkpoints')[0] + 'inference/'
if not os.path.exists(inference_fn):
    os.mkdir(inference_fn)

config = {
    "batch_size": 512,
    "patch_size": {
        "x": 15,
        "y": 15,
        "stride_x": 1,
        "stride_y": 1,
    },
    "x_vars": [
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "channel_8",
        "channel_9",
        "channel_10",
        "channel_11",
        "DEM",
    ],
    "y_vars": ["SIS",],
    "x_features": ["dayofyear", "lat", "lon", 'SZA', "AZI",],# 'sat_SZA', 'sat_AZI', 'coscatter_angle'],
    "transform": ZeroMinMax(),
    "target_transform": ZeroMinMax(),
    'binned':True,
    'max_epochs': 10,
    # Compute related
    'num_workers': 24,
    'ACCELERATOR': "gpu",
    'DEVICES': -1,
    'NUM_NODES': 1,
    'STRATEGY': "ddp",
    'PRECISION': "32",
    'EarlyStopping': {'patience':2},
    'ModelCheckpoint':{'every_n_epochs':1, 'save_top_k':1}
}
config = SimpleNamespace(**config)

In [None]:
zarr_fns = glob('../../ZARR/DWD/DWD_SOLAR_*.zarr')
station_names = [int(os.path.basename(fn).split('SOLAR_')[-1].split('.')[0]) for fn in zarr_fns]
index = xarray.open_dataset('/scratch/snx3000/kschuurm/DATA/DWD/netcdf/DWD_SOLAR_index.nc')
index = index.sel(station_id=station_names)

# train_id, valid_id = torch.utils.data.random_split(station_names, [.8, .2])
# print(list(train_id), list(valid_id))

train_id = [15000, 2638, 662, 342, 691, 4104, 1684, 5426, 1766, 3167, 596, 880, 1346, 4271, 1550, 3196, 5792, 2485, 856, 1468, 3287, 4336, 701, 3126, 891, 1078, 4393, 963, 5705, 5546, 7368, 4887, 164, 704, 2261, 656, 2559, 6197, 3513, 3032, 7351, 430, 1443, 2907, 5856, 5404, 6163, 2483, 3268, 2601, 15444, 13674, 7374, 5480, 7367, 4745, 2014, 4625, 5100, 3761, 460, 7369, 3086, 3366, 282, 591, 1639, 232, 4177, 7370, 2667, 4928, 2712, 4466, 5397, 5516, 1975, 1503, 2115, 1605]
valid_id = [1757, 5109, 953, 3028, 2290, 5906, 2171, 427, 2932, 2812, 5839, 1691, 3811, 1420, 5142, 4911, 3660, 3730, 1048]

index_train = index.sel(station_id=train_id); index_valid = index.sel(station_id=valid_id)

a = index_train.plot.scatter(x='lon', y='lat', c='b', subplot_kws=dict(projection=ccrs.PlateCarree()), transform=ccrs.PlateCarree())
index_valid.plot.scatter(x='lon', y='lat', c='r', subplot_kws=dict(projection=ccrs.PlateCarree()), transform=ccrs.PlateCarree())
a.axes.gridlines()
a.axes.stock_img()
a.axes.set_extent([-5, 25, 40, 60])
plt.show()

In [None]:
valid_datasets = [GroundstationDataset(f'../../ZARR/DWD/DWD_SOLAR_{str(x).zfill(5)}.zarr', 
                                       config.y_vars, config.x_vars, config.x_features, config.patch_size['x'], 
                                       config.transform, config.target_transform)
            for x in tqdm(valid_id)]
valid_dataset = torch.utils.data.ConcatDataset(valid_datasets)


In [12]:
train_datasets = [GroundstationDataset(f'../../ZARR/DWD/DWD_SOLAR_{str(x).zfill(5)}.zarr', 
                                       config.y_vars, config.x_vars, config.x_features, config.patch_size['x'], 
                                       config.transform, config.target_transform,
                                      binned=config.binned)
            for x in tqdm(train_id)]
train_dataset = torch.utils.data.ConcatDataset(train_datasets)


100%|██████████| 80/80 [02:31<00:00,  1.90s/it]


In [None]:
zarr_fns = glob("../../ZARR/METEOSWISS/METEOSWISS_SOLAR_*.zarr")
station_names_meteoswiss = [
    os.path.basename(fn).split("SOLAR_")[-1].split(".")[0] for fn in zarr_fns
]
train_sample_meteoswiss = [
    "KOP",
    "HAI",
    "ELM",
    "JUN",
    "MAH",
    "BAS",
    "MRP",
    "CDF",
    "AND",
    "GLA",
    "EBK",
    "RUE",
    "PMA",
    "DIS",
    "CRM",
    "GOR",
    "VLS",
    "MOE",
    "SHA",
    "CHA",
    "SBE",
    "EIN",
    "GEN",
    "PIL",
    "GIH",
    "NAS",
    "BOL",
    "GVE",
    "BUS",
    "NEU",
    "ORO",
    "MTR",
    "VIS",
    "SBO",
    "PLF",
    "GRH",
    "UEB",
    "LEI",
    "ABO",
    "ARO",
    "SIM",
    "CGI",
    "MUB",
    "BRL",
    "WFJ",
    "MAG",
    "BER",
    "EVO",
    "LAT",
    "LUZ",
]

datasets_meteoswisss = [
    GroundstationDataset(
        f"../../ZARR/METEOSWISS/METEOSWISS_SOLAR_{str(x)}.zarr",
        config.y_vars,
        config.x_vars,
        config.x_features,
        config.patch_size["x"],
        config.transform,
        config.target_transform,
        binned=config.binned,
    )
    for x in tqdm(train_sample_meteoswiss)
]

train_dataset_meteoswiss = torch.utils.data.ConcatDataset(datasets_meteoswisss)
# train_dataset = torch.utils.data.ConcatDataset([train_dataset, train_dataset_meteoswiss])

In [None]:
zarr_fns = glob('../../ZARR/IEA_PVPS/IEA_PVPS_*.zarr')
station_names_bsrn = [os.path.basename(fn).split('IEA_PVPS_')[-1].split('.')[0] for fn in zarr_fns]
bsrn_datasets = [GroundstationDataset(f'../../ZARR/IEA_PVPS/IEA_PVPS_{x}.zarr', 
                                        config.y_vars, config.x_vars, config.x_features, config.patch_size['x'], 
                                        config.transform, config.target_transform, sarah_idx_only=True)
                            for x in tqdm(station_names_bsrn)]
bsrn_dataset = torch.utils.data.ConcatDataset(bsrn_datasets)

In [None]:
enermena = [
    "OUJ",
    "TAT",
]
enermena_dataset = torch.utils.data.ConcatDataset(
    [
        GroundstationDataset(
            f"../../ZARR/IEA_PVPS/IEA_PVPS_{x}.zarr",
            config.y_vars,
            config.x_vars,
            config.x_features,
            config.patch_size["x"],
            config.transform,
            config.target_transform,
            sarah_idx_only=True,
        )
        for x in tqdm(enermena)
    ]
)

In [None]:
# OUJ = bsrn_datasets.pop(-3)
bsrn_dataset = torch.utils.data.ConcatDataset(bsrn_datasets)
train_dataset = torch.utils.data.ConcatDataset([train_dataset, OUJ])

In [13]:

class DataModule(LightningDataModule):

  def __init__(self, train_dataset, val_dataset,  bsrn_dataset, batch_size = 2):

    super(DataModule, self).__init__()
    self.train_dataset = bsrn_dataset
    self.val_dataset = val_dataset
    self.bsrn_dataset = bsrn_dataset
    self.batch_size = batch_size
    
  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True, num_workers=config.num_workers)
  
  def val_dataloader(self):
    val_loader1 = DataLoader(self.val_dataset, batch_size = self.batch_size, shuffle = False, num_workers=config.num_workers)
    val_loader2 = DataLoader(self.bsrn_dataset, batch_size = self.batch_size, shuffle = False, num_workers=config.num_workers)
    return [val_loader1, val_loader2]

dm = DataModule(train_dataset, valid_dataset, bsrn_dataset, config.batch_size)


In [None]:
# load checkpoint

estimator = LitEstimatorPoint(
        learning_rate=1e-5,
        config=config,
        metric=MeanSquaredError(),
    )

# Groundstations only 

In [None]:

wandb_logger = WandbLogger(name='groundstations_only', project="SIS_point_estimation_groundstation")

mc_sarah = ModelCheckpoint(
        monitor='val_loss/dataloader_idx_0', 
        save_top_k = 3,
        save_last=True
        # filename='{epoch}-{val_loss:.5f}'
    ) 

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator='gpu',
    precision=config.PRECISION,
    callbacks=[mc_sarah],
    log_every_n_steps=50,
    num_sanity_val_steps=0,
    max_time="00:2:00:00",
    val_check_interval=0.5,
)


In [None]:
# trainer.validate(estimator, dm)
trainer.fit(estimator, dm, ckpt_path='last')

In [None]:
wandb_logger.experiment.finish()

# Finetuning the Emulator

In [14]:
chkpt_fn = '/scratch/snx3000/kschuurm/irradiance_estimation/train/SIS_point_estimation_groundstation/4j9y9tqb/checkpoints/epoch=3-bsrnval_loss/dataloader_idx_1=0.01907-dwdval_loss/dataloader_idx_0=0.01901.ckpt'

estimator = LitEstimatorPoint.load_from_checkpoint(chkpt_fn)

/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'metric' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['metric'])`.


In [15]:
# for par in estimator.model.mlp.parameters():
#     par.requires_grad = False

# for name, par in estimator.model.named_parameters():
#     if 'bias' in name:
#         par.requires_grad = False
#     else:
#         continue
    

estimator.set_reference_parameters([par.clone().detach() for par in estimator.model.parameters()])
estimator.parameter_loss = True
estimator.alpha = 5000

In [16]:
for name, par in estimator.model.mlp.named_parameters():
    print(name, par.shape, par.requires_grad)

0.weight torch.Size([256, 261]) True
0.bias torch.Size([256]) True
1.weight torch.Size([256]) True
1.bias torch.Size([256]) True
3.weight torch.Size([64, 256]) True
3.bias torch.Size([64]) True
4.weight torch.Size([64]) True
4.bias torch.Size([64]) True
6.weight torch.Size([1, 64]) True
6.bias torch.Size([1]) True


In [17]:
estimator.lr = 1e-6

In [18]:
from torchmetrics import MeanAbsoluteError, MeanMetric
from train import get_dataloaders

wandb_logger = WandbLogger(name='finetuned on dwd -> meteoswiss -> bsrn, par loss only', project="SIS_point_estimation_groundstation")

mc_sarah = ModelCheckpoint(
        monitor='val_loss/dataloader_idx_1', 
        save_top_k = 3,
        filename='{epoch}-bsrn{val_loss/dataloader_idx_1:.5f}-dwd{val_loss/dataloader_idx_0:.5f}',
        save_last=True,
    ) 

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=10,
    accelerator='gpu',
    precision=config.PRECISION,
    callbacks=[mc_sarah],
    log_every_n_steps=500,
    num_sanity_val_steps=0,
    max_time="00:3:00:00",
    # val_check_interval=1,
    # check_val_every_n_epoch=4,
    
)



/scratch/snx3000/kschuurm/lightning-env/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/kschuurm/lightning-env/lib/python3. ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# trainer.validate(estimator, val_dataloader)
trainer.fit(
    estimator, dm,
)

In [None]:
trainer.validate(
    estimator, train_dataset_meteoswiss
)

In [None]:
wandb_logger.experiment.finish()

In [None]:
import wandb
wandb.finish()

In [115]:
autoreload