In [12]:
import os 
import numpy as np 
import time 
from tqdm import tqdm
import xarray as xr
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset, random_split

## Data Loading
In this block of code, we define a custom dataser class using pytorch `Dataset` class for our dataset. Following this, we create `Dataloaders` for pairwise data of `high_res_data` and `low_res_data`.

In [13]:
class MinMaxScaleTransform:
    def __init__(self, high_res_data, low_res_data, use_half=False):
        self.use_half = use_half

        # Compute min and max for each variable at each time point using numpy
        self.high_res_mins = np.amin(high_res_data, axis=(2, 3), keepdims=True)
        self.high_res_maxs = np.amax(high_res_data, axis=(2, 3), keepdims=True)
        self.low_res_mins = np.amin(low_res_data, axis=(2, 3), keepdims=True)
        self.low_res_maxs = np.amax(low_res_data, axis=(2, 3), keepdims=True)

    def __call__(self, sample):
        high_res, low_res = sample
        dtype = torch.float16 if self.use_half else torch.float32
        high_res = (high_res - self.high_res_mins) / (self.high_res_maxs - self.high_res_mins)
        low_res = (low_res - self.low_res_mins) / (self.low_res_maxs - self.low_res_mins)
        
        return torch.tensor(high_res, dtype=dtype), torch.tensor(low_res, dtype=dtype)


class WRFDataset(Dataset):
    def __init__(self, high_res_data, low_res_data, chunk_size_lat, chunk_size_long, transform=None):
        self.high_res_data = high_res_data
        self.low_res_data = low_res_data
        self.chunk_size_lat = chunk_size_lat
        self.chunk_size_long = chunk_size_long
        self.transform = transform
        
        # Ensure both datasets have the same shape
        assert high_res_data.shape == low_res_data.shape, "High-res and low-res data must have the same shape"
        
        # Calculate the number of chunks
        self.n_chunks_lat = high_res_data.shape[2] // chunk_size_lat
        self.n_chunks_long = high_res_data.shape[3] // chunk_size_long

        # Calculate the total number of chunks
        self.n_chunks = self.n_chunks_lat * self.n_chunks_long

    def __len__(self):
        return self.n_chunks

    def __getitem__(self, idx):
        # Calculate the chunk's starting indices for latitude and longitude
        lat_idx = idx // self.n_chunks_long
        long_idx = idx % self.n_chunks_long
        
        lat_start = lat_idx * self.chunk_size_lat
        lat_end = lat_start + self.chunk_size_lat
        long_start = long_idx * self.chunk_size_long
        long_end = long_start + self.chunk_size_long
        
        high_res_chunk = self.high_res_data[:, :, lat_start:lat_end, long_start:long_end]
        low_res_chunk = self.low_res_data[:, :, lat_start:lat_end, long_start:long_end]
        
        sample = (high_res_chunk, low_res_chunk)
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample

def create_loaders(dataset, batch_size: int = 16):
    # Split indices
    total_size = len(dataset)
    train_size = int(0.8 * total_size)
    train_dataset, test_dataset = random_split(dataset, [train_size, total_size - train_size])

    valid_size = int(0.2 * len(train_dataset))
    train_size = len(train_dataset) - valid_size
    train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, test_loader

## Model Definition
Here, we define the model. The model is a custom modified model of SRCNN (https://arxiv.org/abs/1501.00092). It has been modified by using residual Connections abd Deep Residual Blocks as shown in the ESRGAN paper (https://arxiv.org/abs/1809.00219).

In [14]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_activation: bool, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias = True)
        self.activation = nn.LeakyReLU(0.2, inplace=True) if use_activation else nn.Identity()

    def forward(self, x):
        return self.activation(self.conv(x))
    
class DenseResidualBlock(nn.Module):
    def __init__(self, in_channels: int, channels = 32, beta: float = 0.2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.beta = beta 
        self.conv = nn.ModuleList()

        for block_no in range(5):
            self.conv.append(ConvBlock(in_channels + channels * block_no, 
                                       channels if block_no < 4 else in_channels,
                                         use_activation=True if block_no < 4 else False,
                                           kernel_size=3, stride=1, padding=1))
            
            
    def forward(self, x):
        new_inputs = x
        for block in self.conv:
            out = block(new_inputs)
            new_inputs = torch.cat([new_inputs, out], dim=1)
        return self.beta * out + x
    

class RRDB(nn.Module):
    def __init__(self, in_channels, residual_beta, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.residual_beta = residual_beta
        self.rrdb = nn.Sequential(*[DenseResidualBlock(in_channels, beta=residual_beta) for _ in range(3)])

    def forward(self, x):
        return self.rrdb(x) * self.residual_beta + x
    

class ModifiedSRCNN(nn.Module):
    def __init__(self, in_channels: int, num_blocks: int, 
                 n1: int, n2: int, f1: int, f2: int, f3: int,
                 *args, **kwargs) -> None:
        """ Initialize the SRCNN with Dense Residual network model with the required layers 
         Below params are the hyperparameters for the SRCNN model without the 
         Bassic block which has been added extra other than the resisual connections.
        in_channels (int): Input number of channels
        num_blocks (int): Number of RRDB blocks
        n1 (int): Number of filters in the first convolutional layer
        n2 (int): Number of filters in the second convolutional layer
        f1 (int): Kernel size of the first convolutional layer
        f2 (int): Kernel size of the second convolutional layer
        f3 (int): Kernel size of the third convolutional layer
        residual_beta (float): Residual connection weight
        """
        super().__init__(*args, **kwargs)
        self.conv1 = ConvBlock(in_channels, n1, kernel_size=f1, stride=1, padding=(1, 1, 1), use_activation=True)
        self.bn1 = nn.BatchNorm3d(n1)
        self.blocks = nn.Sequential(*[RRDB(n1 + in_channels, residual_beta=0.5) for _ in range(num_blocks)])
        self.bn2 = nn.BatchNorm3d((n1 + in_channels))
        self.conv2 = ConvBlock(2 * (n1 + in_channels), n2, kernel_size=f2, stride=1, padding=(1, 1, 1), use_activation=True)
        self.bn3 = nn.BatchNorm3d(n2)
        self.conv3 = ConvBlock(n2 + n1 + in_channels, in_channels, kernel_size=f3, stride=1, padding=1, use_activation=False)
        self.bn4 = nn.BatchNorm3d(in_channels)
        self._initialize_weights()

    def forward(self, x):
        
        with torch.cuda.amp.autocast():
            initial = x 
            x = self.conv1(x)
            x = self.bn1(x)
            x = torch.concat([x, initial], dim = 1)
            initial = x
            x = self.blocks(x)
            x = self.bn2(x)
            x = torch.concat([x, initial], dim = 1)
            x = self.conv2(x) # Take feature maps here
            x = self.bn3(x)
            x = torch.concat([x, initial], dim=1)
            x = self.conv3(x)
            x = self.bn4(x)
        return x 
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

## Training phase
In this block of code, we write the training loop of the model. The model creates a folder called `Logs` during its training. In this folder, (1) it creates a `logs.log` file where it saves the validation losses and the training losses per epoch along with the epoch time and (2) It saves the best weights of the model.
An `Early stopping Callback` has been implemented with a patience of `p` for effective training to prevent overfitting of the model.

In [15]:
def train_model(model, train_loader, val_loader, criterion, optimizer, 
                num_epochs, device, log_folder, patience=5):
    """ Train the model using the specified data loaders and hyperparameters.
    Saves the best model weights based on the validation loss.
    Args:
        model (torch.nn.Module): Model to be trained
        train_loader (torch.utils.data.DataLoader): Training data loader
        val_loader (torch.utils.data.DataLoader): Validation data loader
        criterion (torch.nn.Module): Loss function
        optimizer (torch.optim.Optimizer): Optimizer
        num_epochs (int): Number of epochs to train the model
        device (torch.device): Device to run the model on
        log_folder (str): Folder to store logs and model weights
        patience (int): Number of epochs to wait before early stopping
    Returns:
        torch.nn.Module: Trained model
    """
    # Move model to the specified device
    model.to(device)
    
    # Create directories for storing artifacts
    os.makedirs(log_folder, exist_ok=True)
    
    log_file = os.path.join(log_folder, 'logs.log')
    best_weights_file = os.path.join(log_folder, 'best_weights.pth')
    
    best_loss = float('inf')
    patience_counter = 0
    
    with open(log_file, 'w') as log:
        log.write('Epoch,Train Loss,Val Loss,Epoch Time\n')
        
        for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
            start_time = time.time()
            
            # Training phase
            model.train()
            train_losses = []
            for hr_images, lr_images in train_loader:
                hr_images, lr_images = hr_images.to(device), lr_images.to(device)
                optimizer.zero_grad()
                
                # Enable autocast context for mixed precision training
                with autocast():
                    sr_images = model(lr_images)
                    loss = criterion(sr_images, hr_images)
                
                # Backward pass
                loss.backward()
                
                # Clip gradients to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                
                train_losses.append(loss.item())
            
            train_loss = np.mean(train_losses)
            
            # Validation phase
            model.eval()
            val_losses = []
            with torch.no_grad():
                for hr_images, lr_images in val_loader:
                    hr_images, lr_images = hr_images.to(device), lr_images.to(device)
                    with autocast():
                        sr_images = model(lr_images)
                        loss = criterion(sr_images, hr_images)
                    val_losses.append(loss.item())
            
            val_loss = np.mean(val_losses)
            
            epoch_time = time.time() - start_time
            
            print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Epoch Time: {epoch_time:.2f}s')
            log.write(f'{epoch+1},{train_loss},{val_loss},{epoch_time}\n')
            
            # Check for best validation loss
            if val_loss < best_loss:
                best_loss = val_loss
                patience_counter = 0
                torch.save(model.state_dict(), best_weights_file)
            else:
                patience_counter += 1
                        
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break
    
    return model

## Configurations 
Here, we define the model hyperparameters for easy and compact access throughout the training process.

In [16]:
NUM_EPOCHS: int = 1
PATIENCE: int = 15
BATCH_SIZE: int= 8
LOG_FOLDER: str = "Logs"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######### Dataloader hyperparameters ########
LATITUDE_CHUNK_SIZE = 16
LONGITUDE_CHUNK_SIZE = 16

######### MODEL HYPERPARAMETERS #########
in_channels = 7
num_blocks = 2
n1 = 32
n2 = 128
f1 = 3
f2 = 3
f3 = 3
LEARNING_RATE: int = 3e-4

In [11]:
ozone_2011 = xr.open_dataset("/Volumes/Extreme SSD/PRL/data/high_res/WRF_2011.nc")
co_no2_2011 = xr.open_dataset("/Volumes/Extreme SSD/PRL/data/high_res/WRF_Archi_2011_CO_NO2.nc")
no_2011 = xr.open_dataset("/Volumes/Extreme SSD/PRL/data/high_res/WRF_Archi_2011_NO.nc")
humidity_2011 = xr.open_dataset("/Volumes/Extreme SSD/PRL/data/high_res/WRF_Archi_2011_SpecificHum.nc")
temp_2011 = xr.open_dataset("/Volumes/Extreme SSD/PRL/data/high_res/WRF_2011_Archi_T.nc")


PRESSURE_LEVEL = 2
high_res_data = np.array([ozone_2011["o3"].sel(bottom_top=PRESSURE_LEVEL), 
            ozone_2011["PM2_5_DRY"].sel(bottom_top=PRESSURE_LEVEL),
            co_no2_2011["co"].sel(bottom_top=PRESSURE_LEVEL), 
            co_no2_2011["no2"].sel(bottom_top=PRESSURE_LEVEL), 
            no_2011["no"].sel(bottom_top=PRESSURE_LEVEL), 
            humidity_2011["QVAPOR"].sel(bottom_top=PRESSURE_LEVEL),
            temp_2011["T2"]])

low_res_data = high_res_data + np.random.rand(*high_res_data.shape)
min_max_transform = MinMaxScaleTransform(high_res_data, low_res_data, use_half=True)

dataset = WRFDataset(high_res_data, low_res_data, 
                     LATITUDE_CHUNK_SIZE, 
                     LONGITUDE_CHUNK_SIZE,   
                     transform=min_max_transform)

train_loader, valid_loader, test_loader = create_loaders(dataset, BATCH_SIZE)
train_loader, valid_loader, test_loader = train_loader, valid_loader, test_loader

In [7]:
# TRAINING PART
model_srcnn = ModifiedSRCNN(in_channels=in_channels, num_blocks=num_blocks, n1=n1, n2=n2, f1=f1, f2=f2, f3=f3)
model_srcnn = model_srcnn.half().to(DEVICE)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model_srcnn.parameters(), lr=LEARNING_RATE)
# train_model(model = model_srcnn, train_loader=train_loader, val_loader=valid_loader, 
#             criterion=criterion, optimizer=optimizer, num_epochs=NUM_EPOCHS, 
#             log_folder=LOG_FOLDER, device=DEVICE, patience=PATIENCE)

In [10]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
model_srcnn.eval()
for idx, (hr, lr) in enumerate(train_loader):
    hr, lr = hr.to(DEVICE), lr.to(DEVICE)
    
    with autocast():
            sr = model_srcnn(lr)
            loss = criterion(sr, hr)
    print(f"Epoch = {idx}, loss = {loss}")
    

OutOfMemoryError: CUDA out of memory. Tried to allocate 88.00 MiB. GPU 0 has a total capacty of 14.74 GiB of which 2.12 MiB is free. Process 6862 has 14.74 GiB memory in use. Of the allocated memory 14.21 GiB is allocated by PyTorch, and 409.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF