### Task 1: Model Implementation (80% Marks)

Implement your model that you want to submit by completing the following functions:
* `__init__`: The constructor for Model class.
* `fit`: Fit/train the model using the input data. You may perform data handling and preprocessing here before training your model.
* `predict`: Predict using the model. If you perform data handling and preprocessing in the `fit` function, then you may want to do the same here.

#### Dependencies

It is crucial to note that your model may rely on specific versions of Python packages, including:

* Python 3.10
* Numpy version 1.23
* Pandas version 1.4
* Scikit-Learn version 1.1
* PyTorch version 1.12
* Torchvision version 0.13

To prevent any compatibility issues or unexpected errors during the execution of your code, ensure that you are using the correct versions of these packages. You can refer to `environment.yml` for a comprehensive list of packages that are pre-installed in Coursemology and can be used by your model. Note that if you do end up using libraries that are not installed on Coursemology, you might see an error like:

"Your code failed to evaluate correctly. There might be a syntax error, or perhaps execution failed to complete within the allocated time and memory limits."

#### Model Template

Note that you should copy and paste the code below *directly* into Coursemology for submission. You should probably test the code in this notebook on your local machine before uploading to Coursemology and using up an attempt. 

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

class CNN(nn.Module):
    
    def __init__(self, classes, drop_prob):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)  
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)  
        self.fc1 = nn.Linear(64 * 2 * 2, 64)
        self.bn3 = nn.BatchNorm1d(64)  
        self.fc2 = nn.Linear(64, classes)
        self.leaky_relu = nn.LeakyReLU(0.1)
        self.max_pool = nn.MaxPool2d(2, 2)
        self.drop = nn.Dropout2d(drop_prob)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = self.leaky_relu(x)
        x = self.max_pool(x)
        x = self.drop(x)
        x = self.bn2(self.conv2(x))
        x = self.leaky_relu(x)
        x = self.max_pool(x)
        x = self.drop(x)
        x = torch.flatten(x, 1)
        x = self.bn3(self.fc1(x))
        x = self.leaky_relu(x)
        x = self.fc2(x)
        return x
    

class DataLoaderHeler(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label

class Model:  
    """
    This class represents an AI model.
    """
    
    def __init__(self):
        self.model = CNN(classes=3, drop_prob=0.4)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005)
        self.loss_fn = nn.CrossEntropyLoss()
    

    def image_replacing_nan(self, images):
        nan_mask = np.isnan(images)
        means = np.nanmean(images, axis=(2, 3), keepdims=True)  
        images[nan_mask] = np.broadcast_to(means, images.shape)[nan_mask]
        clipped_images = np.clip(images, 0, 255)
        return clipped_images
        
    
    def standardising_images(self, images):
        return images / 255.0
    
    def images_labels_filter_nan(self, images, labels):
        not_nan_indices = ~np.isnan(labels)
        filtered_images = images[not_nan_indices]
        filtered_labels = labels[not_nan_indices]
        return filtered_images, filtered_labels
    

    def data_processing(self, images, labels):
        images_with_no_nan = self.image_replacing_nan(images)
        images_standardized = self.standardising_images(images_with_no_nan)
        processed_images, processed_labels = self.images_labels_filter_nan(images_standardized, labels)
        
        return processed_images, processed_labels

    

    def fit(self, X, y):
        """
        Train the model using the input data.
        
        Parameters
        ----------
        X : ndarray of shape (n_samples, channel, height, width)
            Training data.
        y : ndarray of shape (n_samples,)
            Target values.
            
        Returns
        -------
        self : object
            Returns an instance of the trained model.
        """
        X_processed, y_processed = self.data_processing(X, y) 
        X_processed = torch.tensor(X_processed, dtype=torch.float32)
        y_processed = torch.tensor(y_processed, dtype=torch.long)
        
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=5, translate=(0.1, 0.1))
        ])

        dataset = DataLoaderHeler(X_processed, y_processed, transform=transform)
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
        self.model.train()
        losses = []
        for epoch in range(60):
            for batch_X, batch_y in dataloader:
                self.optimizer.zero_grad()
                output = self.model(batch_X)
                loss = self.loss_fn(output, batch_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
                loss += loss.item()

            loss = loss / len(dataloader)
            losses.append(loss)
            print ("Epoch: {}, Loss: {}".format(epoch, loss))

    def predict(self, X):
        """
        Use the trained model to make predictions.
        
        Parameters
        ----------
        X : ndarray of shape (n_samples, channel, height, width)
            Input data.
            
        Returns
        -------
        ndarray of shape (n_samples,)
        Predicted target values per element in X.
           
        """
        # TODO: Replace the following code with your own prediction code.
        X = self.image_replacing_nan(X)
        X = self.standardising_images(X)
        #print("X with no nan in predict:", X)
        X = torch.tensor(X, dtype=torch.float32) 
        self.model.eval()
        with torch.no_grad():
            predictions = self.model(X)
            #print("Pred in model: ", predictions)
            return torch.argmax(predictions, dim=1)     

#### Local Evaluation

You may test your solution locally by running the following code. Do note that the results may not reflect your performance in Coursemology. You should not be submitting the code below in Coursemology. The code here is meant only for you to do local testing.

In [2]:
# Import packages
import pandas as pd
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split

In [3]:
# Load data
with open('data.npy', 'rb') as f:
    data = np.load(f, allow_pickle=True).item()
    X = data['image']
    y = data['label']

In [None]:
# Split train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

# Filter test data that contains no labels
# In Coursemology, the test data is guaranteed to have labels
nan_indices = np.argwhere(np.isnan(y_test)).squeeze()
mask = np.ones(y_test.shape, bool)
mask[nan_indices] = False
X_test = X_test[mask]
y_test = y_test[mask]
# Train and predict
model = Model()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(y_pred)
# Evaluate model predition
# Learn more: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
print("F1 Score (macro): {0:.2f}".format(f1_score(y_test, y_pred, average='macro'))) # You may encounter errors, you are expected to figure out what's the issue.