## Simple example - Temperature Only Analysis
This notebook shows how to train model for temperature data only

In [28]:
# Imports
import sys
import matplotlib.pyplot as plt
import cartopy
import torch

sys.path.append('../src/')
from DatasetUS import *
# from TrainDiffusion import *  # COMMENTED OUT - Temporarily using U-Net for testing
# import Network  # Import Network module for EDMPrecond
from TrainUnet import *  # UNCOMMENTED - Using U-Net for testing weekly data
import Network  # Import Network module for UNet



This example can be run on a laptop but won't train the network very well. We will train with just a small subset of data. 

In [29]:
## Select years to train and validate
train_year_start = 1953
train_year_end = 1955

valid_year_start = 1956
valid_year_end = 1957

Set up training hyperparameters. We will only run for 10 epochs and we will use the cpu. 

In [30]:
## Select hyperparameters of training
batch_size = 4  # Reduced from 16 to fit RTX 4050 (6GB VRAM)
learning_rate = 1e-4
accum = 8

# Run training for small number of epochs 
num_epochs = 2        

# Define device
torch.cuda.set_device(0)  # Use GPU 1 as default
device = 'cuda'  # Will use GPU 1
# if torch.cuda.is_available() else 'cpu'

# define the ml model - TEMPORARILY USING U-NET FOR TESTING WEEKLY DATA
unet_model = UNet((256, 128), 3, 1, label_dim=2, use_diffuse=False)  # UNCOMMENTED - U-Net for testing
unet_model.to(device)
# diffusion_model = Network.EDMPrecond((256, 128), 3, 1, label_dim=2)  # COMMENTED OUT - Temporarily using U-Net

# define the datasets
datadir = "../data/"
dataset_train = UpscaleDataset(datadir, year_start=train_year_start, year_end=train_year_end,
                               constant_variables=["lsm", "z"])

dataset_test = UpscaleDataset(datadir, year_start=valid_year_start, year_end=valid_year_end,
                              constant_variables=["lsm", "z"])

dataloader_train = torch.utils.data.DataLoader(
    dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

Opening files
All files accessed. Creating tensors
torch.Size([1])
tensor([259.5524]) tensor([310.5276])
Opening constant variables file (e.g. land-sea mask, topography)
Normalize z
Mean:4599.646526826994, Std6220.799692544967
Dataset initialized.
Opening files
All files accessed. Creating tensors
torch.Size([1])
tensor([259.5524]) tensor([310.5276])
Opening constant variables file (e.g. land-sea mask, topography)
Normalize z
Mean:4599.646526826994, Std6220.799692544967
Dataset initialized.


In [31]:
print(len(dataloader_train), len(dataloader_test))

# Debug: Check the actual data shapes
batch = next(iter(dataloader_train))
print("Input shape:", batch["inputs"].shape)
print("Target shape:", batch["targets"].shape)
print("Expected input channels: 3 (1 temp + 2 constants)")
print("Actual input channels:", batch["inputs"].shape[1])
print("Model expects:", 3, "channels")

# Check what channels we actually have
print("\\nDataset configuration:")
print("- Temperature variables:", dataset_train.varnames)
print("- Number of temperature variables:", dataset_train.n_var)
print("- Constant variables:", ["lsm", "z"])
print("- Expected total channels:", dataset_train.n_var + 2, "(temp + constants)")

# WEEKLY DATA SUBSAMPLING INFO
print("\\nWeekly Data Subsampling:")
print("- Original data: Daily (365 days/year)")
print("- Weekly data: Every 7th day (52 days/year)")
print("- Data reduction: 86% (365 → 52 days)")
print("- Expected training speed: 7x faster")
print("- Training samples (weekly):", (train_year_end - train_year_start + 1) * 52)
print("- Validation samples (weekly):", (valid_year_end - valid_year_start + 1) * 52)

180 90
Input shape: torch.Size([4, 3, 128, 256])
Target shape: torch.Size([4, 1, 128, 256])
Expected input channels: 3 (1 temp + 2 constants)
Actual input channels: 3
Model expects: 3 channels
\nDataset configuration:
- Temperature variables: ['temp']
- Number of temperature variables: 1
- Constant variables: ['lsm', 'z']
- Expected total channels: 3 (temp + constants)
\nWeekly Data Subsampling:
- Original data: Daily (365 days/year)
- Weekly data: Every 7th day (52 days/year)
- Data reduction: 86% (365 → 52 days)
- Expected training speed: 7x faster
- Training samples (weekly): 156
- Validation samples (weekly): 104


In [32]:
# Clear GPU memory
torch.cuda.empty_cache()

# Check GPU memory
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"GPU Memory Cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

scaler = torch.amp.GradScaler('cuda')  # Updated for newer PyTorch

# define the optimiser - TEMPORARILY USING U-NET
optimiser = torch.optim.AdamW(unet_model.parameters(), lr=learning_rate)
# optimiser = torch.optim.AdamW(diffusion_model.parameters(), lr=learning_rate)  # COMMENTED OUT - Using U-Net

# Define the tensorboard writer
writer = SummaryWriter("./runs_unet")  # UNCOMMENTED - Using U-Net for testing
# writer = SummaryWriter("./runs_diffusion")  # COMMENTED OUT - Using U-Net

# define loss function - MSE Loss for U-Net
loss_fn = torch.nn.MSELoss()
# loss_fn = EDMLoss()  # COMMENTED OUT - Using U-Net

# train the model
losses = []

GPU Memory: 6.4 GB
GPU Memory Allocated: 3.66 GB
GPU Memory Cached: 4.16 GB


Start the training loop using U-Net framework with weekly data. The plots generated will show the coarse res, the predicted, and the truth for a few samples and for different variables. At the start of training the first two columns (coarse res and predicted) look similar. Towards the end of the training, the last two columns (predicted and truth) should look similar. 

In [None]:
for step in range(num_epochs):
    # Clear GPU memory before each epoch
    torch.cuda.empty_cache()
    
    # Use U-Net training step for weekly data testing
    epoch_loss = train_step(
        unet_model, loss_fn, dataloader_train, optimiser,
        scaler, step, accum, writer, device=device)
    losses.append(epoch_loss)

    # Use U-Net sampling for weekly data testing
    (fig, ax), (base_error, pred_error) = sample_model(
        unet_model, dataloader_test, device=device)
    plt.show()

    writer.add_scalar("Error/base", base_error, step)
    writer.add_scalar("Error/pred", pred_error, step)
    
    # Clear memory after each epoch
    torch.cuda.empty_cache()


  with torch.cuda.amp.autocast():
Train :: Epoch: 0:  47%|████▋     | 84/180 [26:53<1:19:06, 49.44s/it, Loss: 0.7798]