In [10]:
import torch
import torch.nn as nn
import fla2 as a
import pickle as pkl
from torch.utils.data import DataLoader
from main import npDataset
import time
import numpy as np
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from importlib import reload
import os


In [13]:
cwd=os.getcwd()
with open(f"{cwd}/data/diabetes/X.pkl", "rb") as file:
    X = pkl.load(file)
with open(f"{cwd}/data/diabetes/y.pkl", "rb") as file:
    y = pkl.load(file)

In [14]:
y_counts = np.unique(y, return_counts=True)[1]
weight = torch.tensor([y_counts[0]/y_counts[1]], dtype=torch.float32)

In [18]:
reload(a)
head_counts = [1,2,3,4]
test_prediction_dict = {h: [] for h in head_counts}
test_label_list = []
losses = {h: [] for h in head_counts}

forward_times = []
loss_times = []
backwards_times = []
optimizer_times = []

for seed in range(10):
    for i, head in enumerate(head_counts):
        print(f'seed {seed+1}, with {head} heads')
        #split data
        X_train, X_test, y_train, y_test = train_test_split(X,y, stratify=y, test_size=0.2, random_state=seed)
        X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, stratify=y_train, test_size=0.1, random_state=seed)
        train_dataset = npDataset(X_train,y_train)
        test_dataset = npDataset(X_test,y_test)
        val_dataset = npDataset(X_val,y_val)
        batch_size = 64
        train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        #make model
        hidden_dims = [50,25,10]
        embed_dim = head
        model = a.TabularAttentionClassifier(input_dim=108, embed_dim=2*head, hidden_dims=hidden_dims, num_heads=head)
        criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        #train
        num_epochs = 500
        best_val_loss = float('inf')
        best_model = None
        patience = 10
        early_stop_counter = 0
        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                start_f = time.time()
                outputs = model(inputs)
                end_f = time.time()
                forward_times.append(end_f-start_f)
                labels = labels.unsqueeze(1)
                start_l = time.time()
                loss = criterion(outputs, labels)
                end_l = time.time()
                loss_times.append(end_l-start_l)
                start_b = time.time()
                loss.backward()
                end_b = time.time()
                backwards_times.append(end_b-start_b)
                start_o = time.time()
                optimizer.step()
                end_o = time.time()
                optimizer_times.append(end_o-start_o)

            model.eval()
            val_losses = []
            for inputs, labels in val_loader:
                with torch.no_grad():
                    outputs = model(inputs)
                    labels = labels.unsqueeze(1)
                    val_loss = criterion(outputs, labels)
                    val_losses.append(val_loss.item())
            
            avg_val_loss = np.mean(val_losses)
            print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model = model.state_dict()
                early_stop_counter = 0
            else:
                early_stop_counter += 1
            
            if early_stop_counter >= patience:
                print(f'Early stopping after epoch {epoch+1} with validation loss {best_val_loss:.4f}')
                break
            
        model.load_state_dict(best_model)

        #eval
        test_losses = []
        test_predictions = []
        test_true_labels = []

        for inputs, labels in test_loader:
            with torch.no_grad():
                outputs = model(inputs)
                labels = labels.unsqueeze(1)
                test_loss = criterion(outputs, labels)
                test_losses.append(test_loss.item())
                test_predictions.extend(outputs.cpu().numpy())
                test_true_labels.extend(labels.cpu().numpy())
        avg_test_loss = np.mean(test_losses)
        test_predictions_f1 = [y>0.5 for y in test_predictions]
        test_score = f1_score(test_true_labels, test_predictions_f1)
        print(f'Test Loss: {avg_test_loss:.4f}, Test Score: {test_score:.4f} for seed {seed+1} and {head} heads.')
        if i == 0:
            test_label_list.append(test_true_labels)
        test_prediction_dict[head].append(test_predictions)
        losses[head].append(avg_test_loss)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/fla2/diabetes/test_pred_dict_1_to_4.pkl", "wb") as file:
    pkl.dump(test_prediction_dict, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/fla2/diabetes/test_losses_dict_1_to_4.pkl", "wb") as file:
    pkl.dump(losses, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/fla2/diabetes/test_labels.pkl", "wb") as file:
    pkl.dump(test_label_list, file=file)

seed 1, with 1 heads


AssertionError: was expecting embedding dimension of 2, but got 108

In [50]:
test_prediction_dict[5][0]

[array([-1.2212318], dtype=float32),
 array([0.18830653], dtype=float32),
 array([0.4394591], dtype=float32),
 array([-1.0474021], dtype=float32),
 array([0.23518611], dtype=float32),
 array([0.01598065], dtype=float32),
 array([-0.9846909], dtype=float32),
 array([0.5603118], dtype=float32),
 array([0.8618928], dtype=float32),
 array([-1.1921793], dtype=float32),
 array([-0.43146533], dtype=float32),
 array([-0.8783313], dtype=float32),
 array([-0.07539515], dtype=float32),
 array([0.59515655], dtype=float32),
 array([1.1372019], dtype=float32),
 array([1.0608116], dtype=float32),
 array([-0.0414625], dtype=float32),
 array([-0.35973752], dtype=float32),
 array([0.9244994], dtype=float32),
 array([-1.2234756], dtype=float32),
 array([-0.72969186], dtype=float32),
 array([0.00058858], dtype=float32),
 array([0.6104087], dtype=float32),
 array([-1.2745628], dtype=float32),
 array([0.48556817], dtype=float32),
 array([-0.08482729], dtype=float32),
 array([-1.1975491], dtype=float32),
 ar