In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from pprint import pprint
import json
import random
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
from tqdm import trange, tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!unzip '/content/drive/MyDrive/School/2025 Spring/Advanced ML/AML Project/Data/new/preprocessed_selected_features.zip'

Archive:  /content/drive/MyDrive/School/2025 Spring/Advanced ML/AML Project/Data/new/preprocessed_selected_features.zip
   creating: preprocessed_selected_features/
   creating: preprocessed_selected_features/test/
   creating: preprocessed_selected_features/train/
  inflating: preprocessed_selected_features/test/connectome_matrices.csv  
  inflating: preprocessed_selected_features/test/aux.csv  
  inflating: preprocessed_selected_features/train/labels.csv  
  inflating: preprocessed_selected_features/train/connectome_matrices.csv  
  inflating: preprocessed_selected_features/train/aux.csv  


In [8]:
def compute_leaderboard_f1_multiclass(y_true, y_pred):
    """
    Multiclass version of compute_leaderboard_f1_binary.
    Assumes class encoding:
        0 -> [ADHD=0, Sex_F=0]
        1 -> [ADHD=0, Sex_F=1]
        2 -> [ADHD=1, Sex_F=0]
        3 -> [ADHD=1, Sex_F=1]

    Returns:
    - average of two F1 scores:
        (1) ADHD F1 with extra weight on ADHD=1 & Sex_F=1
        (2) Sex_F F1 (unweighted)
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Decode back to binary labels
    true_adhd = (y_true // 2)       # 2 or 3 → 1
    true_sex_f = (y_true % 2)       # 1 or 3 → 1
    pred_adhd = (y_pred // 2)
    pred_sex_f = (y_pred % 2)

    # ADHD: apply weight=2 if true_adhd=1 and true_sex_f=1
    weights = np.where((true_adhd == 1) & (true_sex_f == 1), 2, 1)
    f1_adhd = f1_score(true_adhd, pred_adhd, sample_weight=weights, average='binary')
    f1_sex_f = f1_score(true_sex_f, pred_sex_f, average='binary')

    return (f1_adhd + f1_sex_f) / 2

In [9]:
y_true_temp = [0, 1, 2, 3]
y_pred_temp = [0, 1, 3, 3]
compute_leaderboard_f1_multiclass(y_true_temp, y_pred_temp)

0.9

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
TRAIN_X_PATH = 'preprocessed_selected_features/train/connectome_matrices.csv'
TRAIN_Y_PATH = 'preprocessed_selected_features/train/labels.csv'

In [12]:
train_X_df = pd.read_csv(TRAIN_X_PATH)
train_y_df = pd.read_csv(TRAIN_Y_PATH)
train_X_df.set_index('participant_id', inplace=True)
train_y_df.set_index('participant_id', inplace=True)
train_y_df = train_y_df.reindex(train_X_df.index)

In [13]:
train_X_df.head()

Unnamed: 0_level_0,0throw_1thcolumn,0throw_2thcolumn,0throw_3thcolumn,0throw_4thcolumn,0throw_5thcolumn,0throw_6thcolumn,0throw_7thcolumn,0throw_8thcolumn,0throw_9thcolumn,0throw_10thcolumn,...,195throw_196thcolumn,195throw_197thcolumn,195throw_198thcolumn,195throw_199thcolumn,196throw_197thcolumn,196throw_198thcolumn,196throw_199thcolumn,197throw_198thcolumn,197throw_199thcolumn,198throw_199thcolumn
participant_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
70z8Q2xdTXM3,0.270399,0.580746,0.485892,0.064059,0.617299,0.337467,0.55023,-0.087434,0.579197,0.535364,...,0.305246,0.58361,0.5409,0.228036,0.343643,0.485184,0.019701,0.614717,0.52442,0.40429
WHWymJu6zNZi,0.745668,0.635297,0.560712,0.541223,0.439375,0.473556,0.129684,-0.095509,0.132499,0.314387,...,0.295154,-0.021363,0.000563,-0.119118,0.511165,0.396962,0.201877,0.664817,0.612853,0.557002
4PAQp1M6EyAo,-0.141711,0.503933,0.294476,0.697041,0.840358,0.476623,0.690517,0.215647,0.573338,0.43098,...,0.464666,-0.031043,-0.048386,0.092712,0.464166,0.337855,0.471782,0.50496,0.566427,0.691008
obEacy4Of68I,0.242208,0.829234,0.74409,0.627094,0.756269,0.699015,0.565864,0.463884,0.584879,0.295275,...,0.140506,-0.261833,0.269811,-0.023052,0.490709,0.684443,0.26087,0.373375,0.490076,0.617905
s7WzzDcmDOhF,0.275725,0.675102,0.702433,0.613503,0.804479,0.637771,0.28674,0.404671,0.33113,0.377236,...,-0.223803,0.010372,-0.154615,-0.602806,0.55511,-0.248578,0.25421,0.061103,0.132521,0.119855


In [14]:
train_y_df.head()

Unnamed: 0_level_0,ADHD_Outcome,Sex_F
participant_id,Unnamed: 1_level_1,Unnamed: 2_level_1
70z8Q2xdTXM3,1,0
WHWymJu6zNZi,1,1
4PAQp1M6EyAo,1,1
obEacy4Of68I,1,1
s7WzzDcmDOhF,1,1


In [15]:
class Model(nn.Module):
    def __init__(self, input_dim, layer_dims, dropout=0.5, output_dim=4):
        super(Model, self).__init__()
        layers = []
        prev_dim = input_dim
        for dim in layer_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = dim
        layers.append(nn.Linear(prev_dim, output_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [16]:
temp_model = Model(100, [64, 32], dropout=0.3, output_dim=4).to(device)
summary(temp_model, (100,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 64]           6,464
              ReLU-2                   [-1, 64]               0
           Dropout-3                   [-1, 64]               0
            Linear-4                   [-1, 32]           2,080
              ReLU-5                   [-1, 32]               0
           Dropout-6                   [-1, 32]               0
            Linear-7                    [-1, 4]             132
Total params: 8,676
Trainable params: 8,676
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.03
Estimated Total Size (MB): 0.04
----------------------------------------------------------------


In [17]:
X = np.array(train_X_df.values, dtype=np.float32)

In [18]:
X.shape

(1213, 19900)

In [19]:
y_two_vars = train_y_df.values
y = np.array(y_two_vars[:, 0] * 2 + y_two_vars[:, 1], dtype=np.uint8)

In [20]:
y[:3], y_two_vars[:3]

(array([2, 3, 3], dtype=uint8),
 array([[1, 0],
        [1, 1],
        [1, 1]]))

In [21]:
layer_dims_list = [
    [256, 128, 64, 32],
    [256, 128, 128, 64, 32],
    [256, 128, 128, 64, 64, 32],
    [256, 128, 128, 64, 64, 32, 32, 16],
    [512, 256, 128, 128, 64, 64, 32, 32, 16],
]
dropouts = [0.0, .1, .2, .3, .4, .5]

In [22]:
criterion = nn.CrossEntropyLoss()

In [23]:
seed = 42  # Choose any fixed number
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # If using CUDA

In [24]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [25]:
num_epochs = 200

results = {"-".join(map(str, layer_dims)): {} for layer_dims in layer_dims_list}
epoch_history = {"-".join(map(str, layer_dims)): {} for layer_dims in layer_dims_list}
for layer_dims in layer_dims_list:
    for dropout in dropouts:
        print("complexity:", layer_dims, "dropout rate:", dropout)
        f1_scores = []
        best_epochs = []

        for fold, (train_index, test_index) in enumerate(kf.split(X)):
            model = Model(
                input_dim=X.shape[1], layer_dims=layer_dims, dropout=dropout, output_dim=4
            ).to(device)
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]

            optimizer = optim.Adam(model.parameters(), lr=0.001)
            best_test_loss = float("inf")
            best_f1 = 0.0
            best_epoch = 0

            for epoch in trange(num_epochs):
                model.train()
                optimizer.zero_grad()
                outputs = model(torch.tensor(X_train).to(device))
                loss = criterion(outputs, torch.tensor(y_train).to(device))
                loss.backward()
                optimizer.step()

                model.eval()
                with torch.no_grad():
                    test_outputs = model(torch.tensor(X_test).to(device))
                    test_loss = criterion(test_outputs, torch.tensor(y_test).to(device)).item()
                    predicted = torch.argmax(test_outputs.data, 1).cpu()
                    f1 = compute_leaderboard_f1_multiclass(y_test, predicted)

                if test_loss < best_test_loss:
                    best_test_loss = test_loss
                    best_f1 = f1
                    best_epoch = epoch


            f1_scores.append(float(best_f1))
            best_epochs.append(best_epoch)
        print(f1_scores)
        results["-".join(map(str, layer_dims))][dropout] = f1_scores
        epoch_history["-".join(map(str, layer_dims))][dropout] = best_epochs

complexity: [256, 128, 64, 32] dropout rate: 0.0


100%|██████████| 200/200 [00:18<00:00, 10.95it/s]
100%|██████████| 200/200 [00:17<00:00, 11.53it/s]
100%|██████████| 200/200 [00:17<00:00, 11.16it/s]
100%|██████████| 200/200 [00:18<00:00, 11.07it/s]
100%|██████████| 200/200 [00:18<00:00, 10.78it/s]


[0.46378621378621376, 0.6666709989342622, 0.667961715467218, 0.5467344249952946, 0.6568572321149642]
complexity: [256, 128, 64, 32] dropout rate: 0.1


100%|██████████| 200/200 [00:18<00:00, 10.67it/s]
100%|██████████| 200/200 [00:17<00:00, 11.13it/s]
100%|██████████| 200/200 [00:17<00:00, 11.18it/s]
100%|██████████| 200/200 [00:19<00:00, 10.51it/s]
100%|██████████| 200/200 [00:17<00:00, 11.12it/s]


[0.46858288770053474, 0.43018867924528303, 0.6407315340909091, 0.5421909562316423, 0.6145672441706831]
complexity: [256, 128, 64, 32] dropout rate: 0.2


100%|██████████| 200/200 [00:18<00:00, 10.79it/s]
100%|██████████| 200/200 [00:17<00:00, 11.17it/s]
100%|██████████| 200/200 [00:18<00:00, 11.11it/s]
100%|██████████| 200/200 [00:18<00:00, 10.80it/s]
100%|██████████| 200/200 [00:17<00:00, 11.21it/s]


[0.49778621125869704, 0.4552430695058256, 0.7383192623115847, 0.42292490118577075, 0.4085106382978723]
complexity: [256, 128, 64, 32] dropout rate: 0.3


100%|██████████| 200/200 [00:19<00:00, 10.38it/s]
100%|██████████| 200/200 [00:17<00:00, 11.34it/s]
100%|██████████| 200/200 [00:17<00:00, 11.23it/s]
100%|██████████| 200/200 [00:18<00:00, 10.79it/s]
100%|██████████| 200/200 [00:17<00:00, 11.18it/s]


[0.4358610914245216, 0.43018867924528303, 0.43761996161228406, 0.6162870945479642, 0.47919216646266827]
complexity: [256, 128, 64, 32] dropout rate: 0.4


100%|██████████| 200/200 [00:18<00:00, 10.89it/s]
100%|██████████| 200/200 [00:17<00:00, 11.20it/s]
100%|██████████| 200/200 [00:17<00:00, 11.33it/s]
100%|██████████| 200/200 [00:18<00:00, 10.94it/s]
100%|██████████| 200/200 [00:17<00:00, 11.15it/s]


[0.5834160691303548, 0.43018867924528303, 0.43761996161228406, 0.4720901340873067, 0.41922290388548056]
complexity: [256, 128, 64, 32] dropout rate: 0.5


100%|██████████| 200/200 [00:17<00:00, 11.11it/s]
100%|██████████| 200/200 [00:14<00:00, 13.94it/s]
100%|██████████| 200/200 [00:17<00:00, 11.43it/s]
100%|██████████| 200/200 [00:17<00:00, 11.15it/s]
100%|██████████| 200/200 [00:19<00:00, 10.50it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.41030927835051545, 0.5681114551083591]
complexity: [256, 128, 128, 64, 32] dropout rate: 0.0


100%|██████████| 200/200 [00:17<00:00, 11.23it/s]
100%|██████████| 200/200 [00:16<00:00, 11.87it/s]
100%|██████████| 200/200 [00:17<00:00, 11.27it/s]
100%|██████████| 200/200 [00:17<00:00, 11.30it/s]
100%|██████████| 200/200 [00:18<00:00, 10.86it/s]


[0.6474789915966386, 0.43018867924528303, 0.7029601029601029, 0.6052947023739914, 0.6307300509337861]
complexity: [256, 128, 128, 64, 32] dropout rate: 0.1


100%|██████████| 200/200 [00:17<00:00, 11.21it/s]
100%|██████████| 200/200 [00:18<00:00, 10.94it/s]
100%|██████████| 200/200 [00:17<00:00, 11.31it/s]
100%|██████████| 200/200 [00:17<00:00, 11.28it/s]
100%|██████████| 200/200 [00:18<00:00, 10.85it/s]


[0.6800573888091822, 0.7134256734649362, 0.7341716857502151, 0.42292490118577075, 0.4785238959467634]
complexity: [256, 128, 128, 64, 32] dropout rate: 0.2


100%|██████████| 200/200 [00:17<00:00, 11.28it/s]
100%|██████████| 200/200 [00:18<00:00, 11.03it/s]
100%|██████████| 200/200 [00:17<00:00, 11.25it/s]
100%|██████████| 200/200 [00:18<00:00, 10.80it/s]
100%|██████████| 200/200 [00:18<00:00, 10.88it/s]


[0.6922525107604017, 0.43018867924528303, 0.5383784974664023, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 32] dropout rate: 0.3


100%|██████████| 200/200 [00:17<00:00, 11.32it/s]
100%|██████████| 200/200 [00:17<00:00, 11.18it/s]
100%|██████████| 200/200 [00:18<00:00, 11.07it/s]
100%|██████████| 200/200 [00:17<00:00, 11.34it/s]
100%|██████████| 200/200 [00:18<00:00, 10.90it/s]


[0.4117647058823529, 0.43018867924528303, 0.5755102040816327, 0.42292490118577075, 0.4900181488203267]
complexity: [256, 128, 128, 64, 32] dropout rate: 0.4


100%|██████████| 200/200 [00:17<00:00, 11.34it/s]
100%|██████████| 200/200 [00:18<00:00, 11.10it/s]
100%|██████████| 200/200 [00:18<00:00, 11.03it/s]
100%|██████████| 200/200 [00:17<00:00, 11.27it/s]
100%|██████████| 200/200 [00:18<00:00, 10.90it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 32] dropout rate: 0.5


100%|██████████| 200/200 [00:17<00:00, 11.21it/s]
100%|██████████| 200/200 [00:17<00:00, 11.25it/s]
100%|██████████| 200/200 [00:18<00:00, 10.58it/s]
100%|██████████| 200/200 [00:17<00:00, 11.12it/s]
100%|██████████| 200/200 [00:18<00:00, 10.81it/s]


[0.4117647058823529, 0.4408269771176235, 0.43761996161228406, 0.4012738853503185, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32] dropout rate: 0.0


100%|██████████| 200/200 [00:17<00:00, 11.28it/s]
100%|██████████| 200/200 [00:17<00:00, 11.20it/s]
100%|██████████| 200/200 [00:18<00:00, 11.08it/s]
100%|██████████| 200/200 [00:17<00:00, 11.21it/s]
100%|██████████| 200/200 [00:18<00:00, 10.69it/s]


[0.42395982783357244, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.5924812030075188]
complexity: [256, 128, 128, 64, 64, 32] dropout rate: 0.1


100%|██████████| 200/200 [00:17<00:00, 11.18it/s]
100%|██████████| 200/200 [00:18<00:00, 10.98it/s]
100%|██████████| 200/200 [00:18<00:00, 10.94it/s]
100%|██████████| 200/200 [00:17<00:00, 11.25it/s]
100%|██████████| 200/200 [00:18<00:00, 10.76it/s]


[0.6405510441712411, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32] dropout rate: 0.2


100%|██████████| 200/200 [00:18<00:00, 10.67it/s]
100%|██████████| 200/200 [00:17<00:00, 11.23it/s]
100%|██████████| 200/200 [00:18<00:00, 11.05it/s]
100%|██████████| 200/200 [00:17<00:00, 11.18it/s]
100%|██████████| 200/200 [00:18<00:00, 10.87it/s]


[0.45827633378932964, 0.43018867924528303, 0.47984496124031006, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32] dropout rate: 0.3


100%|██████████| 200/200 [00:18<00:00, 11.11it/s]
100%|██████████| 200/200 [00:17<00:00, 11.24it/s]
100%|██████████| 200/200 [00:18<00:00, 11.03it/s]
100%|██████████| 200/200 [00:17<00:00, 11.34it/s]
100%|██████████| 200/200 [00:18<00:00, 10.76it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32] dropout rate: 0.4


100%|██████████| 200/200 [00:17<00:00, 11.23it/s]
100%|██████████| 200/200 [00:18<00:00, 11.03it/s]
100%|██████████| 200/200 [00:18<00:00, 11.08it/s]
100%|██████████| 200/200 [00:10<00:00, 19.12it/s]
100%|██████████| 200/200 [00:18<00:00, 10.80it/s]


[0.4784313725490196, 0.48069372975033353, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32] dropout rate: 0.5


100%|██████████| 200/200 [00:10<00:00, 19.31it/s]
100%|██████████| 200/200 [00:17<00:00, 11.34it/s]
100%|██████████| 200/200 [00:10<00:00, 19.39it/s]
100%|██████████| 200/200 [00:17<00:00, 11.31it/s]
100%|██████████| 200/200 [00:10<00:00, 19.62it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.0


100%|██████████| 200/200 [00:18<00:00, 11.06it/s]
100%|██████████| 200/200 [00:10<00:00, 19.79it/s]
100%|██████████| 200/200 [00:17<00:00, 11.19it/s]
100%|██████████| 200/200 [00:10<00:00, 19.45it/s]
100%|██████████| 200/200 [00:17<00:00, 11.37it/s]


[0.42381289865343724, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.4140786749482402]
complexity: [256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.1


100%|██████████| 200/200 [00:14<00:00, 13.98it/s]
100%|██████████| 200/200 [00:17<00:00, 11.27it/s]
100%|██████████| 200/200 [00:13<00:00, 14.41it/s]
100%|██████████| 200/200 [00:14<00:00, 14.27it/s]
100%|██████████| 200/200 [00:17<00:00, 11.40it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.2


100%|██████████| 200/200 [00:14<00:00, 13.63it/s]
100%|██████████| 200/200 [00:13<00:00, 14.74it/s]
100%|██████████| 200/200 [00:18<00:00, 11.00it/s]
100%|██████████| 200/200 [00:14<00:00, 13.76it/s]
100%|██████████| 200/200 [00:17<00:00, 11.58it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.3


100%|██████████| 200/200 [00:14<00:00, 14.15it/s]
100%|██████████| 200/200 [00:14<00:00, 13.80it/s]
100%|██████████| 200/200 [00:18<00:00, 10.89it/s]
100%|██████████| 200/200 [00:11<00:00, 18.06it/s]
100%|██████████| 200/200 [00:17<00:00, 11.16it/s]


[0.5950980392156863, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.4


100%|██████████| 200/200 [00:13<00:00, 15.38it/s]
100%|██████████| 200/200 [00:18<00:00, 11.11it/s]
100%|██████████| 200/200 [00:11<00:00, 17.18it/s]
100%|██████████| 200/200 [00:18<00:00, 10.98it/s]
100%|██████████| 200/200 [00:11<00:00, 18.15it/s]


[0.4117647058823529, 0.4161735700197239, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.5


100%|██████████| 200/200 [00:17<00:00, 11.43it/s]
100%|██████████| 200/200 [00:12<00:00, 15.88it/s]
100%|██████████| 200/200 [00:17<00:00, 11.47it/s]
100%|██████████| 200/200 [00:18<00:00, 11.05it/s]
100%|██████████| 200/200 [00:09<00:00, 20.37it/s]


[0.4117647058823529, 0.43018867924528303, 0.38620689655172413, 0.42292490118577075, 0.42105263157894735]
complexity: [512, 256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.0


100%|██████████| 200/200 [00:20<00:00,  9.98it/s]
100%|██████████| 200/200 [00:17<00:00, 11.57it/s]
100%|██████████| 200/200 [00:20<00:00,  9.91it/s]
100%|██████████| 200/200 [00:16<00:00, 11.86it/s]
100%|██████████| 200/200 [00:20<00:00,  9.90it/s]


[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.5794684731631058]
complexity: [512, 256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.1


100%|██████████| 200/200 [00:15<00:00, 12.56it/s]
100%|██████████| 200/200 [00:20<00:00,  9.77it/s]
100%|██████████| 200/200 [00:16<00:00, 12.28it/s]
100%|██████████| 200/200 [00:19<00:00, 10.16it/s]
100%|██████████| 200/200 [00:20<00:00,  9.78it/s]


[0.4117647058823529, 0.43018867924528303, 0.4305555555555556, 0.42292490118577075, 0.5190145576707726]
complexity: [512, 256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.2


100%|██████████| 200/200 [00:15<00:00, 12.80it/s]
100%|██████████| 200/200 [00:20<00:00,  9.79it/s]
100%|██████████| 200/200 [00:16<00:00, 12.29it/s]
100%|██████████| 200/200 [00:16<00:00, 11.83it/s]
100%|██████████| 200/200 [00:20<00:00,  9.77it/s]


[0.47918043621943157, 0.43018867924528303, 0.43761996161228406, 0.42292490118577075, 0.42105263157894735]
complexity: [512, 256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.3


100%|██████████| 200/200 [00:16<00:00, 12.39it/s]
100%|██████████| 200/200 [00:20<00:00,  9.73it/s]
100%|██████████| 200/200 [00:16<00:00, 11.81it/s]
100%|██████████| 200/200 [00:20<00:00,  9.71it/s]
100%|██████████| 200/200 [00:14<00:00, 13.51it/s]


[0.41140529531568226, 0.6936617331374986, 0.3902439024390244, 0.5969825089771805, 0.41561181434599154]
complexity: [512, 256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.4


100%|██████████| 200/200 [00:19<00:00, 10.11it/s]
100%|██████████| 200/200 [00:18<00:00, 10.78it/s]
100%|██████████| 200/200 [00:19<00:00, 10.03it/s]
100%|██████████| 200/200 [00:20<00:00,  9.63it/s]
100%|██████████| 200/200 [00:16<00:00, 12.17it/s]


[0.4117647058823529, 0.43018867924528303, 0.4129979035639413, 0.5883925270850513, 0.42105263157894735]
complexity: [512, 256, 128, 128, 64, 64, 32, 32, 16] dropout rate: 0.5


100%|██████████| 200/200 [00:20<00:00,  9.86it/s]
100%|██████████| 200/200 [00:19<00:00, 10.13it/s]
100%|██████████| 200/200 [00:17<00:00, 11.52it/s]
100%|██████████| 200/200 [00:19<00:00, 10.17it/s]
100%|██████████| 200/200 [00:20<00:00,  9.88it/s]

[0.4117647058823529, 0.43018867924528303, 0.43761996161228406, 0.39072847682119205, 0.42105263157894735]





In [26]:
results_json = json.dumps(results, indent=4)
print(results_json)

{
    "256-128-64-32": {
        "0.0": [
            0.46378621378621376,
            0.6666709989342622,
            0.667961715467218,
            0.5467344249952946,
            0.6568572321149642
        ],
        "0.1": [
            0.46858288770053474,
            0.43018867924528303,
            0.6407315340909091,
            0.5421909562316423,
            0.6145672441706831
        ],
        "0.2": [
            0.49778621125869704,
            0.4552430695058256,
            0.7383192623115847,
            0.42292490118577075,
            0.4085106382978723
        ],
        "0.3": [
            0.4358610914245216,
            0.43018867924528303,
            0.43761996161228406,
            0.6162870945479642,
            0.47919216646266827
        ],
        "0.4": [
            0.5834160691303548,
            0.43018867924528303,
            0.43761996161228406,
            0.4720901340873067,
            0.41922290388548056
        ],
        "0.5": [
            0

In [27]:
full_results = {}
summary_results = {}
final_epoch_history = {}

for layer_dims in results.keys():
    for dropout in results[layer_dims].keys():
        full_results[layer_dims+'-'+str(dropout)] = results[layer_dims][dropout]
        summary_results[layer_dims+'-'+str(dropout)] = float(np.mean(results[layer_dims][dropout]))
        final_epoch_history[layer_dims+'-'+str(dropout)] = epoch_history[layer_dims][dropout]

summary_results = dict(sorted(summary_results.items(), key=lambda item: item[1], reverse=True))
keys = list(summary_results.keys())
final_epoch_history = dict(sorted(final_epoch_history.items(), key=lambda item: keys.index(item[0]), reverse=False))


with open("full_results.json", "w") as f:
    json.dump(results, f, indent=4)


with open("summary_results.json", "w") as f:
    json.dump(summary_results, f, indent=4)


with open("epoch_history.json", "w") as f:
    json.dump(final_epoch_history, f, indent=4)

In [28]:
best_layer_dims, best_dropout, n_epochs = [256,128,128,64,32], 0.1, 25

In [29]:
model = Model(
    input_dim=X.shape[1], layer_dims=best_layer_dims, dropout=best_dropout, output_dim=4
)
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in trange(n_epochs):
    optimizer.zero_grad()
    outputs = model(torch.tensor(X))
    loss = criterion(outputs, torch.tensor(y))
    loss.backward()
    optimizer.step()

100%|██████████| 25/25 [00:11<00:00,  2.09it/s]


In [30]:
torch.save(model.state_dict(), "256-128-128-64-32-0.1.pth")

In [31]:
model.eval()
with torch.no_grad():
    yhat = model(torch.tensor(X))
predicted = torch.argmax(yhat, 1)
f1 = compute_leaderboard_f1_multiclass(y, predicted)
print(f1)

0.42492138364779874
