# Task 1

In this task, we need to attempt to create a Deep Learning model that, given a sequence of images of Hurricanes, it will try to predict how the hurricane will be like at a future time.

In [1]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset, BatchSampler, DataLoader, Sampler

from torchvision.utils import make_grid
from torchvision.transforms import ToTensor
from torchsummary import summary
import torch

import matplotlib.pyplot as plt

In [2]:
def set_device(device="cpu", idx=0):
    if device != "cpu":
        if torch.cuda.device_count() > idx and torch.cuda.is_available():
            print("Cuda installed! Running on GPU {} {}!".format(idx, torch.cuda.get_device_name(idx)))
            device="cuda:{}".format(idx)
        elif torch.cuda.device_count() > 0 and torch.cuda.is_available():
            print("Cuda installed but only {} GPU(s) available! Running on GPU 0 {}!".format(torch.cuda.device_count(), torch.cuda.get_device_name()))
            device="cuda:0"
        else:
            device="cpu"
            print("No GPU available! Running on CPU")
    return device

device = set_device("cpu")

Below, we define the dataset. In the 

In [3]:
class GroupedStormImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.storms = []
        self.index_map = []

        for storm_dir in sorted(os.listdir(root_dir)):
            storm_path = os.path.join(root_dir, storm_dir)
            if os.path.isdir(storm_path):
                storm_data = []
                for image_file in sorted(os.listdir(storm_path)):
                    if image_file.endswith('.jpg'):
                        image_path = os.path.join(storm_path, image_file)
                        file_stem = image_file.split('.')[0]
                        features_json_path = os.path.join(storm_path, file_stem + '_features.json')
                        with open(features_json_path, 'r') as f:
                            features = json.load(f)
                        storm_data.append((image_path, features))
                self.storms.append(storm_data)
                self.index_map.extend([(len(self.storms) - 1, i) for i in range(len(storm_data))])

    def __len__(self):
        return sum(len(storm) for storm in self.storms)

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            # Convert slice to a range of indices
            idx = range(*idx.indices(len(self)))  # Adjusts slice to fit dataset length

        if isinstance(idx, range) or isinstance(idx, list):
            images = []
            features = []
            for i in idx:
                storm_idx, local_idx = self.index_map[i]
                img_path, feature = self.storms[storm_idx][local_idx]
                image = Image.open(img_path).convert('L')
                if self.transform:
                    image = self.transform(image)
                images.append(image)
                features.append(feature)
            return torch.stack(images), features
        else:
            storm_idx, local_idx = self.index_map[idx]
            img_path, feature = self.storms[storm_idx][local_idx]
            image = Image.open(img_path).convert('L')
            if self.transform:
                image = self.transform(image)
            return image, feature

    def get_storm_sequence(self, storm_idx, seq_start, seq_length):
        # Retrieve a sequence of images from a specific storm
        images = []
        features = []
        for i in range(seq_start, seq_start + seq_length):
            img_path, feature = self.storms[storm_idx][i]
            image = Image.open(img_path).convert('L')
            if self.transform:
                image = self.transform(image)
            images.append(image)
            features.append(feature)
        return torch.stack(images), features  # Stack images to create a sequence tensor


In [4]:
# Example of what an element in the image dataset looks like.
dataset = GroupedStormImageDataset(root_dir='./Selected_Storms_curated_to_zip', transform=ToTensor())
dataset[0]

(tensor([[[0.2471, 0.3373, 0.4314,  ..., 0.0431, 0.0431, 0.0392],
          [0.2431, 0.3255, 0.4118,  ..., 0.0431, 0.0431, 0.0392],
          [0.2863, 0.3569, 0.4235,  ..., 0.0431, 0.0431, 0.0392],
          ...,
          [0.0588, 0.0627, 0.0667,  ..., 0.1647, 0.1843, 0.2039],
          [0.0627, 0.0667, 0.0706,  ..., 0.1725, 0.1882, 0.2039],
          [0.0588, 0.0667, 0.0706,  ..., 0.1647, 0.1804, 0.1922]]]),
 {'storm_id': 'bkh', 'relative_time': '0', 'ocean': '1'})

For our storm image dataset, each storm represents a different event with a series of images. We need to ensure that batches of data fed into the model are coherent and respect the boundaries of these events to ensure the model learns how a specific storm behaves. So, in order to not mix events, we use:

**GroupedStormSampler**

This custom sampler is designed to iterate over the dataset while maintaining the grouping by storm events. It ensures that the model sees all images from one storm before moving to the next, respecting the grouping in the data (the individual storms).

**GroupedBatchSampler**

Building on the GroupedStormSampler, the GroupedBatchSampler takes this a step further by making sure that each batch of data not only comes from the same storm but also follows a specified batch size (at most). This sampler respects the boundaries of each storm, ensuring that no batch contains data from two different storms. If we are at the boundary of two different storms, it will yield a batch size smaller than the maximum batch size.

See below for an example that showcases the use of these two classes.

In [5]:
class GroupedStormSampler(Sampler):
    """
    A custom sampler for iterating over items in a dataset grouped by storm events,
    ensuring that each iteration step progresses through a single storm's data.

    Attributes:
        data_source (Dataset): The dataset to sample from, expected to have a 'storms' attribute
                               that contains grouped data per storm.
        indices (list): A list of tuples where each tuple contains the start and end indices
                        for samples belonging to the same storm.
    """
    def __init__(self, data_source):
        self.data_source = data_source
        self.indices = [] # List to hold start and end indices for each storm's data
        idx = 0
        # Generate start and end indices for each storm
        for storm in self.data_source.storms:
            self.indices.append((idx, idx + len(storm))) # Append a tuple with start and end indices
            idx += len(storm) # Update index to the start of the next storm

    def __iter__(self):
        """
        Provides an iterator over the dataset indices, grouped by storm.

        Yields:
            int: The next index in the dataset, progressing through each storm's data sequentially.
        """
        for start, end in self.indices:
            yield from range(start, end) # Yield indices for the current storm

    def __len__(self):
        return sum(len(storm) for storm in self.data_source.storms)



In [6]:
class StormBatchSampler(Sampler):
    def __init__(self, data_source, sequence_length, batch_size, stride=1, drop_last=True):
        self.data_source = data_source
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.stride = stride
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for storm_idx, storm_data in enumerate(self.data_source.storms):
            # Start from the beginning of each storm
            seq_start = 0
            while seq_start + self.batch_size <= len(storm_data):
                # Find the global index for the start of this sequence
                global_start_index = next((i for i, (s_idx, l_idx) in enumerate(self.data_source.index_map) if s_idx == storm_idx and l_idx == seq_start), None)
                if global_start_index is not None:
                    # Collect the global indices for the entire sequence
                    sequence_indices = list(range(global_start_index, global_start_index + self.batch_size))
                    batch.append(sequence_indices)

                    if len(batch) == self.sequence_length:
                        yield batch
                        batch = []

                # Move to the start of the next sequence within the storm, based on the stride
                seq_start += self.stride

        # Handle the last batch if drop_last is False
        if not self.drop_last and len(batch) > 0:
            yield batch

    def __len__(self):
        # Length calculation needs to be implemented based on your dataset's structure
        raise NotImplementedError("Length calculation needs to be implemented.")



In [7]:
from torch.utils.data import DataLoader
sequence_length = 5  # Length of each sequence
batch_size = 6  # Number of sequences per batch
stride = 1  # Overlap between sequences
storm_sampler = GroupedStormSampler(dataset)

sequential_batch_sampler = StormBatchSampler(
    data_source=dataset,
    sequence_length=sequence_length,
    batch_size=batch_size,
    stride=stride,
    drop_last=True  # You can choose to drop the last incomplete batch if desired
)

storm_data_loader = DataLoader(dataset, batch_sampler=sequential_batch_sampler)

i = 0
for batch in storm_data_loader:
    if i == 130:
        images, data = batch
        print(images.shape)  # Should print torch.Size([M, N, C, H, W])
        print(data)
        break  # To only check the first batch
    i+=1

torch.Size([5, 6, 1, 366, 366])
[{'storm_id': ['blq', 'blq', 'dzw', 'dzw', 'dzw'], 'relative_time': ['655198', '658801', '0', '1799', '3600'], 'ocean': ['1', '1', '1', '1', '1']}, {'storm_id': ['blq', 'blq', 'dzw', 'dzw', 'dzw'], 'relative_time': ['658801', '660600', '1799', '3600', '5399'], 'ocean': ['1', '1', '1', '1', '1']}, {'storm_id': ['blq', 'blq', 'dzw', 'dzw', 'dzw'], 'relative_time': ['660600', '662400', '3600', '5399', '7200'], 'ocean': ['1', '1', '1', '1', '1']}, {'storm_id': ['blq', 'blq', 'dzw', 'dzw', 'dzw'], 'relative_time': ['662400', '664200', '5399', '7200', '10801'], 'ocean': ['1', '1', '1', '1', '1']}, {'storm_id': ['blq', 'blq', 'dzw', 'dzw', 'dzw'], 'relative_time': ['664200', '666000', '7200', '10801', '12601'], 'ocean': ['1', '1', '1', '1', '1']}, {'storm_id': ['blq', 'blq', 'dzw', 'dzw', 'dzw'], 'relative_time': ['666000', '669600', '10801', '12601', '14401'], 'ocean': ['1', '1', '1', '1', '1']}]


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvTLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels=8, kernel_size=3, output_channels=1, cuda_flag=False):
        super(ConvTLSTMCell, self).__init__()  # Fixed __init__ typo
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.cuda_flag = cuda_flag  # to device instead?
        self.output_channels = output_channels

        self.W_i = nn.Conv2d(input_channels + hidden_channels, hidden_channels, kernel_size, padding=self.padding)
        self.W_f = nn.Conv2d(input_channels + hidden_channels, hidden_channels, kernel_size, padding=self.padding)
        self.W_o = nn.Conv2d(input_channels + hidden_channels, hidden_channels, kernel_size, padding=self.padding)
        self.W_c = nn.Conv2d(input_channels + hidden_channels, hidden_channels, kernel_size, padding=self.padding)
        
        # Decay parameters for short-term memory
        self.W_d = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, padding=self.padding)

        # Output transformation layer
        self.output_transform = nn.Conv2d(hidden_channels, output_channels, kernel_size=1)  # 1x1 conv to adjust channel size without changing spatial dimensions

    def forward(self, inputs, timestamps, state):
        h, c = state
        #print("inputs: ", inputs.shape)
        combined = torch.cat([inputs, h], dim=1)  # Combining input and hidden state for convolution

        # Compute gates and candidate memory cell with convolutional operations
        i = torch.sigmoid(self.W_i(combined))
        f = torch.sigmoid(self.W_f(combined))
        o = torch.sigmoid(self.W_o(combined))
        c_hat = torch.tanh(self.W_c(combined))

        # Time decay mechanism
        c_s = torch.tanh(self.W_d(c))
        decay = torch.exp(-timestamps).unsqueeze(1).unsqueeze(2).unsqueeze(3)  # Reshape decay to [batch_size, 1, 1, 1] so that it can be multiplied
        decay = decay.expand_as(c_s)  # Ensure decay is broadcastable to the shape of c_s

        c_l = c - c_s
        c_star = c_l + c_s

        # Current memory cell and hidden state
        c = f * c_star + i * c_hat
        h = o * torch.tanh(c)
        print('h:, ', h.shape)

        # Transform output to target image size
        output = self.output_transform(h)  # Transform the hidden state to match the target image size

        if self.cuda_flag:
            output = output.cuda()

        return output, (h, c)  # Return both the transformed output and the raw hidden states for potential further use


In [9]:
def train_epoch(model, data_loader, optimizer, criterion, device, max_batches=None):
    model.train()
    total_loss = 0
    batches_processed = 0


    for images, data in data_loader:
        if max_batches is not None and batches_processed >= max_batches:
            break  # Stop after processing max_batches
        # Extract relative_time for all items in the batch and convert to integers
        relative_times = [list(map(int, data_sequence['relative_time'])) for data_sequence in data]

        # Convert the list of lists into a 2D tensor of shape (batch_size, sequence_length)
        timestamps = torch.tensor(relative_times, dtype=torch.float32).to(device)

        images = images.to(device)
        #print(timestamps.shape)

        # Input is all but the last image in the sequence

        input_images = images[:-1, :, :, :, :]
        print("input_images:", input_images.shape)
        input_timestamps = timestamps[:, :-1]
        #print(input_timestamps.shape)

        # Target is the last image in the sequence for every sequence in batch
        target_image = images[-1, :, :, :, :]
        print("target imaages:", target_image.shape)


        optimizer.zero_grad()

        # Initialize hidden and cell states to None for the start of the sequence
        # Inside the train_epoch and evaluate_epoch functions, before the loop over t
        seq_len, batch_size, _, height, width = input_images.size()
        #print(batch_size, seq_len)
        hidden_channels = model.hidden_channels

        # Initialize hidden and cell states with zeros
        h = torch.zeros(batch_size, hidden_channels, height, width, device=device)
        c = torch.zeros(batch_size, hidden_channels, height, width, device=device)
        print("c", c.shape)

        state = (h, c)
        
        # Process the sequence through the ConvTLSTMCell. Go element by element from each batch in parallel.
        for t in range(input_images.size(0)):
            img_t = input_images[t, :, :, :, :]  # Get each image from each moment in time from batch
            print("img_t", img_t.shape)
            print(input_timestamps[:, t]) # All the timestamps from the same index from each batch in parallel
            # Update the state with each timestep; output is only used at the last timestep
            output, state = model(img_t, input_timestamps[:, t], state)

        # The final output is the prediction for the last image in the sequence for the whole batch
        # of size batch_size, 1, 366, 366.
        predicted_image = output

        # Compute loss
        loss = criterion(predicted_image, target_image)
        total_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()

        batches_processed += 1

    avg_loss = total_loss / len(data_loader.dataset)
    return avg_loss

def evaluate_epoch(model, data_loader, criterion, device, max_batches=None):
    model.eval()
    total_loss = 0
    batches_processed = 0

    with torch.no_grad():
        for images, data in data_loader:
            if max_batches is not None and batches_processed >= max_batches:
                break  # Stop after processing max_batches
            
            # Extract relative_time for all items in the batch and convert to integers
            relative_times = [list(map(int, data_sequence['relative_time'])) for data_sequence in data]

            # Convert the list of lists into a 2D tensor of shape (batch_size, sequence_length)
            timestamps = torch.tensor(relative_times, dtype=torch.float32).to(device)

            images = images.to(device)

            input_images = images[:-1, :, :, :, :]
            input_timestamps = timestamps[:, :-1]
            target_image = images[-1, :, :, :, :]

            # Initialize hidden and cell states to None for the start of the sequence
            # Inside the train_epoch and evaluate_epoch functions, before the loop over t
            seq_len, batch_size, _, height, width = input_images.size() 
            hidden_channels = model.hidden_channels

            # Initialize hidden and cell states with zeros
            h = torch.zeros(seq_len, hidden_channels, height, width, device=device)
            c = torch.zeros(seq_len, hidden_channels, height, width, device=device)

            state = (h, c)

            # Process the sequence through the ConvTLSTMCell
            for t in range(input_images.size(1)):
                img_t = input_images[:, t, :, :, :]  # Add channel dimension
                output, state = model(img_t, input_timestamps[t, :], state)

            predicted_image = output

            # Compute loss
            loss = criterion(predicted_image, target_image)
            total_loss += loss.item()

            batches_processed += 1


    avg_loss = total_loss / len(data_loader.dataset)
    return avg_loss

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

def train_combined(model, train_loader, val_loader, lr, epochs, device):
    model.to(device)

    # Set up the optimizer and the loss function
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()  # Mean Squared Error loss
    scheduler = StepLR(optimizer, step_size=5, gamma=0.6)  # TODO: Adjust this

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, max_batches=5)

        # Evaluation phase
        model.eval()
        val_loss = evaluate_epoch(model, val_loader, criterion, device, max_batches=5)

        # Scheduler step (if using learning rate decay)
        scheduler.step()

        # Print epoch summary
        print(f'Epoch {epoch + 1}/{epochs}: Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        # Save model checkpoints periodically or based on certain conditions, e.g., every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'ConvTLSTM_checkpoint_epoch_{epoch + 1}.pth')
            print(f'Model checkpoint saved at epoch {epoch + 1}')

    # Optionally, return the model if you need to use it immediately after training
    return model


In [10]:
from torch.utils.data import DataLoader
sequence_length = 10  # Length of each sequence
batch_size = 6  # Number of sequences per batch
stride = 1  # Overlap between sequences
storm_sampler = GroupedStormSampler(dataset)

sequential_batch_sampler = StormBatchSampler(
    data_source=dataset,
    sequence_length=sequence_length,
    batch_size=batch_size,
    stride=stride,
    drop_last=True 
)

storm_data_loader = DataLoader(dataset, batch_sampler=sequential_batch_sampler)

In [11]:
model = ConvTLSTMCell(input_channels=1, hidden_channels=8, kernel_size=3, output_channels=1, cuda_flag=False)

train_loader = storm_data_loader
val_loader = storm_data_loader  
# TODO: Validation set

# Call the training routine
train_combined(model, train_loader, val_loader, lr=0.001, epochs=10, device=torch.device(device))


input_images: torch.Size([9, 6, 1, 366, 366])
target imaages: torch.Size([6, 1, 366, 366])
c torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([    0.,  1801.,  3600.,  5400.,  7200., 10802.])
h:,  torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([ 1801.,  3600.,  5400.,  7200., 10802., 12602.])
h:,  torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([ 3600.,  5400.,  7200., 10802., 12602., 14402.])
h:,  torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([ 5400.,  7200., 10802., 12602., 14402., 16202.])
h:,  torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([ 7200., 10802., 12602., 14402., 16202., 18002.])
h:,  torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([10802., 12602., 14402., 16202., 18002., 21602.])
h:,  torch.Size([6, 8, 366, 366])
img_t torch.Size([6, 1, 366, 366])
tensor([12602., 14402., 16202., 18002., 21602., 23402.])
h:,  torch.Size([6, 8, 366, 36

KeyboardInterrupt: 