# Set up

In [1]:
import random
import numpy as np
import xarray as xr
import torch
from torch.utils.data import DataLoader, TensorDataset
from unet import *
from train import *
from loss import *

base_path = "gs://leap-persistent/YueWang/SSH/data"

def open_zarr(path):
    return xr.open_zarr(path, consolidated=True)

train = open_zarr(f"{base_path}/train_54_sst.zarr").compute()
val = open_zarr(f"{base_path}/val_54_sst.zarr").compute()
zca = open_zarr(f"{base_path}/zca_54_sst.zarr").compute()

In [2]:
def min_max_normalize(tensor, min_values=None, max_values=None, feature_range=(0, 1)):

    num_channels = tensor.shape[1]
    
    if min_values is None:
        min_values = torch.zeros(num_channels, device=tensor.device)
        for c in range(num_channels):
            min_values[c] = tensor[:, c, :, :].min()
    
    if max_values is None:
        max_values = torch.zeros(num_channels, device=tensor.device)
        for c in range(num_channels):
            max_values[c] = tensor[:, c, :, :].max()
    
    normalized_tensor = torch.zeros_like(tensor)
    scale = (feature_range[1] - feature_range[0])
    
    # Normalize each channel independently
    for c in range(num_channels):
        channel_range = max_values[c] - min_values[c]
        
        # Handle edge case where min and max are the same
        if channel_range == 0:
            normalized_tensor[:, c, :, :] = feature_range[0]
        else:
            # Apply min-max formula: (x - min) / (max - min) * scale + min_range
            normalized_tensor[:, c, :, :] = (
                (tensor[:, c, :, :] - min_values[c]) / channel_range
            ) * scale + feature_range[0]
    
    return normalized_tensor, min_values, max_values

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare training data
x_train_channel_0 = torch.from_numpy(train.ssh.values).float().unsqueeze(1).to(device)
x_train_channel_1 = torch.from_numpy(train.sst.values).float().unsqueeze(1).to(device)
x_train = torch.cat([x_train_channel_0, x_train_channel_1], dim=1)
x_train_normalized, min_values, max_values = min_max_normalize(x_train)

y_train_channel_0 = torch.from_numpy(train.ubm.values).float().unsqueeze(1).to(device)
y_train_channel_1 = torch.from_numpy(train.zca_ubm.values).float().unsqueeze(1).to(device)
y_train = torch.cat([y_train_channel_0, y_train_channel_1], dim=1)


# Prepare validation data 
x_val_channel_0 = torch.from_numpy(val.ssh.values).float().unsqueeze(1).to(device)
x_val_channel_1 = torch.from_numpy(val.sst.values).float().unsqueeze(1).to(device)
x_val = torch.cat([x_val_channel_0, x_val_channel_1], dim=1)
x_val_normalized, _, _ = min_max_normalize(x_val, min_values=min_values, max_values=max_values)

y_val_channel_0 = torch.from_numpy(val.ubm.values).float().unsqueeze(1).to(device)
y_val_channel_1 = torch.from_numpy(val.zca_ubm.values).float().unsqueeze(1).to(device)  
y_val = torch.cat([y_val_channel_0, y_val_channel_1], dim=1)  

# Create datasets
train_dataset = TensorDataset(x_train_normalized, y_train)
val_dataset = TensorDataset(x_val_normalized, y_val)


# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Vt = torch.from_numpy(zca.zca_Vt_ubm.values).float().to(device)
scale = torch.from_numpy(zca.zca_scale_ubm.values).float().to(device)
mean = torch.from_numpy(zca.zca_mean_ubm.values).float().to(device)

# Training

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

# Create model and optimizer
model = UNet(in_channels=2, out_channels=2, initial_features=32, depth=4)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


train_model(model, train_loader, val_loader,
            Vt, scale, mean,  
            optimizer, device,
            grad_loss_weight=0.0,      # Weight for gradient loss in physical space
            mse_loss_weight=0.0,       # Weight for MSE loss in physical space  
            zca_nll_weight=1.0,        # Weight for ZCA Gaussian NLL loss
            save_path='/home/jovyan/grl_final/checkpoints/zca_sst.pth',
            n_epochs=2000,
            patience=50)



Epoch 1, Train Loss: 3.39e-01 (ZCA-NLL: 3.39e-01, MSE-Phys: 1.05e-04, Grad-Phys: 2.60e-04), Val Loss: 1.57e+00, Epoch Time: 152.95s


RuntimeError: Parent directory /home/jovyan/grl_final/checkpoints does not exist.