## Single step test



In [None]:
import fsspec
import xarray as xr
import matplotlib.pyplot as plt
import sys

import torch
import math
import numpy as np
import wandb
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

import pyqg_explorer.dataset.forcing_dataset as forcing_dataset
import pyqg_explorer.models.base_model as base_model
import pyqg_explorer.util.pbar as pbar

In [None]:
# use GPUs if available
if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
else:
    print('CUDA Not Available')
    device = torch.device('cpu')

In [None]:
lev=0
data_full=xr.open_zarr(fsspec.get_mapper(f'/scratch/zanna/data/pyqg/publication/eddy/forcing1.zarr'), consolidated=True)
data_dqbar=data_full.dqbar_dt.isel(lev=0)
data_dqbar=data_dqbar.stack(snapshot=("run","time"))
data_dqbar=data_dqbar.transpose("snapshot","y","x")
data_dqbar.isel(snapshot=200).plot()
data_forcing=data_full.q_subgrid_forcing.isel(lev=lev)
data_forcing=data_forcing.stack(snapshot=("run","time"))
data_forcing=data_forcing.transpose("snapshot","y","x")
data_forcing.isel(snapshot=20).plot()
data_q=data_full.q.isel(lev=0)
data_q=data_q.stack(snapshot=("run","time"))
data_q=data_q.transpose("snapshot","y","x")
data_q.isel(snapshot=20).plot()

In [None]:
data_dqbar=data_full.dqbar_dt.isel(lev=0)
data_dqbar=data_dqbar.stack(snapshot=("run","time"))
data_dqbar=data_dqbar.transpose("snapshot","y","x")
data_dqbar.isel(snapshot=200).plot()

In [None]:
data_forcing=data_full.q_subgrid_forcing.isel(lev=lev)
data_forcing=data_forcing.stack(snapshot=("run","time"))
data_forcing=data_forcing.transpose("snapshot","y","x")
data_forcing.isel(snapshot=20).plot()

In [None]:
data_q=data_full.q.isel(lev=0)
data_q=data_q.stack(snapshot=("run","time"))
data_q=data_q.transpose("snapshot","y","x")
data_q.isel(snapshot=20).plot()

del data_full

In [None]:
## Build a single-step dataset
class SingleStepDataset(Dataset):
    """
    Subgrid forcing maps dataset
    """
    def __init__(self,pv,dqbar_dt,s,seed=42,train_ratio=0.75,valid_ratio=0.25,test_ratio=0.0):
        """
        pv:          xarray of the PV field
        dqbar_dt:    xarray of PV tendency
        s:           xarray of the subgrid forcing field
        seed:        random seed used to create train/valid/test splits
        train_ratio: proportion of dataset to use as training data
        valid_ratio: proportion of dataset to use as validation data
        test_ratio:  proportion of dataset to use as test data
        
        """
        super().__init__()
        self.pv=torch.unsqueeze(torch.tensor(pv.to_numpy()),dim=1)
        self.dqbar_dt=torch.unsqueeze(torch.tensor(dqbar_dt.to_numpy()),dim=1)
        self.s=torch.unsqueeze(torch.tensor(s.to_numpy()),dim=1)
        ## Generate array for Q_i+1
        self.pv_plusone=torch.roll(self.pv,1,dims=0)
        
        ## Drop last index, where we have no i+1
        self.pv=self.pv[:-1, :, :, :]
        self.dqbar_dt=self.dqbar_dt[:-1, :, :, :]
        self.s=self.s[:-1, :, :, :]
        self.pv_plusone=self.pv_plusone[1:, :, :, :]
        
        ## Cat into x_data
        self.x_data=torch.cat((self.pv,self.dqbar_dt,self.s),1)
        self.y_data=self.pv_plusone
        
        self.train_ratio=train_ratio
        self.valid_ratio=valid_ratio
        self.test_ratio=test_ratio
        self.rng = np.random.default_rng(seed)

        self.x_renorm=torch.std(self.x_data)
        self.y_renorm=torch.std(self.y_data)
        self.x_data=self.x_data/self.x_renorm
        self.y_data=self.y_data/self.y_renorm
        self.len=len(self.x_data)
        
        assert len(self.x_data)==len(self.y_data), "Number of x and y samples should be the same"
        
        self._get_split_indices()
        
    def _get_split_indices(self):
        """ Set indices for train, valid and test splits """

        ## Randomly shuffle indices of entire dataset
        rand_indices=self.rng.permutation(np.arange(self.len))

        ## Set number of train, valid and test points
        num_train=math.floor(self.len*self.train_ratio)
        num_valid=math.floor(self.len*self.valid_ratio)
        num_test=math.floor(self.len*self.test_ratio)
        
        ## Make sure we aren't overcounting
        assert (num_train+num_valid+num_test) <= self.len
        
        ## Pick train, test and valid indices from shuffled list
        self.train_idx=rand_indices[0:num_train]
        self.valid_idx=rand_indices[num_train+1:num_train+num_valid]
        self.test_idx=rand_indices[len(self.valid_idx)+1:]
        
        ## Make sure there's no overlap between train, valid and test data
        assert len(set(self.train_idx) & set(self.valid_idx) & set(self.test_idx))==0, (
                "Common elements in train, valid or test set")
        
        
    def __len__(self):
        return self.len
    
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return (self.x_data[idx],self.y_data[idx])

In [None]:
## From Andrew/Pavel's code, function to create a CNN block
def make_block(in_channels: int, out_channels: int, kernel_size: int, 
        ReLU = 'ReLU', batch_norm = True) -> list:
    '''
    Packs convolutional layer and optionally ReLU/BatchNorm2d
    layers in a list
    '''
    conv = nn.Conv2d(in_channels, out_channels, kernel_size, 
        padding='same', padding_mode='circular')
    block = [conv]
    if ReLU == 'ReLU':
        block.append(nn.ReLU())
    elif ReLU == 'LeakyReLU':
        block.append(nn.LeakyReLU(0.2))
    elif ReLU == 'False':
        pass
    else:
        print('Error: wrong ReLU parameter')
    if batch_norm:
        block.append(nn.BatchNorm2d(out_channels))
    return block


class AndrewCNN(nn.Module):
    def __init__(self, n_in: int, n_out: int, x_renorm=torch.tensor(1.), y_renorm=torch.tensor(1.), ReLU = 'ReLU', lr=0.001) -> list:
        '''
        Packs sequence of 8 convolutional layers in a list.
        First layer has n_in input channels, and Last layer has n_out
        output channels
        '''
        super().__init__()
        self.lr=lr
        ## Register normalisation factors as buffers
        self.register_buffer('x_renorm', x_renorm)
        self.register_buffer('y_renorm', y_renorm)
        blocks = []
        blocks.extend(make_block(n_in,128,5,ReLU))                #1
        blocks.extend(make_block(128,64,5,ReLU))                  #2
        blocks.extend(make_block(64,32,3,ReLU))                   #3
        blocks.extend(make_block(32,32,3,ReLU))                   #4
        blocks.extend(make_block(32,32,3,ReLU))                   #5
        blocks.extend(make_block(32,32,3,ReLU))                   #6
        blocks.extend(make_block(32,32,3,ReLU))                   #7
        blocks.extend(make_block(32,n_out,3,'False',False))       #8
        self.conv = nn.Sequential(*blocks)

    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
single_dataset=SingleStepDataset(data_q,data_dqbar,data_forcing)

In [None]:
## Wandb config file
config={"lev":lev,
        "forcing":1,
        "framework":"Single-step loss"}

wandb.init(project="pyqg_single_step", entity="chris-pedersen",config=config)
train_loader = DataLoader(
    single_dataset,
    batch_size=64,
    sampler=SubsetRandomSampler(single_dataset.train_idx),
)
valid_loader = DataLoader(
    single_dataset,
    batch_size=64,
    sampler=SubsetRandomSampler(single_dataset.valid_idx),
)

In [None]:
model_theta=base_model.AndrewCNN(1,1,single_dataset.x_renorm,single_dataset.y_renorm)
model_beta=base_model.AndrewCNN(2,1,single_dataset.x_renorm,single_dataset.y_renorm)

model_theta.to(device)
model_beta.to(device)

wandb.watch([model_theta,model_beta], log_freq=1)

In [None]:
# optimizer parameters
beta1 = 0.5
beta2 = 0.999
lr = 0.01
wd = 0.05

optimizer = torch.optim.AdamW(list(model_theta.parameters()) + list(model_beta.parameters()), lr=lr, weight_decay=wd, betas=(beta1, beta2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=10)

criterion=nn.MSELoss()

In [None]:
for epoch in range(150):  # loop over the dataset multiple times

    train_running_loss = 0.0
    train_samples = 0.0
    valid_running_loss = 0.0
    valid_samples = 0.0
    
    model_theta.train()
    model_beta.train()
    for i, data in enumerate(train_loader, 0):
        ## x_data is ordered in [pv, dqdt, s]
        x_data, y_data = data
        x_data=x_data.to(device)
        y_data=y_data.to(device)
        
        #print(x_data)
        #print(x_data[:,2,:,:].unsqueeze(1))
    
        ## zero the parameter gradients
        optimizer.zero_grad()

        ## First network
        output_theta = model_theta(x_data[:,2,:,:].unsqueeze(1)) ## Takes in PV, outputs S
        ## Now evaluate F+\hat{S}, the tendency + estimated forcing
        F_plus_s_hat = output_theta+x_data[:,1,:,:].unsqueeze(1)
        ## Second network
        output_beta = model_beta(torch.cat((x_data[:,1,:,:].unsqueeze(1),F_plus_s_hat),1))
        loss_1 = criterion(output_theta, x_data[2])
        loss_2 = criterion(output_beta, y_data)
        loss = loss_1+loss_2
        loss.backward()
        optimizer.step()
        
        ## Track loss for wandb
        train_running_loss=+loss
        train_samples+=x_data.shape[0]
    
    model_theta.eval()
    model_beta.eval()
    for i, data in enumerate(valid_loader, 0):
        ## x_data is ordered in [pv, dqdt, s]
        x_data, y_data = data
        x_data=x_data.to(device)
        y_data=y_data.to(device)
    
        ## zero the parameter gradients
        optimizer.zero_grad()

        ## First network
        output_theta = model_theta(x_data[:,2,:,:].unsqueeze(1)) ## Takes in PV, outputs S
        ## Now evaluate F+\hat{S}, the tendency + estimated forcing
        F_plus_s_hat = output_theta+x_data[:,1,:,:].unsqueeze(1)
        ## Second network
        output_beta = model_beta(torch.cat((x_data[:,1,:,:].unsqueeze(1),F_plus_s_hat),1))
        val_loss_1 = criterion(output_theta, x_data[2])
        val_loss_2 = criterion(output_beta, y_data)
        val_loss = val_loss_1+val_loss_2
        ## Track loss for wandb
        valid_running_loss=+loss
        valid_samples+=x_data.shape[0]
    
    log_dic={}
    log_dic["training_loss"]=train_running_loss/train_samples
    log_dic["valid_loss"]=valid_running_loss/valid_samples
    wandb.log(log_dic)
    
    # verbose
    print('%03d %.3e %.3e '%(epoch, train_running_loss/train_samples, valid_running_loss), end='')
    print("")

In [None]:
x_train,y_train=next(iter(train_loader))

In [None]:
x_train.shape

In [None]:
x_train[:,1,:,:].unsqueeze(1).shape

In [None]:
output_theta = model_theta(x_train[:,1,:,:].unsqueeze(1))

In [None]:
output_theta.shape

In [None]:
F_plus_shat=output_theta+x_train[:,1,:,:].unsqueeze(1)

In [None]:
F_plus_shat.shape

In [None]:
torch.cat((x_train[:,1,:,:].unsqueeze(1),F_plus_shat),1).shape

In [None]:
output_beta = model_beta(torch.cat((x_train[:,1,:,:].unsqueeze(1),F_plus_shat),1)).shape

In [None]:
c

In [None]:
loss=(model_beta()-y_data)

In [None]:
logger = WandbLogger()
trainer = pl.Trainer(
    default_root_dir="/scratch/cp3759/pyqg_data/models",
    accelerator="auto",
    max_epochs=150,
    callbacks=pbar.ProgressBar(),
    logger=WandbLogger()
)

trainer.fit(model, train_loader, valid_loader)

In [None]:
torch.save(model.state_dict(), '/scratch/cp3759/pyqg_data/models/cnn_1step_upper.torch')

In [None]:
pv=torch.unsqueeze(torch.tensor(data_q.to_numpy()),dim=1)
dqbar_dt=torch.unsqueeze(torch.tensor(data_dqbar.to_numpy()),dim=1)
s=torch.unsqueeze(torch.tensor(data_forcing.to_numpy()),dim=1)

In [None]:
pv_plusone=torch.roll(pv,1,dims=0)

In [None]:
pv=pv[:-1, :, :, :]
pv_plusone=pv_plusone[1:, :, :, :]

In [None]:
s=s[:-1, :, :, :]
dqbar_dt=dqbar_dt[:-1, :, :, :]

In [None]:
torch.cat((pv,dqbar_dt,s),1).shape