In [15]:
import main as a
from datasets import load_diabetes
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.impute import KNNImputer
from sklearn.preprocessing import MinMaxScaler
import pickle as pkl
from importlib import reload
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split

In [7]:
data = load_diabetes()

In [8]:
X_train = pd.concat([data['X_train'][k] for k in data['X_train'].keys()], axis=1)
X_test = pd.concat([data['X_test'][k] for k in data['X_test'].keys()], axis=1)
X_raw = pd.concat([X_train, X_test], axis=0).to_numpy()
y = pd.concat([data['y_train'], data['y_test']], axis=0).to_numpy().flatten()
y_counts = np.unique(y, return_counts=True)[1]
weight = torch.tensor([y_counts[0]/y_counts[1]], dtype=torch.float32)

In [9]:
imputer = KNNImputer(n_neighbors=5)
X_imputed_not_norm = imputer.fit_transform(X_raw)
scaler = MinMaxScaler()
X = scaler.fit_transform(X_imputed_not_norm)

In [19]:
reload(a)

<module 'main' from '/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/main.py'>

In [20]:
head_counts = [20,25]#[0,5,10,15]
test_prediction_dict = {h: [] for h in head_counts}
test_label_list = []
losses = {h: [] for h in head_counts}
for seed in range(10):
    for head in head_counts:
        print(f'seed {seed+1}, with {head} heads')
        #split data
        X_train, X_test, y_train, y_test = train_test_split(X,y, stratify=y, test_size=0.2, random_state=seed)
        X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, stratify=y_train, test_size=0.1, random_state=seed)
        train_dataset = a.npDataset(X_train,y_train)
        test_dataset = a.npDataset(X_test,y_test)
        val_dataset = a.npDataset(X_val,y_val)
        batch_size = 100
        train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        #make model
        hidden_dims = [50,25,10]
        attn_heads = head
        model = a.FLANN(input_dim=108, hidden_dims=hidden_dims, output_dim=1, attn_heads=attn_heads, activation=nn.ReLU())
        criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        #train
        num_epochs = 500
        best_val_loss = float('inf')
        best_model = None
        patience = 10
        early_stop_counter = 0
        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                labels = labels.unsqueeze(1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            model.eval()
            val_losses = []
            for inputs, labels in val_loader:
                with torch.no_grad():
                    outputs = model(inputs)
                    labels = labels.unsqueeze(1)
                    val_loss = criterion(outputs, labels)
                    val_losses.append(val_loss.item())
            
            avg_val_loss = np.mean(val_losses)
            print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model = model.state_dict()
                early_stop_counter = 0
            else:
                early_stop_counter += 1
            
            if early_stop_counter >= patience:
                print(f'Early stopping after epoch {epoch+1} with validation loss {best_val_loss:.4f}')
                break
            
        model.load_state_dict(best_model)

        #eval
        test_losses = []
        test_predictions = []
        test_true_labels = []

        for inputs, labels in test_loader:
            with torch.no_grad():
                outputs = model(inputs)
                labels = labels.unsqueeze(1)
                test_loss = criterion(outputs, labels)
                test_losses.append(test_loss.item())
                test_predictions.extend(outputs.cpu().numpy())
                test_true_labels.extend(labels.cpu().numpy())
        avg_test_loss = np.mean(test_losses)
        test_predictions_f1 = [y>0.5 for y in test_predictions]
        test_score = f1_score(test_true_labels, test_predictions_f1)
        print(f'Test Loss: {avg_test_loss:.4f}, Test Score: {test_score:.4f} for seed {seed+1} and {head} heads.')
        if head == 0:
            test_label_list.append(test_true_labels)
        test_prediction_dict[head].append(test_predictions)
        losses[head].append(avg_test_loss)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/diabetes/test_pred_dict_20s.pkl", "wb") as file:
    pkl.dump(test_prediction_dict, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/diabetes/test_losses_dict_20s.pkl", "wb") as file:
    pkl.dump(losses, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/diabetes/test_labels.pkl", "wb") as file:
    pkl.dump(test_label_list, file=file)

seed 1, with 20 heads
Epoch 1, Validation Loss: 1.1861
Epoch 2, Validation Loss: 1.1574
Epoch 3, Validation Loss: 1.1600
Epoch 4, Validation Loss: 1.1715
Epoch 5, Validation Loss: 1.1081
Epoch 6, Validation Loss: 1.1219
Epoch 7, Validation Loss: 1.1244
Epoch 8, Validation Loss: 1.1440
Epoch 9, Validation Loss: 1.1421
Epoch 10, Validation Loss: 1.1409
Epoch 11, Validation Loss: 1.1387
Epoch 12, Validation Loss: 1.1188
Epoch 13, Validation Loss: 1.1411
Epoch 14, Validation Loss: 1.1722
Epoch 15, Validation Loss: 1.1230
Early stopping after epoch 15 with validation loss 1.1081
Test Loss: 1.1660, Test Score: 0.2598 for seed 1 and 20 heads.
seed 1, with 25 heads
Epoch 1, Validation Loss: 1.1827
Epoch 2, Validation Loss: 1.1623
Epoch 3, Validation Loss: 1.1404
Epoch 4, Validation Loss: 1.1568
Epoch 5, Validation Loss: 1.1270
Epoch 6, Validation Loss: 1.1219
Epoch 7, Validation Loss: 1.1541
Epoch 8, Validation Loss: 1.1211
Epoch 9, Validation Loss: 1.1059
Epoch 10, Validation Loss: 1.1195
Epo

KeyboardInterrupt: 

In [21]:
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/diabetes/test_pred_dict_20s.pkl", "wb") as file:
    pkl.dump(test_prediction_dict, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/diabetes/test_losses_dict_20s.pkl", "wb") as file:
    pkl.dump(losses, file=file)