# Imports and Data

In [1]:
# Third-party

import xarray as xr

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch import Trainer

# Local imports

from weather_data_class_v1 import WeatherData

ds = xr.open_dataset('data_850/2022_850_SA_coarsen.nc')
ds.load()


Cannot find the ecCodes library


# Model Setup

In [2]:
class SimpleMLP(L.LightningModule):
    def __init__(self, 
                 input_size, 
                 forcing_size, 
                 output_size, 
                 lr = 0.0001, 
                 steps = 1, 
                 lat = 16, 
                 lon = 32,
                 weigths = False,
                 intervals = 1):
        
        super(SimpleMLP, self).__init__()

        self.save_hyperparameters()

        self.fc1 = nn.Linear(input_size + forcing_size, 128)  
        self.fc2 = nn.Linear(128, 64) 
        self.fc3 = nn.Linear(64, output_size) 

        self.loss_fn = nn.MSELoss()

        self.steps = steps

        self.lat = lat
        self.lon = lon

        self.lr = lr

        self.intervals = intervals
        
        if weigths:
            self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)  
            if m.bias is not None:
                nn.init.constant_(m.bias, 0) 

    def forward(self, X, F):

        batch_size = X.size(0)
        X = X.view(batch_size, -1)  

        inputs = torch.cat((X, F), dim=1)  

        x = torch.relu(self.fc1(inputs)) 
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  

        return x.view(-1, 1, 16, 32)
    
    def training_step(self, batch, batch_idx):
        x, F, y = batch  

        loss = self.auto_rollout(x, F, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True) 
        return loss

    def validation_step(self, batch, batch_idx):
        x, F, y = batch  
 
        loss = self.auto_rollout(x, F, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True) 
        
        return loss
    
    def auto_rollout(self, x, F, y):

        cumulative_loss = 0.0
        current_input = x.clone()
        current_F = F.clone()

        for step in range(self.steps): 
            # print('Step: ', step)
            
            y_hat = self(current_input, current_F)

            loss = self.loss_fn(y_hat, y[:, step].reshape(-1, 1, self.lat, self.lon))
            cumulative_loss += loss  
            
            current_input = torch.cat((current_input[:, 1:], y_hat), dim=1)

            hour = current_F[:, 0]
            month = current_F[:, 1]
            
            hour = (hour + self.intervals) % 24

            current_F = torch.stack((hour, month), dim=1).float()

        return cumulative_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr= self.lr)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
        
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}

# Training

In [2]:
window_size = 15
steps = 13
batch_size = 32
intervals = 3

model_class = WeatherData(ds, window_size=window_size, steps=steps, use_forcings=True, intervals=intervals)


Time intervals applied
Data split
Data normalized


In [8]:
ar_steps = 3

model = SimpleMLP(
        model_class.input_size, 
        model_class.forcing_size, 
        model_class.output_size,
        steps=ar_steps,
        weigths=True,
        intervals=intervals
    )


train_loader = DataLoader(WeatherData(ds, 
                                    window_size=window_size, 
                                    steps=ar_steps, 
                                    intervals=intervals, 
                                    data_split='train'), 
                                    batch_size=batch_size, 
                                    shuffle=True)

val_loader = DataLoader(WeatherData(ds,
                                    window_size=window_size,
                                    steps=ar_steps,
                                    intervals=intervals,
                                    data_split='val'), 
                                    batch_size=batch_size, 
                                    shuffle=False)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",      
    dirpath="checkpoints/",  
    filename=f"best_model_{ar_steps}_steps",
    save_top_k=1,            
    mode="min"              
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",      
    patience=5,             
    mode="min",              
    verbose=True
)

trainer = Trainer(
    max_epochs=500,
    callbacks=[checkpoint_callback, early_stopping_callback],
    check_val_every_n_epoch=1
)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [25]:
for batch in train_loader:
    x, F, y = batch
    break

x.max(), x.min(), F.max(), F.min(), y.max(), y.min()

(tensor(5.6885),
 tensor(-1.5689),
 tensor(21.),
 tensor(0.),
 tensor(6.8905),
 tensor(-1.5646))

In [None]:
trainer.fit(model, train_loader, val_loader)

In [10]:
model_class = WeatherData(ds, window_size=15, steps=13, use_forcings=True, intervals=3)

model = SimpleMLP(
        model_class.input_size, 
        model_class.forcing_size, 
        model_class.output_size,
    )

In [11]:
model_class.assign_model(model)
model_class.load_model('models\MLP_15to1_s32_multi_3.pth')

seed = 74

model_class.plot_pred_target(seed=seed, frame_rate=4, levels=10)

  F = torch.tensor(F).float()
