In [1]:
from dependencies import *
import torch.nn as nn

In [2]:
#Inputs
precip_input= xr.open_dataset ("/work/FAC/FGSE/IDYST/tbeucler/downscaling/sasthana/Downscaling/Downscaling/data/processed/RhiresD_1km_bicubic_Swiss_features_masked.nc")
temp_input= xr.open_dataset("/work/FAC/FGSE/IDYST/tbeucler/downscaling/sasthana/Downscaling/Downscaling/data/processed/TabsD_1km_bicubic_Swiss_features_masked.nc")

In [3]:
#targets
precip_target= xr.open_dataset("/work/FAC/FGSE/IDYST/tbeucler/downscaling/sasthana/Downscaling/Downscaling/data/raw/RhiresD_1971_2022.nc")
temp_target= xr.open_dataset("/work/FAC/FGSE/IDYST/tbeucler/downscaling/sasthana/Downscaling/Downscaling/data/raw/TabsD_1971_2022.nc")

In [4]:
#Buildign a pytorch dataset with these inputs and targets 
from torch.utils.data import Dataset
import torch

In [5]:
class Downscaling_Dataset(Dataset):
    def __init__(self, input_ds, target_ds, var_name_inputs, var_name_targets, transform=None):
        """Takes in xarray datasets as inputs for both inputs and tagrets, var_name is for variable name inside the dataset, and transform is for any preprocessing in case it is done"""

        self.input = input_ds[var_name_inputs]
        self.target = target_ds[var_name_targets]
        self.transform= transform

    def __len__(self):
        """Returns the length of the dataset, which is the number of timesteps in the input/target inputs"""
        return len(self.input.time)
    
    def __getitem__(self, index):
        """Returns the pair of input and target data for a given index"""
        input_img= self.input.isel(time=index).values
        target_img= self.target.isel(time=index).values

        #Channel dimension for the input and ther target 
        input_img= torch.tensor(input_img).unsqueeze(0).float()
        target_img= torch.tensor(target_img).unsqueeze(0).float()

        #If any transform is provided , else transform is None by default
        if self.transform:
            input_img= self.transform(input_img)
            target_img= self.transform(target_img)

         #Returnig  the input targhet pairs   
        return input_img, target_img


xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

###Normalisation code: To be written later as required 

Max min scaling

xxxxxxxxxxxxxxxxxxxxxxxx

In [6]:
#Building the full dataset and then splitting it into train, validation and test datasets

In [7]:
torch_precip_dataset= Downscaling_Dataset(precip_input, precip_target, "pr","RhiresD", transform=None)
torch_temp_dataset= Downscaling_Dataset(temp_input, temp_target, "tas","TabsD", transform=None)

We will do spatial and temporal cross validation later for different cases. For now, for each decade, 70% testing, 30 percent eval, 10 percent testing . No shiffling between decades

In [8]:
from torch.utils.data import Subset
import pandas as pd
from collections import defaultdict

In [9]:
torch.manual_seed(42)

#Reading the time coordinate
times = precip_input['time'].values
times = pd.to_datetime(times)  

# Map samples to decades
def get_decade(year):
    return int(year // 10 * 10)

decades = [get_decade(t.year) for t in times]

# Indices for each decade
decade_to_indices = defaultdict(list)
for idx, decade in enumerate(decades):
    decade_to_indices[decade].append(idx)

# Split within each decade
train_indices = []
val_indices = []
test_indices = []

for decade, indices in decade_to_indices.items():
    indices = torch.tensor(indices)
    indices = indices[torch.randperm(len(indices))]

    n_total = len(indices)
    n_train = int(0.7 * n_total)
    n_val = int(0.2 * n_total)
    n_test = n_total - n_train - n_val  # whatever is left goes into testing

    train_indices.append(indices[:n_train])
    val_indices.append(indices[n_train:n_train + n_val])
    test_indices.append(indices[n_train + n_val:])

# Concatenate all decades
train_indices = torch.cat(train_indices)
val_indices = torch.cat(val_indices)
test_indices = torch.cat(test_indices)

# Subset the datasets
precip_train = Subset(torch_precip_dataset, train_indices)
precip_val = Subset(torch_precip_dataset, val_indices)
precip_test = Subset(torch_precip_dataset, test_indices)

temp_train = Subset(torch_temp_dataset, train_indices)
temp_val = Subset(torch_temp_dataset, val_indices)
temp_test = Subset(torch_temp_dataset, test_indices)



In [10]:
x_precip, y_precip= precip_train[0]
x_temp, y_temp= temp_train[0]
print(x_precip.shape, y_precip.shape)

torch.Size([1, 265, 370]) torch.Size([1, 265, 370])


In [11]:
print(x_temp.shape, y_temp.shape)

torch.Size([1, 265, 370]) torch.Size([1, 265, 370])


In [12]:
class Paired(Dataset):
    def __init__(self, precip_dataset, temp_dataset):
        assert len(precip_dataset) == len(temp_dataset), "Datasets must be the same length."
        self.precip_dataset = precip_dataset
        self.temp_dataset = temp_dataset

    def __len__(self):
        return len(self.precip_dataset)

    def __getitem__(self, idx):
        precip_input, precip_target = self.precip_dataset[idx]
        temp_input, temp_target = self.temp_dataset[idx]

        # Stack input channels together (pr, tas)
        input_combined = torch.cat([precip_input, temp_input], dim=0)  # Shape: (2, H, W)

        # Stack output/target channels together (pr_target, tas_target)
        target_combined = torch.cat([precip_target, temp_target], dim=0)  # Shape: (2, H, W)

        return input_combined, target_combined


In [13]:
from torch.utils.data import DataLoader

train_dataset = Paired(precip_train, temp_train)
val_dataset = Paired(precip_val, temp_val)
test_dataset = Paired(precip_test, temp_test)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)


In [14]:
import sys
sys.path.append("/work/FAC/FGSE/IDYST/tbeucler/downscaling/sasthana/Downscaling/Downscaling/models/")
from UNet import UNet

model_01_experimental = UNet(in_channels=2, out_channels=2)


In [15]:
optimizer = torch.optim.Adam(model_01_experimental.parameters(), lr=1e-3)
criterion = nn.MSELoss() 

xxxxxxxxxxxxxxxxxxxxxxxxxxTraining Loop xxxxxxxxxxxxxxxxxxxxxxxxx

In [17]:
num_epochs = 10  

for epoch in range(num_epochs):
    model_01_experimental.train()
    train_loss = 0.0

    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model_01_experimental(inputs, targets)  # Forward pass
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.3f}")
    # Validation step
    model_01_experimental.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            outputs = model_01_experimental(inputs)
            val_loss += criterion(outputs, targets).item()
    val_loss /= len(val_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {val_loss:.3f}")


KeyboardInterrupt: 

In [None]:
# Quick test flag
quick_test = True

# Adjust dataset and loader for quick testing
if quick_test:
    # Use a small subset of the dataset
    small_train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, range(100)),
        batch_size=32,
        shuffle=True
    )
    small_val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, range(50)),
        batch_size=32,
        shuffle=False
    )
    train_loader = small_train_loader
    val_loader = small_val_loader
    num_epochs = 1  # Only one epoch for testing
else:
    # Use the full dataset
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
    num_epochs = 10  # or your desired number of epochs

# Training loop
for epoch in range(num_epochs):
    model_01_experimental.train()
    train_loss = 0.0

    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model_01_experimental(inputs, targets)  # Forward pass (pass targets for resizing safety)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        if quick_test and i == 2:  # Stop after 3 batches during quick test
            break

    train_loss /= (i + 1)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.3f}")

    # Validation step
    model_01_experimental.eval()
    val_loss = 0.0
    with torch.no_grad():
        for j, (inputs, targets) in enumerate(val_loader):
            outputs = model_01_experimental(inputs)  # No targets during validation
            loss = criterion(outputs, targets)
            val_loss += loss.item()

            if quick_test and j == 2:
                break

    val_loss /= (j + 1)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {val_loss:.3f}")
