In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset
import os
import random
import xarray as xr

In [2]:
# U-Net model
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # Encoder
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        
        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)
        
        # Decoder
        self.decoder1 = self.upconv_block(512, 256)
        self.decoder2 = self.upconv_block(256, 128)
        self.decoder3 = self.upconv_block(128, 64)
        
        # Output layer
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # padding=1 to preserve dimensions
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),  # padding=1 to preserve dimensions
            nn.ReLU(inplace=True)
        )
    
    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),  
            nn.ReLU(inplace=True),
            self.conv_block(out_channels, out_channels)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, 2))
        enc3 = self.encoder3(F.max_pool2d(enc2, 2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc3, 2))
        
        # Decoder
        dec1 = self.decoder1(bottleneck)
        dec2 = self.decoder2(dec1)
        dec3 = self.decoder3(dec2)
        
        # Output
        out = self.out_conv(dec3)
        return out

In [3]:
# Hybrid U-Net with Auto-regressive Training
class HybridUNet(nn.Module):
    def __init__(self, in_channels, out_channels, lead_time):
        super(HybridUNet, self).__init__()
        self.unet = UNet(in_channels, out_channels)
        self.lead_time = lead_time
    
    def forward(self, x):
        predictions = []
        current_input = x
        
        for _ in range(self.lead_time):
            # Predict the next time step
            next_step = self.unet(current_input)
            
            # Print shapes before padding
            # print('1', current_input.shape)
            # print('2', next_step.shape)
            
            desired_height = 721
            desired_width = 1431
            desired_channels = 11 

            # Calculate padding for height and width
            height_padding = (desired_height - next_step.shape[2]) // 2
            width_padding = (desired_width - next_step.shape[3]) // 2

            # Ensure the padding is correct: add zero padding if dimensions are already correct
            height_padding_top = height_padding
            height_padding_bottom = desired_height - next_step.shape[2] - height_padding_top

            width_padding_left = width_padding
            width_padding_right = desired_width - next_step.shape[3] - width_padding_left

            # Apply padding
            padding = (width_padding_left, width_padding_right, height_padding_top, height_padding_bottom)
            next_step = F.pad(next_step, padding)  # Pad to match the desired height and width

            # Ensure the shape is (6, 11, 721, 1431)
            next_step = next_step.view(-1, desired_channels, desired_height, desired_width)

            predictions.append(next_step)
            # print('3', next_step.shape)

            # Update the input for the next prediction
            current_input = torch.cat([current_input[:, 11:, :, :], next_step], dim=1)  # Shift time steps
            # print('4', current_input.shape)

        # Stack predictions into a single tensor
        predictions = torch.stack(predictions, dim=1)  # Shape: (batch_size, lead_time, 8, 721, 1431)
        return predictions

In [4]:
class NetCDFDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list 
        self.means = xr.open_dataset("means.nc").to_array().values
        self.stds = xr.open_dataset("std.nc").to_array().values

    def __len__(self):
        return len(self.file_list)
    
    def normalize(self, ds):
        # Convert means and stds to PyTorch tensors
        for i, var in enumerate(ds.data_vars):
            # Get the data for the current variable (Shape: (lat, lon, time))
            data = ds[var]
            
            # Broadcast means and stds to match the shape of the data (lat, lon, time)
            mean = self.means[i]
            std = self.stds[i]
            
            # Normalize the data: (x - mean) / std
            normalized_data = (data - mean) / (std + 1e-8)  # Adding a small epsilon to avoid division by zero
            
            # Update the dataset with the normalized data
            ds[var] = normalized_data

        return ds


    def __getitem__(self, idx):
        # Load a batch (NetCDF file)
        file_path = self.file_list[idx]
        ds = xr.open_dataset(file_path)
        ds = ds.drop_vars(["tp_mask", "sst_mask", "pottmp_mask"])
        ds = ds.fillna(-1)
        # Apply normalization
        ds = self.normalize(ds)
        
        # Extract data as numpy arrays
        data = ds.to_array().values  # Shape: (variables, lat, lon, time)
        # print(data.shape)

        # correct shape: (variables, lat, lon, time)
        assert data.shape[-1] == 6, f"Expected 6 time steps, but got {data.shape[-1]} in {file_path}"

        data_tensor = torch.tensor(data, dtype=torch.float32)  # Shape: (11, 721, 1431, 6)

        # Extract inputs (first 3 time steps)
        inputs = data_tensor[..., :3]  # Shape: (11, 721, 1431, 3)
        inputs = inputs.reshape(3 * 11, inputs.shape[1], inputs.shape[2])  # Shape: (3*11, 721, 1431)

        # Extract targets (last 3 time steps)
        targets = data_tensor[..., 3:]  # Shape: (11, 721, 1431, 3)
        targets = targets.reshape(3 * 11, targets.shape[1], targets.shape[2])  # Shape: (3*11, 721, 1431)

        ds.close()
        # print(inputs.shape)
        return inputs, targets

In [5]:
# Define hyperparameters
batch_size = 6
in_channels = 3 * 11  # 3 time steps × 11 variables
out_channels = 11  # 1 time step × 11 variables
lead_time = 3  # Number of time steps to predict
height, width = 721, 1431 

# Create the model
model = HybridUNet(in_channels, out_channels, lead_time)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# Define the folder where batches are stored
batch_folder = 'batches/'
batch_files = [os.path.join(batch_folder, f) for f in os.listdir(batch_folder) if f.endswith('.nc')]


# Define the ratios for splitting
train_ratio = 0.723   
val_ratio = 0.12    
test_ratio = 0.158  

# Calculate the split indices
num_batches = len(batch_files)
train_end_idx = int(num_batches * train_ratio)
val_end_idx = int(num_batches * (train_ratio + val_ratio))

# Split the list into training, validation, and test sets
train_batches = batch_files[:train_end_idx]
print('training batches:', len(train_batches))
val_batches = batch_files[train_end_idx:val_end_idx]
print('validation batches:', len(val_batches))
test_batches = batch_files[val_end_idx:]
print('test batches:', len(test_batches))

# Create dataset instances
train_dataset = NetCDFDataset(train_batches)
val_dataset = NetCDFDataset(val_batches)
test_dataset = NetCDFDataset(test_batches)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Training loop
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

# Get mask arrays for sst, pottmp and tp
ds = xr.open_dataset("batches/batch-001-19820101-19820601.nc")

pottmp_mask=ds['pottmp_mask'].values
pottmp_mask=torch.tensor(pottmp_mask).to(device)

sst_mask=ds['sst_mask'].values
sst_mask=torch.tensor(sst_mask).to(device)

tp_mask=ds['tp_mask'].values
tp_mask=torch.tensor(tp_mask).to(device)

training batches: 360
validation batches: 60
test batches: 79
cpu


In [None]:
# training loop
for epoch in range(num_epochs):
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward pass
        predictions = model(inputs)  # Shape: (batch_size, lead_time, 11, 721, 1431)
        predictions = predictions.view(batch_size, lead_time * out_channels, height, width)  # Reshape to match targets

        # check for NaN or Inf in predictions or targets
        if torch.isnan(predictions).any() or torch.isinf(predictions).any():
            print(f"NaN/Inf detected in predictions at epoch {epoch+1}, batch {batch_idx}")
            
        if torch.isnan(targets).any() or torch.isinf(targets).any():
            print(f"NaN/Inf detected in targets at epoch {epoch+1}, batch {batch_idx}")

        
        # For other variables, we don't apply the mask.
        loss = 0
        for i in range(out_channels):
            if i == 0:  # Ocean variable
                loss += criterion(predictions[:, 0, :, :] * sst_mask, targets[:, 0, :, :] * sst_mask)
            elif i == 2:  # pottmp variable
                loss += criterion(predictions[:, 2, :, :] * pottmp_mask, targets[:, 2, :, :] * pottmp_mask)
            elif i == 10:  # tp variable
                loss += 100 * criterion(predictions[:, 10, :, :] * tp_mask, targets[:, 10, :, :] * tp_mask)
                # print('prc', 100 * criterion(predictions[:, 10, :, :] * tp_mask, targets[:, 10, :, :] * tp_mask))
            else:
                loss += criterion(predictions[:, i, :, :], targets[:, i, :, :])
                # print(criterion(predictions[:, i, :, :], targets[:, i, :, :]))
        
        # Check also for NaN/Inf in the loss itself
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN/Inf detected in loss at epoch {epoch+1}, batch {batch_idx}")
        

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

    val_loss = 0
    num_val_batches = 0
    
    with torch.no_grad():  # Disable gradient computation
        for val_inputs, val_targets in val_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

            # Forward pass
            val_predictions = model(val_inputs)
            val_predictions = val_predictions.view(batch_size, lead_time * out_channels, height, width)
            
            # Compute validation loss (same approach as training loss)
            batch_val_loss = 0
            for i in range(out_channels):
                if i == 0:  # Ocean variable
                    batch_val_loss += criterion(val_predictions[:, 0, :, :] * sst_mask, val_targets[:, 0, :, :] * sst_mask)
                elif i == 2:  # pottmp variable
                    batch_val_loss += criterion(val_predictions[:, 2, :, :] * pottmp_mask, val_targets[:, 2, :, :] * pottmp_mask)
                elif i == 10:  # tp variable
                    batch_val_loss += 100 * criterion(val_predictions[:, 10, :, :] * tp_mask, val_targets[:, 10, :, :] * tp_mask)
                else:
                    batch_val_loss += criterion(val_predictions[:, i, :, :], val_targets[:, i, :, :])

            val_loss += batch_val_loss.item()
            num_val_batches += 1

    # Compute average validation loss
    avg_val_loss = val_loss / num_val_batches

    # Print training and validation loss
    print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], "
        f"Train Loss: {loss.item():.4f}, Val Loss: {avg_val_loss:.4f}")

    # Define where to save the model
    model_path = f"weights/unet_model_epoch_{epoch+1}_train_loss_{loss.item():.2f}_val_loss_{avg_val_loss:.2f}.pth"

    # Save only the model weights (state_dict)
    torch.save(model.state_dict(), model_path)

    print(f"Model weights saved to {model_path}")

In [6]:
# Inference
# Load the saved weights
model.load_state_dict(torch.load("weights/unet_model_epoch_5_train_loss_5.80_val_loss_6.00.pth", map_location=device))

# Set model to evaluation mode
model.eval()
print("Model loaded successfully!")

Model loaded successfully!


In [13]:
# Get a sample batch from the test loader
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
test_inputs, test_targets = next(iter(test_loader))
test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)

# Run inference
with torch.no_grad():
    predictions = model(test_inputs)  # Shape: (batch_size, lead_time, 11, 721, 1431)

# Reshape predictions to match targets
predictions = predictions.view(test_targets.shape)  # Ensure they have the same shape
loss = 0
for i in range(out_channels):
    if i == 0:  # Ocean variable
        loss += criterion(predictions[:, 0, :, :] * sst_mask, test_targets[:, 0, :, :] * sst_mask)
    elif i == 2:  # pottmp variable
        loss += criterion(predictions[:, 2, :, :] * pottmp_mask, test_targets[:, 2, :, :] * pottmp_mask)
    elif i == 10:  # tp variable
        loss += 100 * criterion(predictions[:, 10, :, :] * tp_mask, test_targets[:, 10, :, :] * tp_mask)
    else:
        loss += criterion(predictions[:, i, :, :], test_targets[:, i, :, :])

print(loss.item())

6.5511274337768555
