# Final Assessment Scratch Pad

## Instructions

1. Please use only this Jupyter notebook to work on your model, and **do not use any extra files**. If you need to define helper classes or functions, feel free to do so in this notebook.
2. This template is intended to be general, but it may not cover every use case. The sections are given so that it will be easier for us to grade your submission. If your specific use case isn't addressed, **you may add new Markdown or code blocks to this notebook**. However, please **don't delete any existing blocks**.
3. If you don't think a particular section of this template is necessary for your work, **you may skip it**. Be sure to explain clearly why you decided to do so.

## Report

##### Overview

##### 1. Descriptive Analysis
I plot the first 10 images to get an intuitive sense of what the images are.

##### 2. Detection and Handling of Missing Values
The next step is to understand what is the nature of the labels, namely:
1. How many nans are there?
2. What is the make up of each labels as a percentage of the whole dataset? This will affect whether or not under/over sampling is used (or other strategies)

Given that about 10% of true labels were NaNs, these images were excluded from the dataset to avoid misleading the model. However, since all dataset samples contained some NaNs in RGB values, rather than discarding these samples, NaNs were replaced with zeros to maintain neutrality and data integrity.

##### 3. Detection and Handling of Outliers
I print out the min and max of the values, which contain -10000 and 10000, that is weird since based on contextual knowledge, RGB values lie between 0 and 255. To prob further, I printed out the percentage of values in images that are between 0 and 255, which turns out to be around 98%. This confirms my suspicion that the values are supposed to be between 0 and 255.

I cap the values to 0 - 255 instead of replacing it with 0 or mean etc. This is so that I still retain some information.

##### 4. Detection and Handling of Class Imbalance 
**[TODO]**
By plotting the graph, label 1 outweighs the other classes.

Here are some of the potential strategies I considered:
1. Oversample the minority clss (label 1 and 2)
2. Undersample the majority class (label 0)
3. Feature engineer to add more variations of the minority class to dataset
4. Customize loss function to add weights to minority class
5. Random sample during training

I have tried all 5, for undersampling, I have tried duplicating samples with label 1 and 2 to a larger proportion, to equal proportion, to more proportion than label 0.
I have also tried customizing the loss function to add weights to the minority class, either through am algorithm (based on the proportion of samples they have), or hardcoded based on empirical iterations.  
The same goes to trying to customize the loss function. The goal is to "punish" the model more if it predicts 1 or 2 wrongly, since the model has a tendency to just predict 0 for everything.
In the end, I found that using random sampling during the batching stage to tbe the most consistent and stable one. Thus I do away with the rest of the strategies.

##### 5. Understanding Relationship Between Variables
Since this is an image data with RGB properties, there is not much to unpack for the relationship due to our contextual knowledge of what RGB entails. I do not intend to compress the image so there is not much need to study the relationship between the variables.
Tried to investigate the mean max of the RGB values but there is not much gain.
##### 6. Data Visualization
I tried to plot a few images for each of the labels side by side to see if I can identify some differences between images of the different labels through eye power, which would help in the contextual udnerstanding bit. Unfortunately, they all look similar to the naked eye.
Furthermore, I also tried to visualize the RGB values in graphs but I don't think I could get any meaningful information out of it.
##### 7. General Preprocessing
 After dealing with the nan values and capping to 0-255, I also normalize the RGB values by dividing by 255 to make the gradient descent smoother.
##### 8. Feature Selection 
Initially when I thought the red pixels are the important ones and the rest are noise, I tried to mask all B and G to 0 and only keep the R. However, I realize that there is not any noticible enhancement to the model, which I assume that the neural network has already take into account. Thus in the end, there is no feature selection.
##### 9. Feature Engineering
Engineering features for the minority class (like flips and rotations) didn't significantly enhance performance, leading to a decision against aggressive feature engineering.

##### 10. Creating Models
I start off with around 2 convolution layers, with standard industrial practices (pooling after convolution, using relu, kernel size etc.). For the filters, I have experimented with small filters (4, 8), and slowly incremented it to larger ones by power of 2 (32, 64, 128, 256, 512). I realize that the higher filter numbers did not result in better performance, but it increases the training time by a lot. Later on, I realize that the model is also prone to overfitting, which is why it is unable to predict many of the 1s and 2s correctly. Thus I proceeded to add dropout layers internally. The values for the dropouts are selected via hyperparameter search.  
For the optimizer, I choose to use the industrial reccomendation, without much testing involved. The batch size is chosen arbitrarily.  
Throughout the test, I have tweaked the layers multiple times, either by removing and adding layers, or by tweaking the features of the layers.

##### 11. Model Evaluation
Focus was on the proportion of predictions for each label, aiming for a distribution close to the actual label distribution. The accuracy of predictions for each label was also scrutinized. These aspects influenced the weight assignment in the loss function.
##### 12. Hyperparameters Search
Optimization included learning rate, dropout rates, and epochs. Averaging f1 scores over multiple random train-test splits was used to minimize variance and identify robust parameters. 

Lastly, I also want to investigate with random sampling, would feature engineering improve the model.

##### Conclusion
In the end, I select the parameters that give me the highest average f1 score, and I realizes that the performance does not change even without oversampling and feature engineering. So some of the functions left in are unused by the end model.
https://chat.openai.com/share/522e7233-6f35-41b7-b2fa-3e626972054f


---

# Workings (Not Graded)

You will do your working below. Note that anything below this section will not be graded, but we might counter-check what you wrote in the report above with your workings to make sure that you actually did what you claimed to have done. 

## Import Packages

Here, we import some packages necessary to run this notebook. In addition, you may import other packages as well. Do note that when submitting your model, you may only use packages that are available in Coursemology (see `main.ipynb`).

In [None]:
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split

## Load Dataset

The dataset `data/images.npy` is of size $(N, C, H, W)$, where $N$, $C$, $H$, and $W$ correspond to the number of data, image channels, image width, and image height, respectively.

A code snippet that loads the data is provided below.

### Load Image Data

In [None]:
with open('data.npy', 'rb') as f:
    data = np.load(f, allow_pickle=True).item()
    images = data['image']
    labels = data['label']
    
print('Shape:', images.shape)

In [None]:
# Create a figure with subplots
plt.figure(figsize=(15, 15))  # Adjust the size as needed

# Loop through the first 10 images
for i in range(10):
    # Access the image
    image = images[i]

    # Convert to uint8
    image = np.array(image, dtype='uint8')

    # Rearrange the axes from [channels, height, width] to [height, width, channels]
    image = np.transpose(image, (1, 2, 0))

    # Plot the image
    plt.subplot(2, 5, i + 1)  # Adjust the layout (rows, columns, index) as needed
    plt.imshow(image)
    plt.title(f"Label: {labels[i]}")
    plt.axis('off')

# Display the plot
plt.show()

In [None]:

# count the number of nans in labels
nan_count = 0
for label in labels:
    if np.isnan(label):
        nan_count += 1

total_count = len(labels)

print('NaN count:', nan_count)
# print nan count percentage
print('NaN count percentage:', nan_count / len(labels) * 100, "%")

# remove nans and plot the label count
labels = labels[~np.isnan(labels)]
label_count = {}
for label in labels:
    if label not in label_count:
        label_count[label] = 0
    label_count[label] += 1

bars = plt.bar(label_count.keys(), label_count.values())

plt.title('Label Count')
plt.xlabel('Label')
plt.ylabel('Count')

# Add the percentage to each bar
for bar in bars:
    height = bar.get_height()
    percentage = f'{100 * height / total_count:.2f}%'
    plt.text(bar.get_x() + bar.get_width() / 2, height, percentage, ha='center', va='bottom')

plt.show()


## Data Exploration & Preparation

### 1. Descriptive Analysis

### 2. Detection and Handling of Missing Values

In [None]:
print('NaN count percentage:', nan_count / len(images) * 100, "%") # arund 100%

# remove images where label is nan
images = images[~np.isnan(labels)]
labels = labels[~np.isnan(labels)]

print('Shape:', images.shape)

In [None]:
# count the max nan values in an image
max_nan_count = 0
for image in images:
    nan_count = np.isnan(image).sum()
    if nan_count > max_nan_count:
        max_nan_count = nan_count

print('Max NaN count:', max_nan_count)
# replace nan values with 0
images = np.nan_to_num(images)


### 3. Detection and Handling of Outliers

In [None]:
# range of values in images
print('Min:', images.min())
print('Max:', images.max())

# count the percentage of values between 0 and 255 in images
print('Percentage of values between 0 and 255:', np.sum((images >= 0) & (images <= 255)) / images.size * 100, "%")

# cap the values between 0 and 255
images = np.clip(images, 0, 255)

# count the percentage of values between 0 and 255 in images
print('Percentage of values between 0 and 255:', np.sum((images >= 0) & (images <= 255)) / images.size * 100, "%")


### 4. Detection and Handling of Class Imbalance

In [None]:
label_1_indices = np.where(labels == 1)[0]
label_1_count = len(label_1_indices)
add_count = 300 - label_1_count
add_indices = np.random.choice(label_1_indices, add_count)
images = np.concatenate((images, images[add_indices]))
labels = np.concatenate((labels, labels[add_indices]))
# oversample the data with label 2 to 300
# get the indices of label 2
label_2_indices = np.where(labels == 2)[0]
# get the number of images with label 2
label_2_count = len(label_2_indices)
# get the number of images to add
add_count = 300 - label_2_count
# get the indices to add
add_indices = np.random.choice(label_2_indices, add_count)
# add the images and labels
images = np.concatenate((images, images[add_indices]))
labels = np.concatenate((labels, labels[add_indices]))
print('New shape:', images.shape)

### 5. Understanding Relationship Between Variables

In [None]:
# for each label 0, 1, 2, print out 5 images with that label, side by side
for label in range(3):
    label_indices = np.where(labels == label)[0]
    for i in range(10):
        image = images[label_indices[i]]
        image = np.array(image, dtype='uint8')
        image = np.transpose(image, (1, 2, 0))
        plt.subplot(2, 5, i + 1)
        plt.imshow(image)
        plt.title(f"Label: {label}")
        plt.axis('off')
    plt.show()

### 6. Data Visualization

In [None]:
# i want to analyze the difference between the images with label 0 and 1
# get the indices of label 0 and 1
label_0_indices = np.where(labels == 0)[0]
label_1_indices = np.where(labels == 1)[0]
# get the images with label 0 and 1
label_0_images = images[label_0_indices]
label_1_images = images[label_1_indices]
# get the mean of each image
label_0_means = np.mean(label_0_images, axis=(1, 2))
label_1_means = np.mean(label_1_images, axis=(1, 2))
# get the max of each image
label_0_maxes = np.max(label_0_images, axis=(1, 2))
label_1_maxes = np.max(label_1_images, axis=(1, 2))

# plot the means
plt.scatter(label_0_means, label_0_maxes, label='Label 0')
plt.scatter(label_1_means, label_1_maxes, label='Label 1')
plt.title('Mean vs Max')
plt.xlabel('Mean')
plt.ylabel('Max')
plt.legend()
plt.show()



In [None]:
print('Shape before:', images.shape)

# replace all B with 0, G with 0
images[:, 1, :, :] = 0
images[:, 2, :, :] = 0

# plot the first 10 images
for i in range(10):
    image = images[i]
    image = np.array(image, dtype='uint8')
    image = np.transpose(image, (1, 2, 0))
    plt.subplot(2, 5, i + 1)
    plt.imshow(image)
    plt.title(f"Label: {labels[i]}")
    plt.axis('off')
plt.show()

## Data Preprocessing

### 7. General Preprocessing

### 8. Feature Selection

### 9. Feature Engineering

## Modeling & Evaluation

### 10. Creating models

In [None]:
from torch import nn
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torchvision import transforms, datasets
from PIL import Image
from torch.utils.data import WeightedRandomSampler

class Model:  
    """
    This class represents an AI model.
    """
    
    def __init__(self, learning_rate=0.0015, dropout=0.15, epochs=20):
        """
        Constructor for Model class.
  
        Parameters
        ----------
        self : object
            The instance of the object passed by Python.
        """
        # initialize neural network sequence
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Flatten(),
            nn.Linear(256, 32),
            nn.ReLU(),
            nn.Linear(32, 3)  
        )

        # initialize hyperparameters
        self.learning_rate = learning_rate
        self.batch_size = 32
        self.epochs = epochs

    def fit(self, X, y):
        """
        Train the model using the input data.
        
        Parameters
        ----------
        X : ndarray of shape (n_samples, channel, height, width)
            Training data.
        y : ndarray of shape (n_samples,)
            Target values.
            
        Returns
        -------
        self : object
            Returns an instance of the trained model.
        """

        X, y = Model.preprocess(X, y)
        #X, y = Model.balance_dataset(X, y)
        #X, y = Model.feature_engineer(X, y)

        

        # Increase the weight of the minority classes more significantly
        #class_weights = torch.tensor([total_count / (len(class_counts) * class_count) for class_count in class_counts])

        # alternative to increasing weight, since we already random sampled the data
        class_weights = torch.tensor([1, 1, 1])

        class_weights = class_weights / class_weights.sum()

        print('Class weights:', class_weights)
        Model.print_class_counts(y)

        # print percentage of each label
        Model.print_label_percentage(y)

        X_tensor = torch.tensor(X, dtype=torch.float32)
        y_tensor = torch.tensor(y, dtype=torch.long)

        # Calculate weights for each sample
        class_sample_counts = torch.tensor([(y_tensor == t).sum() for t in torch.unique(y_tensor, sorted=True)])
        class_weights = 1. / class_sample_counts.float()
        print("🚀 ~ file: scratchpad.ipynb:89 ~ class_weights:", class_weights)
        weights = class_weights[y_tensor.long()]

        # Create a weighted sampler to handle imbalanced classes
        sampler = WeightedRandomSampler(weights, len(weights))


        # Create a dataset and data loader
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=self.batch_size if self.batch_size else len(dataset), sampler=sampler)

        # Define loss function and optimizer for classification
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.cnn.parameters(), lr=self.learning_rate)

        for epoch in range(self.epochs):
            for inputs, targets in dataloader:
                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = self.cnn(inputs)
                loss = criterion(outputs, targets)

                # Backward and optimize
                loss.backward()
                optimizer.step()

            print(f'Epoch {epoch+1}/{self.epochs}, Loss: {loss.item()}')

        return self
    
    def predict(self, X):
        """
        Use the trained model to make predictions.
        
        Parameters
        ----------
        X : ndarray of shape (n_samples, channel, height, width)
            Input data.
            
        Returns
        -------
        ndarray of shape (n_samples,)
        Predicted target values per element in X.
           
        """
        X = Model.preprocess_predict(X)
        X_tensor = torch.tensor(X, dtype=torch.float32)
        dataset = TensorDataset(X_tensor)
        dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        predictions = []
        for inputs in dataloader:
            outputs = self.cnn(inputs[0])
            _, predicted = torch.max(outputs.data, 1)
            predictions += predicted.tolist()

        
        return np.array(predictions)
        
    
    @staticmethod
    def preprocess(images, labels):
        # remove images where label is nan
        images = images[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        
        # replace nan with 0
        images = np.nan_to_num(images)

        # cap min to 0 and max to 255
        images = np.clip(images, 0, 255)

        print('Shape:', images.shape)

        # normalize the images
        images = images / 255.0

        return images, labels
    
    @staticmethod
    def preprocess_predict(images):
        # replace nan values with 0
        images = np.nan_to_num(images)

        # cap min to 0 and max to 255
        images = np.clip(images, 0, 255)
        
        # normalize the images
        images = images / 255.0

        return images

    @staticmethod
    def feature_engineer(images, labels):
        T = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
        ])

        # crete 10 different transformations
        T_list = [
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(20),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(40),
            transforms.RandomInvert(),
            transforms.ColorJitter(brightness=0.5),
            transforms.ColorJitter(contrast=0.5),
            transforms.ColorJitter(saturation=0.5),
            transforms.ColorJitter(hue=0.5),
            transforms.RandomGrayscale(p=0.1),
        ]

        print("Shape before:", images.shape)

        augmented_images = []
        augmented_labels = []

        # get images and labels where label is 1 or 2
        images_to_engineer = images[labels != 0]
        labels_to_engineer = labels[labels != 0]

        print("Number of images to engineer:", len(images_to_engineer))

        for (image, label) in zip(images_to_engineer, labels_to_engineer):
            image = image.transpose(1, 2, 0)  # Convert to HWC format for PIL
            for transform in T_list:
                img_pil = Image.fromarray(image.astype('uint8'), 'RGB')

                # Apply transformation
                augmented_img = transform(img_pil)

                # Convert back to CHW format and append
                augmented_np = np.asarray(augmented_img).transpose(2, 0, 1)
                augmented_images.append(augmented_np)
                augmented_labels.append(label)

        images = np.concatenate((images, augmented_images), axis=0)
        labels = np.concatenate((labels, augmented_labels), axis=0)

        print('Shape:', images.shape)

        return images, labels
    
    @staticmethod
    def balance_dataset(images, labels, min_proportions=[0.1, 0.9]):
        unique_labels, counts = np.unique(labels, return_counts=True)
        total_samples = len(labels)
        
        # Determine minimum count for each label based on proportions
        min_counts = [int(total_samples * p) for p in min_proportions]
        
        # Sort labels by their count (ascending)
        sorted_indices = np.argsort(counts)
        
        for idx, min_count in zip(sorted_indices, min_counts):
            label = unique_labels[idx]
            current_count = counts[idx]
            
            if current_count < min_count:
                # Calculate the number of samples to add
                add_count = min_count - current_count
                
                # Get indices of the current label
                label_indices = np.where(labels == label)[0]
                
                # Randomly select indices to duplicate
                add_indices = np.random.choice(label_indices, add_count)
                
                # Add the images and labels
                images = np.concatenate((images, images[add_indices]))
                labels = np.concatenate((labels, labels[add_indices]))

        return images, labels

    @staticmethod
    def print_label_percentage(y):
        total_count = len(y)
        unique_labels, counts = np.unique(y, return_counts=True)
        for label, count in zip(unique_labels, counts):
            print(f'Label {label}: {count / total_count * 100:.2f}%')

    @staticmethod
    def print_class_counts(y):
        unique_labels, counts = np.unique(y, return_counts=True)
        print('Class counts:', dict(zip(unique_labels, counts)))

### 11. Model Evaluation

In [None]:
# Load data
with open('data.npy', 'rb') as f:
    data = np.load(f, allow_pickle=True).item()
    X = data['image']
    y = data['label']

In [None]:
# Split train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

# Filter test data that contains no labels
# In Coursemology, the test data is guaranteed to have labels
nan_indices = np.argwhere(np.isnan(y_test)).squeeze()
mask = np.ones(y_test.shape, bool)
mask[nan_indices] = False
X_test = X_test[mask]
y_test = y_test[mask]

# Train and predict
model = Model()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

# Evaluate model predition
# Learn more: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
print("F1 Score (macro): {0:.2f}".format(f1_score(y_test, y_pred, average='macro'))) # You may encounter errors, you are expected to figure out what's the issue.

In [None]:
# print first 20 predictions beside ground truth
#for i in range(20):
#    print(f'Prediction: {y_pred[i]}, Ground Truth: {y_test[i]}')

print(f'Total predictions: {len(y_pred)}')
print(f'Total ground truth: {len(y_test)}')

# get the indices where true label is 1
label_1_indices_true = np.where(y_test == 1)[0]
# get the indices where predicted label is 1
label_1_indices_pred = np.where(y_pred == 1)[0]

# get the indices where true label is 2
label_2_indices_true = np.where(y_test == 2)[0]
# get the indices where predicted label is 2
label_2_indices_pred = np.where(y_pred == 2)[0]

# get the indices where true label is 0
label_0_indices_true = np.where(y_test == 0)[0]
# get the indices where predicted label is 0
label_0_indices_pred = np.where(y_pred == 0)[0]

# print number of true and predicted for each label
for label in range(3):
    label_indices_true = np.where(y_test == label)[0]
    label_indices_pred = np.where(y_pred == label)[0]
    print(f'Label {label}: {len(label_indices_true)} true, {len(label_indices_pred)} predicted')

print("=====================================")

# out of the images with label 0, how many did we predict correctly
correct_count = 0
for label_0_index in label_0_indices_true:
    if label_0_index in label_0_indices_pred:
        correct_count += 1
print(f'Label 0: {correct_count} correct out of {len(label_0_indices_true)}')

# out of the images with label 1, how many did we predict correctly
correct_count = 0
for label_1_index in label_1_indices_true:
    if label_1_index in label_1_indices_pred:
        correct_count += 1
print(f'Label 1: {correct_count} correct out of {len(label_1_indices_true)}')

# out of the images with label 2, how many did we predict correctly
correct_count = 0
for label_2_index in label_2_indices_true:
    if label_2_index in label_2_indices_pred:
        correct_count += 1
print(f'Label 2: {correct_count} correct out of {len(label_2_indices_true)}')

# Convert y_test and y_pred to binary format (1 for label '1' and 0 for all other labels)
y_test_binary = (y_test == 1).astype(int)
y_pred_binary = (y_pred == 1).astype(int)

# Calculate Precision, Recall, and F1 Score for label '1' (binary classification)
precision = precision_score(y_test_binary, y_pred_binary, pos_label=1)
recall = recall_score(y_test_binary, y_pred_binary, pos_label=1)
f1 = f1_score(y_test_binary, y_pred_binary, pos_label=1)

print(f"Precision for label 1: {precision:.2f}")
print(f"Recall for label 1: {recall:.2f}")
print(f"F1 Score for label 1: {f1:.2f}")



In [None]:
from sklearn.metrics import f1_score

def evaluate_model_f1(model, X_val, y_val):
    """
    Evaluate the model on the validation set using F1 score.

    Parameters:
    model (torch.nn.Module): The trained model.
    X_val (np.ndarray): Validation input data.
    y_val (np.ndarray): Validation target data.
    average (str): The averaging method for F1 score calculation.

    Returns:
    float: F1 score of the model on the validation set.
    """

    with torch.no_grad():  # No need to track gradients for validation
        for i in range(len(X_val)):
            predictions = model.predict(X_val)
    f1 = f1_score(y_val, predictions, average='macro')
    return f1


### 12. Hyperparameters Search

In [None]:
from statistics import mean

# Define ranges for hyperparameters
random_state = [42, 43, 50, 51, 52, 53]
#learning_rates = [0.001, 0.002, 0.003, 0.004, 0.005]
#epochs = [20, 25, 30, 35, 40]
#dropouts = [0.2, 0.25, 0.3, 0.35]

learning_rates = [0.0015, 0.002, 0.0025]
epochs = [20, 25, 30]
dropouts = [0.15, 0.2, 0.25]

best_avg_f1 = 0
best_params = {}
f1_scores = {}

# Iterate over each combination of hyperparameters
for lr in learning_rates:
    for epoch in epochs:
        for dropout in dropouts:
            current_param_key = f"LR: {lr}, Epochs: {epoch}, Dropout: {dropout}"
            f1_scores[current_param_key] = []

            # Iterate over each random state
            for rs in random_state:
                X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=rs)
                nan_indices = np.argwhere(np.isnan(y_test)).squeeze()
                mask = np.ones(y_test.shape, bool)
                mask[nan_indices] = False
                X_test = X_test[mask]
                y_test = y_test[mask]

                print(f'Random State: {rs}, Learning Rate: {lr}, Epochs: {epoch}, Dropout: {dropout}')
                # Create a new instance of the model with current parameters
                model = Model(learning_rate=lr, dropout=dropout)
                
                # Train the model (you might need to modify the fit method to accept batch size)
                model.fit(X_train, y_train)

                # Evaluate the model
                f1 = evaluate_model_f1(model, X_test, y_test)

                # Store the F1 score for the current parameter combination and random state
                f1_scores[current_param_key].append(f1)

# Average the F1 scores for each parameter combination and find the best parameters
for params, scores in f1_scores.items():
    avg_f1 = mean(scores)
    if avg_f1 > best_avg_f1:
        best_avg_f1 = avg_f1
        best_params = params

print("Best params:", best_params)


In [None]:
# random state from 1 to 30
random_states = range(1, 31)
f1_scores = []
for r in random_states:
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=r)
    nan_indices = np.argwhere(np.isnan(y_test)).squeeze()
    mask = np.ones(y_test.shape, bool)
    mask[nan_indices] = False
    X_test = X_test[mask]
    y_test = y_test[mask]
    model = Model()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    f1_scores.append(f1_score(y_test, y_pred, average='macro'))

print(f'Average F1 Score: {np.mean(f1_scores):.2f}')
print(f'Nax F1 Score: {np.max(f1_scores):.2f}')
print(f'Min F1 Score: {np.min(f1_scores):.2f}')
# print percentage that is greater than 0.72
print(f'Percentage greater than 0.72: {np.sum(np.array(f1_scores) > 0.72) / len(f1_scores) * 100:.2f}%')

In [None]:
print(f1_scores)
# dropout 0.2, epochs 25, lr 0.002
# dropout 0.15, epochs 20, lr 0.0015
# {'LR: 0.001, Epochs: 20, Dropout: 0.2': [0.4290123456790123, 0.42907157335018026, 0.6346958087091563, 0.6779798086116009, 0.7363465160075329, 0.7902589688762176], 'LR: 0.001, Epochs: 20, Dropout: 0.25': [0.47154225085259566, 0.6412891880419623, 0.6304584561760517, 0.6519717999559375, 0.6439074375304421, 0.7298374607791648], 'LR: 0.001, Epochs: 20, Dropout: 0.3': [0.45526269942009506, 0.528621739188989, 0.625323487600325, 0.6832603689262831, 0.7066936249731949, 0.6547142337095012], 'LR: 0.001, Epochs: 20, Dropout: 0.35': [0.3933808612994841, 0.6161857401470043, 0.5917379515030029, 0.6657403365827634, 0.7513366066557556, 0.7004013042387761], 'LR: 0.001, Epochs: 25, Dropout: 0.2': [0.4699940582293523, 0.6074631402948776, 0.698690880239785, 0.7299043062200957, 0.7939065273121674, 0.8005701615457713], 'LR: 0.001, Epochs: 25, Dropout: 0.25': [0.47502128477657846, 0.5897210531356873, 0.6139634831687811, 0.6823791125160987, 0.6796211162041206, 0.7139925481252706], 'LR: 0.001, Epochs: 25, Dropout: 0.3': [0.4609870643486858, 0.5734053574071334, 0.6288562467997951, 0.7268431983385254, 0.682183908045977, 0.6206733048838312], 'LR: 0.001, Epochs: 25, Dropout: 0.35': [0.43531202435312033, 0.5063236047107015, 0.6490366095298729, 0.6737299604683117, 0.7910197598473264, 0.7333333333333334], 'LR: 0.001, Epochs: 30, Dropout: 0.2': [0.4670349621330013, 0.6512357618740597, 0.6625966625966626, 0.73970680084493, 0.7930228975005095, 0.6777595538789569], 'LR: 0.001, Epochs: 30, Dropout: 0.25': [0.47154225085259566, 0.7291524653487844, 0.7160007297938332, 0.7289321789321789, 0.771820541342232, 0.6836806993328732], 'LR: 0.001, Epochs: 30, Dropout: 0.3': [0.46741423040982516, 0.6059964018615163, 0.568103818103818, 0.6982045277127243, 0.7331083585364407, 0.7522820800598579], 'LR: 0.001, Epochs: 30, Dropout: 0.35': [0.4318181818181818, 0.5076194192074021, 0.6224886224886225, 0.6816326530612246, 0.7003174603174602, 0.6638579889144223], 'LR: 0.001, Epochs: 35, Dropout: 0.2': [0.4742441132350715, 0.5733743214800043, 0.6571169394596726, 0.669226830517153, 0.7217993915631555, 0.7730425017077095], 'LR: 0.001, Epochs: 35, Dropout: 0.25': [0.4348845598845599, 0.5634108527131784, 0.6277380968764478, 0.6810950300652818, 0.7684089082820447, 0.7044858523119393], 'LR: 0.001, Epochs: 35, Dropout: 0.3': [0.4635996659242762, 0.5603024260551678, 0.6452991452991452, 0.5918215321885046, 0.76757139485485, 0.7346556062650095], 'LR: 0.001, Epochs: 35, Dropout: 0.35': [0.4612650515015233, 0.504985754985755, 0.6181882161801976, 0.6712320483749056, 0.6823301793119562, 0.6186868686868687], 'LR: 0.001, Epochs: 40, Dropout: 0.2': [0.4807984270390005, 0.6617211753663995, 0.6274455974670107, 0.7910457516339869, 0.7065918653576437, 0.6971830985915494], 'LR: 0.001, Epochs: 40, Dropout: 0.25': [0.4340722662870313, 0.6215784230637703, 0.7140613847251963, 0.754156954156954, 0.6989992721979621, 0.8085659871012423], 'LR: 0.001, Epochs: 40, Dropout: 0.3': [0.46205623710780586, 0.5759983470453435, 0.6807198463249419, 0.7473544973544973, 0.7942686357243319, 0.6989533630906571], 'LR: 0.001, Epochs: 40, Dropout: 0.35': [0.42065464183744067, 0.5468202543796497, 0.6176086504470784, 0.6802671523982999, 0.768098568098568, 0.6350665310619543], 'LR: 0.002, Epochs: 20, Dropout: 0.2': [0.4406018518518519, 0.6939968048720341, 0.7732703844701879, 0.65441400304414, 0.8016664339244984, 0.7492753623188405], 'LR: 0.002, Epochs: 20, Dropout: 0.25': [0.46724286302241175, 0.6041938287701, 0.6789865871833086, 0.7702506153501902, 0.6503641793863842, 0.6681547619047619], 'LR: 0.002, Epochs: 20, Dropout: 0.3': [0.482078853046595, 0.6095142805166455, 0.7305940406149175, 0.6868946868946869, 0.7004272237048751, 0.7435542431137145], 'LR: 0.002, Epochs: 20, Dropout: 0.35': [0.4224520577461754, 0.602431598766174, 0.6332607116920843, 0.6215780998389694, 0.7090857366380997, 0.7860148032723023], 'LR: 0.002, Epochs: 25, Dropout: 0.2': [0.5040084388185654, 0.6492926284437827, 0.7495726495726496, 0.71011119278472, 0.8089998608361503, 0.7784469721875205], 'LR: 0.002, Epochs: 25, Dropout: 0.25': [0.5271789991276074, 0.6025910364145658, 0.6711089546606286, 0.7187590187590187, 0.7374001452432825, 0.7044858523119393], 'LR: 0.002, Epochs: 25, Dropout: 0.3': [0.4619883040935672, 0.4296296296296296, 0.6278186143957956, 0.68986271725164, 0.7830310010557634, 0.7333517207563359], 'LR: 0.002, Epochs: 25, Dropout: 0.35': [0.433540757694967, 0.38342342342342345, 0.5944018245905038, 0.7114396541904231, 0.7536143077155614, 0.6907020872865276], 'LR: 0.002, Epochs: 30, Dropout: 0.2': [0.47965753542885753, 0.4635008452435838, 0.6931242854809733, 0.7210176991150443, 0.7281873491723599, 0.7044905058150092], 'LR: 0.002, Epochs: 30, Dropout: 0.25': [0.5096251266464032, 0.4624926661637751, 0.6586604884477225, 0.6016726403823177, 0.7931153641679959, 0.6646644824657202], 'LR: 0.002, Epochs: 30, Dropout: 0.3': [0.49157686114207855, 0.6248765066192452, 0.6699404761904763, 0.6801923543881504, 0.6895805344081206, 0.6645173453996983], 'LR: 0.002, Epochs: 30, Dropout: 0.35': [0.4379699248120301, 0.579909906175296, 0.5891383083436063, 0.6853843474533129, 0.8155172413793105, 0.7409337778403843], 'LR: 0.002, Epochs: 35, Dropout: 0.2': [0.48297811878374536, 0.6372759856630824, 0.7248989239285937, 0.749213357160377, 0.749698119717095, 0.6801795685602268], 'LR: 0.002, Epochs: 35, Dropout: 0.25': [0.4270186496915988, 0.449813258636788, 0.6260348583877996, 0.7058919819739865, 0.8653483045639908, 0.7618357487922706], 'LR: 0.002, Epochs: 35, Dropout: 0.3': [0.5168611301761611, 0.5873015873015873, 0.6738679759956355, 0.7109118086696563, 0.7405017921146954, 0.6399342076786564], 'LR: 0.002, Epochs: 35, Dropout: 0.35': [0.42465753424657526, 0.559371250299976, 0.6360750360750361, 0.6432447283467271, 0.6939000745712155, 0.7544027401810925], 'LR: 0.002, Epochs: 40, Dropout: 0.2': [0.50525167523765, 0.6254017342793038, 0.6394938394938395, 0.7565909308791103, 0.8292817104539304, 0.737012987012987], 'LR: 0.002, Epochs: 40, Dropout: 0.25': [0.4311155913978495, 0.6975784898086794, 0.6998998553466117, 0.7571895424836601, 0.7753490394999828, 0.7242445054945055], 'LR: 0.002, Epochs: 40, Dropout: 0.3': [0.45896877269426284, 0.5952294841987428, 0.662917186173, 0.6350488084984208, 0.7175965665236053, 0.7716392625809666], 'LR: 0.002, Epochs: 40, Dropout: 0.35': [0.4285745758977748, 0.613152400835073, 0.5847092778960276, 0.6730457889427454, 0.7618799329325645, 0.7043961507048756], 'LR: 0.003, Epochs: 20, Dropout: 0.2': [0.4648198725451944, 0.675846786957898, 0.6284521198469311, 0.7672771672771672, 0.7848610498831877, 0.6208250910237666], 'LR: 0.003, Epochs: 20, Dropout: 0.25': [0.48526941374101834, 0.6637709936295192, 0.7001307650490002, 0.6828947835659246, 0.671872571872572, 0.6610137855638215], 'LR: 0.003, Epochs: 20, Dropout: 0.3': [0.4488611634992021, 0.6637709936295192, 0.6986111111111111, 0.7525641025641026, 0.7688020168884028, 0.7584325396825395], 'LR: 0.003, Epochs: 20, Dropout: 0.35': [0.4273643879814479, 0.5651464590285502, 0.6913980570031525, 0.65230383043976, 0.6457556935817805, 0.7791054409680095], 'LR: 0.003, Epochs: 25, Dropout: 0.2': [0.45526269942009506, 0.6421326483567147, 0.679766081871345, 0.6981133704056156, 0.8496703885492313, 0.7095489604292422], 'LR: 0.003, Epochs: 25, Dropout: 0.25': [0.45858279651383094, 0.5653235653235653, 0.7238550403107364, 0.7769771955361476, 0.7351795989050891, 0.7298374607791648], 'LR: 0.003, Epochs: 25, Dropout: 0.3': [0.4765654908394859, 0.5468202543796497, 0.6421356421356421, 0.7223250163536118, 0.7287254855926499, 0.680232834402471], 'LR: 0.003, Epochs: 25, Dropout: 0.35': [0.4987346700408799, 0.5960900140646976, 0.602596003723053, 0.6802503716828056, 0.6968512325034064, 0.7002525252525252], 'LR: 0.003, Epochs: 30, Dropout: 0.2': [0.4638151425762045, 0.5875022349365278, 0.6882325363338021, 0.7214232449131778, 0.7671764400285732, 0.6324868745312332], 'LR: 0.003, Epochs: 30, Dropout: 0.25': [0.49007899483734524, 0.6555955715619581, 0.680368149258704, 0.7047886692356218, 0.7266992266992266, 0.7202876293200947], 'LR: 0.003, Epochs: 30, Dropout: 0.3': [0.4772893772893773, 0.5348797602016487, 0.7140670064398877, 0.6761904761904761, 0.6849849849849851, 0.845891422010825], 'LR: 0.003, Epochs: 30, Dropout: 0.35': [0.46751727804359383, 0.5842821339061941, 0.6088336783988958, 0.731454196028187, 0.6850810750154296, 0.7972903951867721], 'LR: 0.003, Epochs: 35, Dropout: 0.2': [0.4848679383712399, 0.5950274938434031, 0.7197069420496752, 0.6910310620487826, 0.6933417530162557, 0.6405753549013163], 'LR: 0.003, Epochs: 35, Dropout: 0.25': [0.4387569031989164, 0.5881370091896407, 0.6752864157119477, 0.7128344671201815, 0.7315423976608186, 0.7821743792790564], 'LR: 0.003, Epochs: 35, Dropout: 0.3': [0.4572043010752688, 0.5318382892153384, 0.6608969154808041, 0.6500184628369468, 0.721125730994152, 0.7861223571749889], 'LR: 0.003, Epochs: 35, Dropout: 0.35': [0.4664516774161292, 0.5141712599018101, 0.6218330089889723, 0.8452325678298035, 0.6899980677507602, 0.7425364758698092], 'LR: 0.003, Epochs: 40, Dropout: 0.2': [0.4757860681427561, 0.4588164251207729, 0.691796157059315, 0.7289321789321789, 0.7954725523486136, 0.7407889409451096], 'LR: 0.003, Epochs: 40, Dropout: 0.25': [0.4386430678466076, 0.4481242487505535, 0.6546919084232518, 0.7085482501385992, 0.7338643790849672, 0.7121157801384985], 'LR: 0.003, Epochs: 40, Dropout: 0.3': [0.4612650515015233, 0.5836036692789016, 0.6512345679012346, 0.7474569146700295, 0.7854345508612468, 0.6402563637462966], 'LR: 0.003, Epochs: 40, Dropout: 0.35': [0.46838443306063154, 0.6112731208592721, 0.6844793250138462, 0.7441999808263828, 0.7976624590930341, 0.5148148148148148], 'LR: 0.004, Epochs: 20, Dropout: 0.2': [0.5203270497388144, 0.5503934057699512, 0.6930521863531753, 0.656127672387835, 0.6509319254459643, 0.7485871935211142], 'LR: 0.004, Epochs: 20, Dropout: 0.25': [0.4646063281824871, 0.633459595959596, 0.7071259709557584, 0.8187937131975463, 0.6258464473056747, 0.7263920099875155], 'LR: 0.004, Epochs: 20, Dropout: 0.3': [0.49157686114207855, 0.5592557201252853, 0.6324469181612038, 0.7270150484436199, 0.7765316924705571, 0.6217506071164608], 'LR: 0.004, Epochs: 20, Dropout: 0.35': [0.46085070597087596, 0.6025910364145658, 0.6404614542545577, 0.7091678011760276, 0.7374929987200702, 0.6960067872125585], 'LR: 0.004, Epochs: 25, Dropout: 0.2': [0.5324792190955541, 0.7062832062832062, 0.687631681623098, 0.7701900683645105, 0.7820149712306574, 0.712739054844318], 'LR: 0.004, Epochs: 25, Dropout: 0.25': [0.49512987012987014, 0.5434962314612205, 0.7140613847251963, 0.7477519272492268, 0.7017632773730335, 0.7735042735042734], 'LR: 0.004, Epochs: 25, Dropout: 0.3': [0.4506884639711964, 0.6706015739629185, 0.5353094926933587, 0.7505912405173442, 0.7905662621936712, 0.6904563892801984], 'LR: 0.004, Epochs: 25, Dropout: 0.35': [0.43047682933646375, 0.7465656031494213, 0.6076409414340448, 0.71011119278472, 0.7153776571687019, 0.6646399062548577], 'LR: 0.004, Epochs: 30, Dropout: 0.2': [0.459366807824158, 0.608457711442786, 0.6955605599673396, 0.790703292986398, 0.7018526743653934, 0.757431813040942], 'LR: 0.004, Epochs: 30, Dropout: 0.25': [0.45800252723934864, 0.6426608251905453, 0.6100000870026708, 0.7189994914823289, 0.6498403072569233, 0.7008064728440674], 'LR: 0.004, Epochs: 30, Dropout: 0.3': [0.45401303307762103, 0.6454490928175138, 0.626858553554396, 0.6780626780626781, 0.781559330709921, 0.704399888610415], 'LR: 0.004, Epochs: 30, Dropout: 0.35': [0.4296535341311461, 0.5104177288331463, 0.6887239744382602, 0.7094158755681993, 0.7257418909592822, 0.6721440999964491], 'LR: 0.004, Epochs: 35, Dropout: 0.2': [0.4812409812409812, 0.6041938287701, 0.7091688818004607, 0.7466049057795557, 0.7527550463720676, 0.7508450154190064], 'LR: 0.004, Epochs: 35, Dropout: 0.25': [0.46548400541835994, 0.6003787878787878, 0.7102453102453102, 0.7085858585858585, 0.7068060574771984, 0.6045843045843046], 'LR: 0.004, Epochs: 35, Dropout: 0.3': [0.43725497291885307, 0.5638263707915449, 0.6769994764055238, 0.6845182347098056, 0.7972949297452608, 0.6294878899728487], 'LR: 0.004, Epochs: 35, Dropout: 0.35': [0.48932522985818916, 0.6041938287701, 0.6059240852780179, 0.6016261882598516, 0.7211458850803112, 0.6301158301158302], 'LR: 0.004, Epochs: 40, Dropout: 0.2': [0.5040084388185654, 0.685368093350355, 0.737705899970051, 0.7176353058706, 0.7483551874856221, 0.733221310665217], 'LR: 0.004, Epochs: 40, Dropout: 0.25': [0.470306673345578, 0.6087538433713148, 0.7541937897493453, 0.7610119047619048, 0.8037264363379014, 0.7584325396825395], 'LR: 0.004, Epochs: 40, Dropout: 0.3': [0.4612650515015233, 0.5277963168953157, 0.6517974562035037, 0.7318090307736554, 0.7185469037400368, 0.8215492277992277], 'LR: 0.004, Epochs: 40, Dropout: 0.35': [0.41058201058201055, 0.7771145378507341, 0.4919958610991258, 0.7252808988764045, 0.6710142026922848, 0.6744380823528465], 'LR: 0.005, Epochs: 20, Dropout: 0.2': [0.49220552151746233, 0.6713214653732722, 0.6676989676989676, 0.6992694166573002, 0.7428571428571429, 0.7236092299056752], 'LR: 0.005, Epochs: 20, Dropout: 0.25': [0.448063198962839, 0.4580799370669571, 0.6485090787939364, 0.7332710190777489, 0.7402218296825223, 0.7263920099875155], 'LR: 0.005, Epochs: 20, Dropout: 0.3': [0.45526269942009506, 0.5790794979079498, 0.603566948089377, 0.7434537125812293, 0.7009804377985308, 0.6236302108233265], 'LR: 0.005, Epochs: 20, Dropout: 0.35': [0.4735811628925471, 0.5670410397683124, 0.659421596921597, 0.7517661296422359, 0.5845052995528229, 0.6474414474414475], 'LR: 0.005, Epochs: 25, Dropout: 0.2': [0.4845089550971904, 0.6323286849602638, 0.7036593228601532, 0.740076411044153, 0.7444468100205804, 0.6468189151038516], 'LR: 0.005, Epochs: 25, Dropout: 0.25': [0.49371662643552044, 0.6984126984126985, 0.6485814062895813, 0.8205715361209572, 0.7030368809721423, 0.7580733442802408], 'LR: 0.005, Epochs: 25, Dropout: 0.3': [0.4966475095785441, 0.648306088604596, 0.6955962135668051, 0.6098620392189601, 0.7812548922671282, 0.8496296296296296], 'LR: 0.005, Epochs: 25, Dropout: 0.35': [0.47389488840892735, 0.5378193065559232, 0.5496291821873217, 0.7702506153501902, 0.7038754311286346, 0.5880139184487011], 'LR: 0.005, Epochs: 30, Dropout: 0.2': [0.45008389261744974, 0.6195700061679443, 0.659881086576929, 0.6987785645814651, 0.8101713378383368, 0.7675641431738992], 'LR: 0.005, Epochs: 30, Dropout: 0.25': [0.482832618025751, 0.5659632475303615, 0.721055441165256, 0.6235897363921111, 0.7047846889952153, 0.7679712981082844], 'LR: 0.005, Epochs: 30, Dropout: 0.3': [0.48579438244649303, 0.5771403737505433, 0.6942245209203635, 0.7963472834067548, 0.6951790402770796, 0.6328181248655876], 'LR: 0.005, Epochs: 30, Dropout: 0.35': [0.41391042627670366, 0.527951095512515, 0.6873991602585271, 0.7094158755681993, 0.7242926155969633, 0.7227278763547956], 'LR: 0.005, Epochs: 35, Dropout: 0.2': [0.4812409812409812, 0.637037037037037, 0.638539887391662, 0.8495972945312152, 0.7774428040121718, 0.7924065769805679], 'LR: 0.005, Epochs: 35, Dropout: 0.25': [0.4765654908394859, 0.7414506875391025, 0.687631681623098, 0.7283034566220407, 0.7142892304230265, 0.7231933561047486], 'LR: 0.005, Epochs: 35, Dropout: 0.3': [0.4822180671237275, 0.5286227246360137, 0.6927236971484759, 0.6638140796653517, 0.7357336349256318, 0.6600125633777538], 'LR: 0.005, Epochs: 35, Dropout: 0.35': [0.44403586426058334, 0.5503934057699512, 0.6229136656102948, 0.668501767013585, 0.7427323883845623, 0.647736777873764], 'LR: 0.005, Epochs: 40, Dropout: 0.2': [0.45418646676929253, 0.40308248800817803, 0.7799494590539368, 0.746518259444812, 0.7452176578786021, 0.7038809144072302], 'LR: 0.005, Epochs: 40, Dropout: 0.25': [0.4608849166942513, 0.5722579377064427, 0.6765211974531399, 0.6498771498771498, 0.7105974354551204, 0.6843106291890374], 'LR: 0.005, Epochs: 40, Dropout: 0.3': [0.4797101449275362, 0.7901066391468868, 0.6793650793650795, 0.6767083753385124, 0.6446110348549373, 0.667888214196939], 'LR: 0.005, Epochs: 40, Dropout: 0.35': [0.35862896157217333, 0.558652495529572, 0.7068710440021353, 0.7278585214069085, 0.6624354494573116, 0.7722385141739979]}
# second