In [None]:
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import os
import numpy as np
from glob import glob
import sklearn
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt

In [None]:
class PASTISSegmentation(Dataset):
    """
    Here we use a subset of the PASTIS dataset: https://github.com/VSainteuf/pastis-benchmark
    """
    def __init__(
        self,
        image_dir: str,
        annotation_dir: str,
        split:str = "train",
        median_of_days: bool = False,
        Xmean = None,
        Xstd = None,
        binary_labels: bool = False,
        normalize:bool = True,
        transform = None
    ) -> None:
        self.split = split
        self.transform = transform
        images = glob(os.path.join(image_dir, split, 'S2_*.npy'))
        annotations = []
        for im in images:
            name = os.path.splitext(os.path.basename(im))[0].replace("S2_", "")
            annotations.append(os.path.join(annotation_dir, split, f"TARGET_{name}.npy"))

        # Store in the class for future reference
        self.median_of_days = median_of_days
        self.binary_labels = binary_labels
        
        
        # Load data
        self.X = self.read_data(images)
        norm_dims = (0,1,3,4)
        if median_of_days:
            self.X = np.median(self.X, axis=1) #Take median value across 43 days
            norm_dims = (0, 2, 3)
        # Normalize the data if the normalization values are provided
        if Xmean is not None and Xstd is not None:
            self.X = (self.X - Xmean) / Xstd
        
        self.y = self.read_data(annotations)
        self.y = self.y[:,0] # We are only interested in the 20 classes for now
        if binary_labels:
            self.y[self.y>0] = 1 # Convert to binary labels
        self.x_pixel, self.y_pixel = self.pixelwise()
        self.pixelwise_test()
        
    def __len__(self):
        return self.x_pixel.shape[0]
    
    def read_data(self, files):
        """
        Reads and stacks our data
        """
        t = []
        for im in files:
#             print(self.split, np.load(im).shape)
            t.append(np.load(im))
        return np.stack(t, axis=0)
        
    def pixelwise(self):
        """
        This method flattens our images to individual pixels, so we can treat
        each pixel as a sample and train our favirote classifier on it. 
        """
        if self.median_of_days:
            return  np.transpose(self.X, (0, 2, 3, 1)).reshape(-1, 10), self.y.reshape(-1)
        else:
            return  np.transpose(self.X, (0, 3, 4, 1, 2)).reshape(-1, 430), self.y.reshape(-1)
    
    def pixelwise_test(self):
        """
        Short test for the above method.
        """
        N = self.X.shape[0]
        
        pX, py = self.pixelwise()
        if self.median_of_days:
            tX = np.transpose(pX.reshape(N, 128, 128, 10), (0, 3, 1, 2))    
        else:
            tX = np.transpose(pX.reshape(N, 128, 128, 43, 10), (0, 3, 4, 1, 2))
        tY = py.reshape(N, 128, 128)
        assert np.all(tX == self.X) and np.all(tY == self.y)
        print("All test passed!")


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if self.binary_labels:
            sample = {'X': torch.FloatTensor(self.x_pixel[idx]), 'y': torch.FloatTensor([self.y_pixel[idx]])} # 'y': torch.FloatTensor([self.y_pixel[idx]])}
        else:
            sample = {'X': torch.FloatTensor(self.x_pixel[idx]), 'y': torch.LongTensor([self.y_pixel[idx]])}
        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
base_path = "../" # Define it

Xmean = np.array([ 596.57817383, 878.493514, 969.89764811, 1324.39628906, 2368.21767578, 2715.68257243, 2886.70323486, 2977.03915609, 2158.25386556, 1462.10965169])
Xmean = Xmean.reshape((1, 10, 1, 1))
Xstd = np.array([251.33337853, 289.95055489, 438.725014, 398.7289996, 706.53781626, 832.72503267, 898.14189979, 909.04165075, 661.66078257, 529.15340992])
Xstd = Xstd.reshape((1, 10, 1, 1))

p_train = PASTISSegmentation(os.path.join(base_path, "data", "images"),
                             os.path.join(base_path, "data", "annotations"),
                             split="train",
                             median_of_days=True,
                             Xmean=Xmean,
                             Xstd=Xstd,
                             binary_labels=False, 
                             transform = None)

p_val = PASTISSegmentation(os.path.join(base_path, "data", "images"),
                             os.path.join(base_path, "data", "annotations"),
                            split="val",
                            median_of_days=True,
                            Xmean=Xmean,
                            Xstd=Xstd,
                            binary_labels=False, 
                            transform = None)

p_test = PASTISSegmentation(os.path.join(base_path, "data", "images"),
                             os.path.join(base_path, "data", "annotations"),
                            split="test",
                            median_of_days=True,
                            Xmean=Xmean,
                            Xstd=Xstd,
                            binary_labels=False, 
                            transform = None)

In [None]:
# Create a dataloader from the dataset
# Dataloader gives us the possibility to sample a mini-batches instead of only a single sample
BATCH_SIZE = 4096
train_dataloader = DataLoader(p_train, batch_size=BATCH_SIZE,
                        shuffle=True, num_workers=0)

val_dataloader = DataLoader(p_val, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=0)

test_dataloader = DataLoader(p_test, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=0)


In [None]:
for d in train_dataloader:
    print(d['X'].shape, d['y'].shape, )
    print(d['X'].dtype, d['y'].dtype, )
    break

In [None]:
imd = 5
# Show the 3rd band of the third image
plt.imshow(p_train.X[imd, 3])
plt.show()

# Show the labels for third image
plt.imshow(p_train.y[imd])
plt.colorbar()
plt.show()

# Neural Network

In [None]:
# Define our neural network

class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, hidden_activations=None, output_activation=None):
        super(MLP, self).__init__()
        
        # Create a list of fully connected layers
        layers = []
        layer_sizes = [input_size] + hidden_sizes
        
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            if hidden_activations is not None:
                layers.append(hidden_activations)  # Apply activation function on the hidden layers
        
        # Add the output layer
        layers.append(nn.Linear(layer_sizes[-1], output_size))
        if output_activation is not None:
                layers.append(output_activation)  # Apply activation function on the output layer
         
        # Combine all layers into a sequential model
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [None]:
# Define the evaluation loop
def eval_loop(model, val_loader, criterion):
#     print(f"Validating using the val_loader")
    epoch_loss_val = []
    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            x, y_true = batch['X'].to(dtype=torch.float32), batch['y']
            y_true = torch.squeeze(y_true)
            y_pred = model(x)

            ### Calcualte loss
            loss = criterion(y_pred, y_true)
            epoch_loss_val.append(loss.item())
    el = torch.mean(torch.FloatTensor(epoch_loss_val))
    model.train()
    return el

# Define the training loop
def train_loop(model, train_loader, val_loader, optimizer, criterion, epochs=50):
    train_loss = []
    val_loss = []
    for e in range(epochs):
        epoch_loss_train = []
        model.train()
        for batch_idx, batch in enumerate(train_loader):
            x, y_true = batch['X'].to(dtype=torch.float32), batch['y']
            y_true = torch.squeeze(y_true)
            y_pred = model(x)

            ### Calcualte loss
            loss = criterion(y_pred, y_true)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            epoch_loss_train.append(loss.item())

        el = torch.mean(torch.FloatTensor(epoch_loss_train))
        print(f"Train loss for epoch {e}: {el}")
        train_loss.append(el)
        
        vel = eval_loop(model, val_loader, criterion)
        print(f"Validation loss for epoch {e}: {vel}")
        val_loss.append(vel)
        
    return model, train_loss, val_loss

In [None]:
# Define our first model
input_size = 10
hidden_sizes = [20, 30, 40]
output_size = 20

model = MLP(input_size, hidden_sizes, output_size, hidden_activations=nn.ReLU(), output_activation=None)

# Define the optimizer, the loss function and 
lr = 0.0001 # The learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # Optimizer calculates the gradients and use it to update the model weights. 
criterion = nn.CrossEntropyLoss(reduction='mean')

In [None]:
# Train the model
start_time = time.time()
epochs = 50
model, train_loss, val_loss = train_loop(model, train_dataloader, val_dataloader,  optimizer, criterion, epochs=epochs)
print(f"Trained for {epochs} epochs in {time.time()-start_time} seconds")

In [None]:
# Define our second model
input_size = 10
hidden_sizes = [2, 2, 2]
output_size = 20

model2 = MLP(input_size, hidden_sizes, output_size, hidden_activations=nn.ReLU(), output_activation=None)
print(model2)

# Define the optimizer, the loss function and 
lr = 0.0001 # The learning rate
optimizer = torch.optim.AdamW(model2.parameters(), lr=lr) # Optimizer calculates the gradients and use it to update the model weights. 
criterion = nn.CrossEntropyLoss(reduction='mean')

In [None]:
# Train the model2
start_time = time.time()
epochs = 50
model2, train_loss, val_loss = train_loop(model2, train_dataloader, val_dataloader,  optimizer, criterion, epochs=epochs)
print(f"Trained for {epochs} epochs in {time.time()-start_time} seconds")

In [None]:
#Plot the training and validation loss as a function of epoch for the two models.

In [None]:
#Calculate accuracy, precision, and recall on the test set (or using test_dataloader).


In [None]:
#Comment on the accuracy, recall, and precision of the two MLP models.