In [2]:
import main as a
import pickle as pkl
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from importlib import reload
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
import pandas as pd

In [3]:
with open("tadpole/X.pkl", "rb") as file:
    X = pkl.load(file)
early = False
with open("tadpole/y.pkl", "rb") as file:
    if early:
        y = pkl.load(file)[:,1:]
    else:
        y = pkl.load(file)[:,:-1]

In [4]:
X = X.reshape(X.shape[0]*X.shape[1], X.shape[2])
y = y.flatten()
weight = torch.tensor(compute_class_weight(class_weight="balanced", classes=np.unique(y), y=y), dtype=torch.float32)

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X,y, stratify=y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, stratify=y_train, test_size=0.1, random_state=42)

In [6]:
reload(a)

<module 'main' from '/Users/aviadsusman/Documents/Python_Projects/FLA_2/FLA/main.py'>

In [7]:
train_dataset = a.npDataset(X_train,y_train)
test_dataset = a.npDataset(X_test,y_test)
val_dataset = a.npDataset(X_val,y_val)
batch_size = 100

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [8]:
hidden_dims = [160,40,10]
attn_heads = 0
model = a.FLANN(input_dim=337, hidden_dims=hidden_dims, output_dim=3, attn_heads=attn_heads, activation=nn.ReLU())

In [9]:
criterion = nn.CrossEntropyLoss(weight=weight)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
metric = a.MacroF1Score(num_classes=3)

In [10]:
num_epochs = 500
best_val_loss = float('inf')
best_model = None
patience = 10
early_stop_counter = 0
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    val_losses = []
    for inputs, labels in val_loader:
        with torch.no_grad():
            outputs = model(inputs)
            val_loss = criterion(outputs, labels)
            val_losses.append(val_loss.item())
    
    avg_val_loss = np.mean(val_losses)
    print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model = model.state_dict()
        early_stop_counter = 0
    else:
        early_stop_counter += 1
    
    if early_stop_counter >= patience:
        print(f'Early stopping after epoch {epoch+1} with validation loss {best_val_loss:.4f}')
        break

model.load_state_dict(best_model)
# model.eval()

Epoch 1, Validation Loss: 0.7146
Epoch 2, Validation Loss: 0.5688
Epoch 3, Validation Loss: 0.4947
Epoch 4, Validation Loss: 0.5060
Epoch 5, Validation Loss: 0.4300
Epoch 6, Validation Loss: 0.3755
Epoch 7, Validation Loss: 0.2981
Epoch 8, Validation Loss: 0.3217
Epoch 9, Validation Loss: 0.2932
Epoch 10, Validation Loss: 0.3437
Epoch 11, Validation Loss: 0.4066
Epoch 12, Validation Loss: 0.3132
Epoch 13, Validation Loss: 0.4432
Epoch 14, Validation Loss: 0.2864
Epoch 15, Validation Loss: 0.2529
Epoch 16, Validation Loss: 0.4256
Epoch 17, Validation Loss: 0.5065
Epoch 18, Validation Loss: 0.3385
Epoch 19, Validation Loss: 0.3681
Epoch 20, Validation Loss: 0.5179
Epoch 21, Validation Loss: 0.3981
Epoch 22, Validation Loss: 0.4388
Epoch 23, Validation Loss: 0.4168
Epoch 24, Validation Loss: 0.5389
Epoch 25, Validation Loss: 0.5706
Early stopping after epoch 25 with validation loss 0.2529


<All keys matched successfully>

In [11]:
test_losses = []
test_predictions = []
test_true_labels = []

for inputs, labels in test_loader:
    with torch.no_grad():
        outputs = model(inputs)
        test_loss = criterion(outputs, labels)
        test_losses.append(test_loss.item())
        predictions = torch.argmax(outputs, dim=1)
        test_predictions.extend(predictions.cpu().numpy())
        test_true_labels.extend(labels.cpu().numpy())

avg_test_loss = np.mean(test_losses)
test_score = f1_score(test_true_labels, test_predictions, average='weighted')
print(f'Test Loss: {avg_test_loss:.4f}, Test Score: {test_score:.4f}, Predicted Proba: {1/np.exp(avg_test_loss):.4f}')

Test Loss: 0.5963, Test Score: 0.8732, Predicted Proba: 0.5508


In [13]:
seeds=10
heads=15
scores = {head: [] for head in range(heads)}
probas = {head: [] for head in range(heads)}
for seed in range(seeds):
    X_train, X_test, y_train, y_test = train_test_split(X,y, stratify=y, test_size=0.2, random_state=seed)
    X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, stratify=y_train, test_size=0.1, random_state=seed)
    train_dataset = a.npDataset(X_train,y_train)
    test_dataset = a.npDataset(X_test,y_test)
    val_dataset = a.npDataset(X_val,y_val)
    batch_size = 100

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    hidden_dims = [160,40,10]
    for head_num in range(heads):
        attn_heads = head_num
        model = a.FLANN(input_dim=337, hidden_dims=hidden_dims, output_dim=3, attn_heads=attn_heads, activation=nn.ReLU())
        criterion = nn.CrossEntropyLoss(weight=weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        metric = a.MacroF1Score(num_classes=3)       

        num_epochs = 500
        best_val_loss = float('inf')
        best_model = None
        patience = 10
        early_stop_counter = 0
        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            model.eval()
            val_losses = []
            for inputs, labels in val_loader:
                with torch.no_grad():
                    outputs = model(inputs)
                    val_loss = criterion(outputs, labels)
                    val_losses.append(val_loss.item())
            
            avg_val_loss = np.mean(val_losses)
            print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model = model.state_dict()
                early_stop_counter = 0
            else:
                early_stop_counter += 1
            
            if early_stop_counter >= patience:
                print(f'Early stopping after epoch {epoch+1} with validation loss {best_val_loss:.4f}')
                break

        model.load_state_dict(best_model) 
        test_losses = []
        test_predictions = []
        test_true_labels = []

        for inputs, labels in test_loader:
            with torch.no_grad():
                outputs = model(inputs)
                test_loss = criterion(outputs, labels)
                test_losses.append(test_loss.item())
                predictions = torch.argmax(outputs, dim=1)
                test_predictions.extend(predictions.cpu().numpy())
                test_true_labels.extend(labels.cpu().numpy())

        avg_test_loss = np.mean(test_losses)
        test_score = f1_score(test_true_labels, test_predictions, average='weighted')
        print(f'Test Loss: {avg_test_loss:.4f}, Test Score: {test_score:.4f}, Predicted Proba: {1/np.exp(avg_test_loss):.4f}')
        scores[head_num].append(test_score)
        probas[head_num].append(1/np.exp(avg_test_loss))

scores = pd.DataFrame(scores)
probas = pd.DataFrame(probas)


    

Epoch 1, Validation Loss: 0.7542
Epoch 2, Validation Loss: 0.6166
Epoch 3, Validation Loss: 0.5397
Epoch 4, Validation Loss: 0.4844
Epoch 5, Validation Loss: 0.4243
Epoch 6, Validation Loss: 0.4521
Epoch 7, Validation Loss: 0.4518
Epoch 8, Validation Loss: 0.6154
Epoch 9, Validation Loss: 0.4840
Epoch 10, Validation Loss: 0.4529
Epoch 11, Validation Loss: 0.5213
Epoch 12, Validation Loss: 0.4457
Epoch 13, Validation Loss: 0.6760
Epoch 14, Validation Loss: 0.5944
Epoch 15, Validation Loss: 0.5048
Early stopping after epoch 15 with validation loss 0.4243
Test Loss: 0.4312, Test Score: 0.8475, Predicted Proba: 0.6498
Epoch 1, Validation Loss: 0.6755
Epoch 2, Validation Loss: 0.5904
Epoch 3, Validation Loss: 0.5442
Epoch 4, Validation Loss: 0.4970
Epoch 5, Validation Loss: 0.4607
Epoch 6, Validation Loss: 0.4657
Epoch 7, Validation Loss: 0.4921
Epoch 8, Validation Loss: 0.5264
Epoch 9, Validation Loss: 0.6483
Epoch 10, Validation Loss: 0.5127
Epoch 11, Validation Loss: 0.3712
Epoch 12, Val