In [1]:
import torch.nn.functional as F
import torch

from torch.utils import data
from torch import Tensor
from typing import Tuple

def crop_to_region(coords: Tuple[int], img: Tensor, crop_size: int=42) -> Tensor:
    """ 
    Given coordinates in the form Tuple[int](y, x), return a cropped
    sample of the input imaged centred at (y, x), matching the input size.
    Args:
        coords (Tuple[int]): The input coordinates (y, x) where the crop will be
        centred.
        img (Tensor): The input image, either 3x400x400, 3x250x250, 3x150x150
        crop_size (int, optional): The size of the returned crop. Defaults to 42.

    Returns:
        Tensor: The image cropped with central coordinates at (y, x) of size 
        (3 x size x size) # is size here referring to 42?
    """
    _, H, W = img.shape
    y, x = coords
    y_min, x_min = max(0, y-crop_size//2), max(0, x-crop_size//2)
    y_max, x_max = min(H, y+crop_size//2), min(W, x+crop_size//2)
    region = img[:, y_min:y_max, x_min:x_max]
    if region.shape[1] < crop_size:
        to_pad = crop_size - region.shape[1]
        padding = (0, 0, to_pad, 0) if (y-crop_size//2) < 0 else (0, 0, 0, to_pad)
        region = F.pad(region, padding, mode='replicate')

    if region.shape[2] < crop_size:
        to_pad = crop_size - region.shape[2]
        padding = (to_pad, 0, 0, 0) if (x-crop_size//2) < 0 else (0, to_pad, 0, 0)
        region = F.pad(region, padding, mode='replicate')
    return region

class MIT(data.Dataset):
    def __init__(self, dataset_path: str):
        """
        Given the dataset path, create the MIT dataset. Creates the
        variable self.dataset which is a list of dictionaries with three keys:
            1) X: For train the crop of image. This is of shape [3, 3, 42, 42]. The 
                first dim represents the crop across each different scale
                (400x400, 250x250, 150x150), the second dim is the colour
                channels C, followed by H and W (42x42). For inference, this is 
                the full size image of shape [3, H, W].
            2) y: The label for the crop. 1 = a fixation point, 0 = a
                non-fixation point. -1 = Unlabelled i.e. val and test
            3) file: The file name the crops were extracted from.
            
        If the dataset belongs to val or test, there are 4 additional keys:
            1) X_400: The image resized to 400x400
            2) X_250: The image resized to 250x250
            3) X_150: The image resized to 150x150
            4) spatial_coords: The centre coordinates of all 50x50 (2500) crops
            
        These additional keys help to load the different scales within the
        dataloader itself in a timely manner. Precomputing all crops requires too
        much storage for the lab machines, and resizing/cropping on the fly
        slows down the dataloader, so this is a happy balance.
        Args:
            dataset_path (str): Path to train/val/test.pth.tar
        """
        self.dataset = torch.load(dataset_path, weights_only=True)
        self.mode = 'train' if 'train' in dataset_path else 'inference'
        self.num_crops = 2500 if self.mode == 'inference' else 1

    def __getitem__(self, index) -> Tuple[Tensor, int]:
        """
        Given the index from the DataLoader, return the image crop(s) and label
        Args:
            index (int): the dataset index provided by the PyTorch DataLoader.
        Returns:
            Tuple[Tensor, int]: A two-element tuple consisting of: 
                1) img (Tensor): The image crop of shape [3, 3, 42, 42]. The 
                first dim represents the crop across each different scale
                (400x400, 250x250, 150x150), the second dim is the colour
                channels C, followed by H and W (42x42).
                2) label (int): The label for this crop. 1 = a fixation point, 
                0 = a non-fixation point. -1 = Unlabelled i.e. val and test.
        """
        sample_index = index // self.num_crops
        
        img = self.dataset[sample_index]['X']
        
        # Inference crops are not precomputed due to file size, do here instead
        if self.mode == 'inference': 
            _, H, W = img.shape
            crop_index = index % self.num_crops
            crop_y, crop_x = self.dataset[sample_index]['spatial_coords'][crop_index]
            scales = []
            for size in ['X_400', 'X_250', 'X_150']:
                scaled_img = self.dataset[sample_index][size]
                y_ratio, x_ratio = scaled_img.shape[1] / H, scaled_img.shape[2] / W
                
                # Need to rescale the crops central coordinate.
                scaled_coords = (int(y_ratio * crop_y), int(x_ratio * crop_x))
                crops = crop_to_region(scaled_coords, scaled_img)
                scales.append(crops)
            img = torch.stack(scales, axis=1)
            
        label = self.dataset[sample_index]['y']

        return img, label

    def __len__(self):
        """
        Returns the length of the dataset (length of the list of dictionaries * number
        of crops). 
        __len()__ always needs to be defined so that the DataLoader
            can create the batches
        Returns:
            len(self.dataset) (int): the length of the list of dictionaries * number of
            crops.
        """
        return len(self.dataset) * self.num_crops


trainingdata = MIT("data/train_data.pth.tar")
testingdata = MIT("data/test_data.pth.tar")
valdata = MIT("data/val_data.pth.tar")


# each element in self.dataset dictionary which has three components so just get X and y component (so X component 3x3x42x42 and y is label) -> inputs to the CNN


print(testingdata.dataset[0]['y'])
print(testingdata.dataset[0]['X'])

-1
tensor([[[-0.6191, -0.6474, -0.6616,  ..., -0.3076, -0.3784, -0.4351],
         [-0.6757, -0.6899, -0.6899,  ..., -0.3360, -0.3784, -0.4351],
         [-0.6757, -0.6757, -0.6757,  ..., -0.3360, -0.3643, -0.3784],
         ...,
         [-0.6191, -0.6757, -0.7465,  ..., -0.7040, -0.7182, -0.7324],
         [-0.6616, -0.6899, -0.7182,  ..., -0.7040, -0.7182, -0.7324],
         [-0.7182, -0.7040, -0.6899,  ..., -0.7040, -0.7182, -0.7324]],

        [[-0.7122, -0.7410, -0.7554,  ..., -0.4238, -0.4959, -0.5536],
         [-0.7698, -0.7842, -0.7842,  ..., -0.4526, -0.4959, -0.5536],
         [-0.7698, -0.7698, -0.7698,  ..., -0.4526, -0.4815, -0.4959],
         ...,
         [-0.6112, -0.6689, -0.7698,  ..., -0.7554, -0.7698, -0.7842],
         [-0.6545, -0.6833, -0.7410,  ..., -0.7554, -0.7698, -0.7842],
         [-0.7122, -0.6977, -0.7122,  ..., -0.7554, -0.7698, -0.7842]],

        [[-0.7344, -0.7617, -0.7754,  ..., -0.4611, -0.5294, -0.5841],
         [-0.7890, -0.8027, -0.8027,  ...,

In [8]:
from torch.utils.data import Dataset, DataLoader

class NormalisedDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx] # Stack X tensors and extract y values
    
X_train = torch.stack([sample['X'] for sample in trainingdata.dataset])
y_train = torch.tensor([sample['y'] for sample in trainingdata.dataset])

# Calculate mean and std for normalization
mean_train = X_train.view(X_train.size(0), X_train.size(1), -1).mean(dim=(0, 2))
std_train = X_train.view(X_train.size(0), X_train.size(1), -1).std(dim=(0, 2))

# Normalize X_train
normalised_X_train = (X_train - mean_train[None, :, None, None]) / std_train[None, :, None, None]

# Create the custom dataset
normalised_trainingdata = NormalisedDataset(normalised_X_train, y_train)

# Pass the dataset into DataLoader
train_loader = DataLoader(normalised_trainingdata, batch_size=128, shuffle=True)
print('all working')

<built-in method type of Tensor object at 0x7f9148dd8950>
all working


In [2]:
training_data_augmented = trainingdata.dataset * 2
for i in range(len(trainingdata.dataset), len(training_data_augmented)):
    training_data_augmented[i]['X'] = torch.flip(training_data_augmented[i]['X'], dims=[3])
    
print('done')

done


In [7]:
import numpy as np
import statistics

red_channel = ([sample['X'][:, 0, :, :] for sample in training_data_augmented])
blue_channel = ([sample['X'][:, 1, :, :] for sample in training_data_augmented])
green_channel = ([sample['X'][:, 2, :, :] for sample in training_data_augmented])

normalised_training_data = training_data_augmented

# red channel is a list of 3x42x42 i.e each resolution 42x42 image only the red channel, likewise with the blue and green channel
# so want the mean of the red channels, blue and green channel

def channel_mean(x):
    return (torch.stack(x).mean().item())

def channel_std(x):
    return (torch.stack(x).std().item())

# def channel_mean(x):
#     channel_sum = 0 # sum is 0 to begin with
#     for i in range(0, len(x)): # go through each item in the list
#         item = x[i] # extract 3x42x42 list
#         for j in range(0, 3): # find the sum of each element of the 3x42x42 list and increment to the running sum
#             for k in range(0, 42):
#                 for l in range(0,42):
#                     channel_sum += item[j][k][l].item()
                    
#     return (channel_sum / (len(x)*3*42*42)) # mean of x i.e. total of all values / number of values 

red_mean = channel_mean(red_channel)
blue_mean = channel_mean(blue_channel)
green_mean = channel_mean(green_channel)

red_std = channel_std(red_channel)
blue_std = channel_std(blue_channel)
green_std = channel_std(green_channel)

print(red_mean, blue_mean, green_mean, red_std, blue_std, green_std)

for i in range(0, len(normalised_training_data)): # goes through the current list
    training_data_X = normalised_training_data[i]['X'] # gets the 3x3x42x42 X
    training_data_X[:,0,:,:] = (training_data_X[:,0,:,:] - red_mean) / red_std # gets red channel values and normalises them
    training_data_X[:, 1, :,:] = (training_data_X[:,1,:,:] - blue_mean) / blue_std
    training_data_X[:, 2, :,:] = (training_data_X[:,2,:,:] - green_mean) / green_std
    
    normalised_training_data[i]['X'] = training_data_X

print('hello')

0.019924132153391838 -0.003527475520968437 -0.011872995644807816 0.9579979777336121 0.948402464389801 0.9503791928291321
hello


In [45]:
testingdata.dataset[1]['X'].shape

torch.Size([3, 705, 1024])