# Set up

In [2]:
import random
import numpy as np
import xarray as xr
import torch
from torch.utils.data import DataLoader, TensorDataset
from unet import *
from train import *
from loss import *

base_path = "gs://leap-persistent/YueWang/SSH/data"

def open_zarr(path):
    return xr.open_zarr(path, consolidated=True)

train = open_zarr(f"{base_path}/train_80_sst.zarr").compute()
val = open_zarr(f"{base_path}/val_80_sst.zarr").compute()
zca = open_zarr(f"{base_path}/zca_80_sst.zarr").compute()

In [4]:
def min_max_normalize(tensor, min_values=None, max_values=None, feature_range=(0, 1)):

    num_channels = tensor.shape[1]
    
    if min_values is None:
        min_values = torch.zeros(num_channels, device=tensor.device)
        for c in range(num_channels):
            min_values[c] = tensor[:, c, :, :].min()
    
    if max_values is None:
        max_values = torch.zeros(num_channels, device=tensor.device)
        for c in range(num_channels):
            max_values[c] = tensor[:, c, :, :].max()
    
    normalized_tensor = torch.zeros_like(tensor)
    scale = (feature_range[1] - feature_range[0])
    
    # Normalize each channel independently
    for c in range(num_channels):
        channel_range = max_values[c] - min_values[c]
        
        # Handle edge case where min and max are the same
        if channel_range == 0:
            normalized_tensor[:, c, :, :] = feature_range[0]
        else:
            # Apply min-max formula: (x - min) / (max - min) * scale + min_range
            normalized_tensor[:, c, :, :] = (
                (tensor[:, c, :, :] - min_values[c]) / channel_range
            ) * scale + feature_range[0]
    
    return normalized_tensor, min_values, max_values

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare training data
x_train_channel_0 = torch.from_numpy(train.ssh.values).float().unsqueeze(1).to(device)
x_train_channel_1 = torch.from_numpy(train.sst.values).float().unsqueeze(1).to(device)
x_train = torch.cat([x_train_channel_0, x_train_channel_1], dim=1)
x_train_normalized, min_values, max_values = min_max_normalize(x_train)

y_train_channel_0 = torch.from_numpy(train.ubm.values).float().unsqueeze(1).to(device)
y_train_channel_1 = torch.from_numpy(train.zca_ubm.values).float().unsqueeze(1).to(device)
y_train = torch.cat([y_train_channel_0, y_train_channel_1], dim=1)

# Prepare validation data 
x_val_channel_0 = torch.from_numpy(val.ssh.values).float().unsqueeze(1).to(device)
x_val_channel_1 = torch.from_numpy(val.sst.values).float().unsqueeze(1).to(device)
x_val = torch.cat([x_val_channel_0, x_val_channel_1], dim=1)
x_val_normalized, _, _ = min_max_normalize(x_val, min_values=min_values, max_values=max_values)

y_val_channel_0 = torch.from_numpy(val.ubm.values).float().unsqueeze(1).to(device)
y_val_channel_1 = torch.from_numpy(val.zca_ubm.values).float().unsqueeze(1).to(device)  # Add this!
y_val = torch.cat([y_val_channel_0, y_val_channel_1], dim=1)  # Concatenate both channels

# Create datasets
train_dataset = TensorDataset(x_train_normalized, y_train)
val_dataset = TensorDataset(x_val_normalized, y_val)  # Now both have 2 channels


# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Vt = torch.from_numpy(zca.zca_Vt_ubm.values).float().to(device)
scale = torch.from_numpy(zca.zca_scale_ubm.values).float().to(device)
mean = torch.from_numpy(zca.zca_mean_ubm.values).float().to(device)

# Training

In [6]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

# Create model and optimizer
model = UNet(in_channels=2, out_channels=2, initial_features=32, depth=4)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


train_model(model, train_loader, val_loader,
            Vt, scale, mean,  
            optimizer, device,
            grad_loss_weight=0.0,      # Weight for gradient loss in physical space
            mse_loss_weight=1.0,       # Weight for MSE loss in physical space  
            zca_nll_weight=0.0,        # Weight for ZCA Gaussian NLL loss
            save_path='/home/jovyan/GRL_ssh/checkpoints/mse_loss_only.pth',
            n_epochs=2000,
            patience=50)

Epoch 1, Train Loss: 8.71e-05 (ZCA-NLL: 5.93e-01, MSE-Phys: 8.71e-05, Grad-Phys: 2.34e-04), Val Loss: 2.32e+00, Epoch Time: 119.95s
Best model so far saved at epoch 1 (Val Loss: 2.324e+00)
Epoch 2, Train Loss: 7.12e-05 (ZCA-NLL: 5.83e-01, MSE-Phys: 7.12e-05, Grad-Phys: 2.21e-04), Val Loss: 2.32e+00, Epoch Time: 121.50s
Best model so far saved at epoch 2 (Val Loss: 2.321e+00)
Epoch 3, Train Loss: 6.65e-05 (ZCA-NLL: 5.61e-01, MSE-Phys: 6.65e-05, Grad-Phys: 2.15e-04), Val Loss: 2.32e+00, Epoch Time: 121.89s
Best model so far saved at epoch 3 (Val Loss: 2.319e+00)
Epoch 4, Train Loss: 6.38e-05 (ZCA-NLL: 5.51e-01, MSE-Phys: 6.38e-05, Grad-Phys: 2.12e-04), Val Loss: 2.32e+00, Epoch Time: 121.96s
Best model so far saved at epoch 4 (Val Loss: 2.316e+00)
Epoch 5, Train Loss: 6.08e-05 (ZCA-NLL: 5.41e-01, MSE-Phys: 6.08e-05, Grad-Phys: 2.09e-04), Val Loss: 2.31e+00, Epoch Time: 122.00s
Best model so far saved at epoch 5 (Val Loss: 2.315e+00)
Epoch 6, Train Loss: 5.82e-05 (ZCA-NLL: 5.33e-01, MSE-P