# Imports and Data

In [21]:
# Third-party

import xarray as xr

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger

import wandb

# Local imports

from weather_data_class_v3 import WeatherData

ds = xr.open_dataset('data_850/2020to2022_coarsened.nc')
ds.load()


# Model Setup

In [22]:
class SimpleMLP(L.LightningModule):
    def __init__(self, input_size, forcing_size, 
                 output_size, 
                 lr = 0.0001, 
                 steps = 1, 
                 lat = 16, 
                 lon = 32,
                 weigths = False,
                 intervals = 1):
        
        super(SimpleMLP, self).__init__()

        self.save_hyperparameters()

        self.fc1 = nn.Linear(input_size + forcing_size, 128)  
        self.fc2 = nn.Linear(128, 64) 
        self.fc3 = nn.Linear(64, output_size) 

        self.loss_fn = nn.MSELoss()

        self.steps = steps

        self.lat = lat
        self.lon = lon

        self.lr = lr

        self.intervals = intervals
        
        if weigths:
            self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)  
            if m.bias is not None:
                nn.init.constant_(m.bias, 0) 

    def forward(self, X, F):

        batch_size = X.size(0)
        X = X.reshape(batch_size, -1)  

        inputs = torch.cat((X, F), dim=1)  

        x = torch.relu(self.fc1(inputs)) 
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  

        return x.reshape(-1, 1, self.lat, self.lon)
    
    def training_step(self, batch, batch_idx):
        x, F, y = batch  

        loss = self.auto_rollout(x, F, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True) 
        return loss

    def validation_step(self, batch, batch_idx):
        x, F, y = batch  
 
        loss = self.auto_rollout(x, F, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True) 
        
        return loss
    
    def auto_rollout(self, x, F, y):

        cumulative_loss = 0.0
        current_input = x.clone()
        current_F = F.clone()

        for step in range(self.steps): 
            # print('Step: ', step)
            
            y_hat = self(current_input, current_F)

            loss = self.loss_fn(y_hat, y[:, step].reshape(-1, 1, self.lat, self.lon))
            cumulative_loss += loss  
            
            current_input = torch.cat((current_input[:, 1:], y_hat), dim=1)

            hour = current_F[:, 0]
            month = current_F[:, 1]
            
            hour = (hour + self.intervals) % 24

            current_F = torch.stack((hour, month), dim=1).float()

        return cumulative_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr= self.lr)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
        
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
    

from UNetPyTorch import *

class build_res_unet_time_(L.LightningModule):
    def __init__(self, in_c, out_c, start_size = 64, lr = 0.0001, steps = 1, intervals = 1, lat = 16, lon = 32, weights = False):
        super(build_res_unet_time_, self).__init__()

        self.save_hyperparameters()

        self.filename = f"unet_in{in_c}_out{out_c}_size{start_size}_lr{lr}_steps{steps}_intervals{intervals}_lat{lat}_lon{lon}_weights{weights}.pth"

        self.steps = steps
        self.intervals = intervals

        self.lat = lat
        self.lon = lon

        self.lr = lr

        self.loss_fn = nn.MSELoss()

        ''' Encoder 1 '''
        self.c11 = nn.Conv2d(in_c + 2, start_size, kernel_size=3, padding=1)  # in_c + 2 to account for time inputs
        self.br1 = batchnorm_relu(start_size)
        self.c12 = nn.Conv2d(start_size, start_size, kernel_size=3, padding=1)
        self.c13 = nn.Conv2d(in_c + 2, start_size, kernel_size=1, padding=0)  # Shortcut feature

        """ Encoder 2 and 3 """
        self.r2 = residual_block(start_size, start_size * 2, stride=2)
        self.r3 = residual_block(start_size * 2, start_size * 4, stride=2)

        """ Bridge """
        self.r4 = residual_block(start_size * 4, start_size * 8, stride=2)

        """ Decoder """
        self.d1 = decoder_block(start_size * 8, start_size * 4)
        self.d2 = decoder_block(start_size * 4, start_size * 2)
        self.d3 = decoder_block(start_size * 2, start_size)

        """ Output """
        self.output = nn.Conv2d(start_size, out_c, kernel_size=1, padding=0)
        # self.sigmoid = nn.Sigmoid()

        if weights:
            self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)  
            if m.bias is not None:
                nn.init.constant_(m.bias, 0) 

    def forward(self, weather_data, time_input):
        # weather_data shape: (batch_size, window_size, lat, lon)
        # time_input shape: (batch_size, window_size, 2) where 2 corresponds to (hour of day, month of year)

        batch_size, window_size, lat, lon = weather_data.shape

        # Reshape weather_data to (batch_size, window_size, 1, lat, lon)
        weather_data = weather_data.view(batch_size, window_size, lat, lon)

        # Expand time_input to match the spatial dimensions of the weather data
        time_input = time_input.view(batch_size, 2)  # (batch_size, window_size, 2)

        # Expand to match spatial dimensions (lat, lon) and channel 2 for (hour, month)
        time_input_expanded = time_input.unsqueeze(-1).unsqueeze(-1)  # Add spatial dimensions
        time_input_expanded = time_input_expanded.expand(-1, -1, lat, lon)  # Now shape is (batch_size, 2, lat, lon)

        # Concatenate along the channel axis
        combined_input = torch.cat([weather_data, time_input_expanded], dim=1)  # New shape: (batch_size, in_c + 2, lat, lon)

        """ Encoder 1 """
        x = self.c11(combined_input)
        x = self.br1(x)
        x = self.c12(x)
        s = self.c13(combined_input)
        skip1 = x + s

        """ Encoder 2 and 3 """
        skip2 = self.r2(skip1)
        skip3 = self.r3(skip2)

        """ Bridge """
        b = self.r4(skip3)

        """ Decoder """
        d1 = self.d1(b, skip3)
        d2 = self.d2(d1, skip2)
        d3 = self.d3(d2, skip1)

        """ Output """
        output = self.output(d3)
        # output = self.sigmoid(output)

        return output
    
    def training_step(self, batch, batch_idx):
        x, F, y = batch  

        loss = self.auto_rollout(x, F, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True) 
        return loss

    def validation_step(self, batch, batch_idx):
        x, F, y = batch  
 
        loss = self.auto_rollout(x, F, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True) 
        
        return loss
    
    def auto_rollout(self, x, F, y):

        cumulative_loss = 0.0
        current_input = x.clone()
        current_F = F.clone()

        for step in range(self.steps): 
            # print('Step: ', step)
            
            y_hat = self(current_input, current_F)

            loss = self.loss_fn(y_hat, y[:, step].reshape(-1, 1, self.lat, self.lon))
            cumulative_loss += loss  
            
            current_input = torch.cat((current_input[:, 1:], y_hat), dim=1)

            hour = current_F[:, 0]
            month = current_F[:, 1]
            
            hour = (hour + self.intervals) % 24

            current_F = torch.stack((hour, month), dim=1).float()

        return cumulative_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr= self.lr)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
        
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
    

# Training

In [23]:
window_size = 5
steps = 13
batch_size = 32
intervals = 3

model_class = WeatherData(ds, window_size=window_size, steps=steps, use_forcings=True, intervals=intervals)
model_class.assign_model(build_res_unet_time_(model_class.window_size, 1))


Data split
Data normalized


In [24]:
model_class.test_class()

self.X_train: (18938, 16, 32) self.X_val: (4735, 16, 32) self.X_test: (2631, 16, 32)
self.F_train: (18938, 2) self.F_val: (4735, 2) self.F_test: (2631, 2)
self.X_train_t: torch.Size([18938, 16, 32]) self.X_val_t: torch.Size([4735, 16, 32]) self.X_test_t: torch.Size([2631, 16, 32])
self.F_train_t: torch.Size([18938, 2]) self.F_val_t: torch.Size([4735, 2]) self.F_test_t: torch.Size([2631, 2])
self.input_size: 2560 self.forcing_size: 2 self.output_size: 512


In [25]:
ar_steps = 4

# model = SimpleMLP(
#         model_class.input_size, 
#         model_class.forcing_size, 
#         model_class.output_size,
#         steps=ar_steps,
#         weigths=True,
#         intervals=intervals
#     )

# log_name = f"SimpleMLP_{ar_steps}_steps_1"

model = build_res_unet_time_(in_c=model_class.window_size, 
                             out_c = 1, 
                             start_size=16,
                             steps=ar_steps, 
                             intervals=intervals,
                             weights=True)

log_name = f"UNet_{ar_steps}_4steps_4"


train_loader = DataLoader(WeatherData(ds, 
                                    window_size=window_size, 
                                    steps=ar_steps, 
                                    intervals=intervals, 
                                    data_split='train'), 
                                    batch_size=batch_size, 
                                    shuffle=True,
                                    persistent_workers=True,
                                    num_workers=2,
                                    drop_last=True)

val_loader = DataLoader(WeatherData(ds,
                                    window_size=window_size,
                                    steps=ar_steps,
                                    intervals=intervals,
                                    data_split='val'), 
                                    batch_size=batch_size, 
                                    shuffle=False,
                                    persistent_workers=True,
                                    num_workers=2,
                                    drop_last=True)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",      
    dirpath="checkpoints/",  
    filename=log_name,
    save_top_k=1,            
    mode="min"              
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",      
    patience=15,             
    mode="min",              
    verbose=True
)

wandb_logger = WandbLogger(project="weather-forecasting", name=log_name)

# import torch

# # Check if GPU is available
# if torch.cuda.is_available():
#     device = torch.device("cuda")
#     # Get GPU properties
#     gpu_stats = torch.cuda.get_device_properties(device)
#     print(f"GPU Name: {gpu_stats.name}")
#     print(f"Total Memory: {gpu_stats.total_memory / (1024 ** 2)} MB")
    
#     # Monitor utilization
#     while True:
#         utilization = torch.cuda.utilization()
#         print(f"GPU Utilization: {utilization}%")
#         if utilization > 90:  # Adjust the threshold as needed
#             print("Reducing workload...")
#             # Add logic to reduce batch size or adjust learning rate here
#             break
# else:
#     print("CUDA is not available.")


trainer = Trainer(
    max_epochs=300,
    callbacks=[checkpoint_callback, early_stopping_callback],
    check_val_every_n_epoch=1,
    logger=wandb_logger
)


Data split
Data normalized


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Data split
Data normalized


In [20]:
trainer.fit(model, train_loader, val_loader)
wandb.finish()

c:\Users\divanvdb\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\loggers\wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
c:\Users\divanvdb\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:654: Checkpoint directory C:\Users\divanvdb\Documents\GitHub\1_WindSpeedForecasting\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name    | Type           | Params | Mode 
----------------------------------------------------
0  | loss_fn | MSELoss        | 0      | train
1  | c11     | Conv2d         | 1.0 K  | train
2  | br1     | batchnorm_relu | 32     | train
3  | c12     | Conv2d         | 2.3 K  | train
4  | c13     | Conv2d         | 128    | train
5  | r2      | residual_block | 14.5 K | train
6  | r3      | residual_b

                                                                           

In [7]:
model_class = WeatherData(ds, window_size=15, steps=13, use_forcings=True, intervals=3)

model = SimpleMLP(
        model_class.input_size, 
        model_class.forcing_size, 
        model_class.output_size,
    )

Data split
Data normalized


In [18]:


model = build_res_unet_time_.load_from_checkpoint('checkpoints/UNet_1_steps_2.ckpt', 
                            in_c=model_class.window_size, 
                             out_c = 1, 
                             start_size=16,
                             steps=ar_steps, 
                             intervals=intervals)

model_class.assign_model(model.to('cpu'))


model_class.plot_pred_target(seed=0, frame_rate=8, levels=10)

  F = torch.tensor(F).float()


# Testing and development

In [43]:
import torch
from torch.utils.data import Dataset, DataLoader

class WindSpeedDataset(Dataset):
    def __init__(self, data, window_size=3, steps=2, interval=2):
        self.data = data
        self.window_size = window_size
        self.steps = steps
        self.interval = interval

    def __len__(self):
        # Calculate the number of batches
        return len(self.data) - (self.window_size + self.steps) * self.interval + self.interval

    def __getitem__(self, idx):
        # Calculate starting index for the window

        x = self.data[idx:idx + self.window_size * self.interval:self.interval]
        y = self.data[idx + self.window_size * self.interval:idx + self.window_size * self.interval + self.steps * self.interval:self.interval]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# Example data
data = torch.arange(0, 25)

# Create dataset and dataloader
dataset = WindSpeedDataset(data, window_size=5, steps=5, interval=2)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Display the batches
for batch_idx, (x, y) in enumerate(dataloader):
    print(f"Batch {batch_idx+1}: x as {x.numpy().flatten()} and y as {y.numpy().flatten()}")


Batch 1: x as [ 2.  4.  6.  8. 10.] and y as [12. 14. 16. 18. 20.]
Batch 2: x as [ 6.  8. 10. 12. 14.] and y as [16. 18. 20. 22. 24.]
Batch 3: x as [1. 3. 5. 7. 9.] and y as [11. 13. 15. 17. 19.]
Batch 4: x as [ 3.  5.  7.  9. 11.] and y as [13. 15. 17. 19. 21.]
Batch 5: x as [ 5.  7.  9. 11. 13.] and y as [15. 17. 19. 21. 23.]
Batch 6: x as [0. 2. 4. 6. 8.] and y as [10. 12. 14. 16. 18.]
Batch 7: x as [ 4.  6.  8. 10. 12.] and y as [14. 16. 18. 20. 22.]


  return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
