In [1]:
import main as a
import pickle as pkl
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from importlib import reload
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
import pandas as pd
from sklearn.impute import KNNImputer
from sklearn.preprocessing import MinMaxScaler

In [2]:
data = pd.read_csv("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/data/COVID-19 data/Basics/concatenated.csv")
labels = pd.read_csv("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/data/COVID-19 data/Basics/outcome.csv")

In [3]:
X_raw = data.drop(columns=['NEW_MASKED_MRN']).to_numpy()
y = labels['DECEASED_INDICATOR'].to_numpy()
y_counts = np.unique(y, return_counts=True)[1]
weight = torch.tensor([y_counts[0]/y_counts[1]], dtype=torch.float32)

In [4]:
imputer = KNNImputer(n_neighbors=5)
X_imputed_not_norm = imputer.fit_transform(X_raw)
scaler = MinMaxScaler()
X = scaler.fit_transform(X_imputed_not_norm)

In [5]:
X.shape

(4783, 98)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X,y, stratify=y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, stratify=y_train, test_size=0.1, random_state=42)

In [23]:
reload(a)

<module 'main' from '/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/main.py'>

In [None]:
train_dataset = a.npDataset(X_train,y_train)
test_dataset = a.npDataset(X_test,y_test)
val_dataset = a.npDataset(X_val,y_val)
batch_size = 100

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
hidden_dims = [50,25,10]
attn_heads = 0
model = a.FLANN(input_dim=98, hidden_dims=hidden_dims, output_dim=1, attn_heads=attn_heads, activation=nn.ReLU())

In [None]:
criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 500
best_val_loss = float('inf')
best_model = None
patience = 10
early_stop_counter = 0
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        labels = labels.unsqueeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    val_losses = []
    for inputs, labels in val_loader:
        with torch.no_grad():
            outputs = model(inputs)
            labels = labels.unsqueeze(1)
            val_loss = criterion(outputs, labels)
            val_losses.append(val_loss.item())
    
    avg_val_loss = np.mean(val_losses)
    print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model = model.state_dict()
        early_stop_counter = 0
    else:
        early_stop_counter += 1
    
    if early_stop_counter >= patience:
        print(f'Early stopping after epoch {epoch+1} with validation loss {best_val_loss:.4f}')
        break

model.load_state_dict(best_model)
# model.eval()

In [None]:
test_losses = []
test_predictions = []
test_true_labels = []

for inputs, labels in test_loader:
    with torch.no_grad():
        outputs = model(inputs)
        labels = labels.unsqueeze(1)
        test_loss = criterion(outputs, labels)
        test_losses.append(test_loss.item())
        test_predictions.extend(outputs.cpu().numpy())
        test_true_labels.extend(labels.cpu().numpy())

avg_test_loss = np.mean(test_losses)
test_predictions = [y>0.5 for y in test_predictions]
test_score = f1_score(test_true_labels, test_predictions)
print(f'Test Loss: {avg_test_loss:.4f}, Test Score: {test_score:.4f}')

In [24]:
head_counts = [20,25]#[0,5,10,15]
test_prediction_dict = {h: [] for h in head_counts}
test_label_list = []
losses = {h: [] for h in head_counts}
for seed in range(10):
    for head in head_counts:
        print(f'seed {seed+1}, with {head} heads')
        #split data
        X_train, X_test, y_train, y_test = train_test_split(X,y, stratify=y, test_size=0.2, random_state=seed)
        X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, stratify=y_train, test_size=0.1, random_state=seed)
        train_dataset = a.npDataset(X_train,y_train)
        test_dataset = a.npDataset(X_test,y_test)
        val_dataset = a.npDataset(X_val,y_val)
        batch_size = 100
        train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        #make model
        hidden_dims = [50,25,10]
        attn_heads = head
        model = a.FLANN(input_dim=98, hidden_dims=hidden_dims, output_dim=1, attn_heads=attn_heads, activation=nn.ReLU())
        criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        #train
        num_epochs = 500
        best_val_loss = float('inf')
        best_model = None
        patience = 10
        early_stop_counter = 0
        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                labels = labels.unsqueeze(1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            model.eval()
            val_losses = []
            for inputs, labels in val_loader:
                with torch.no_grad():
                    outputs = model(inputs)
                    labels = labels.unsqueeze(1)
                    val_loss = criterion(outputs, labels)
                    val_losses.append(val_loss.item())
            
            avg_val_loss = np.mean(val_losses)
            print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model = model.state_dict()
                early_stop_counter = 0
            else:
                early_stop_counter += 1
            
            if early_stop_counter >= patience:
                print(f'Early stopping after epoch {epoch+1} with validation loss {best_val_loss:.4f}')
                break
            
        model.load_state_dict(best_model)

        #eval
        test_losses = []
        test_predictions = []
        test_true_labels = []

        for inputs, labels in test_loader:
            with torch.no_grad():
                outputs = model(inputs)
                labels = labels.unsqueeze(1)
                test_loss = criterion(outputs, labels)
                test_losses.append(test_loss.item())
                test_predictions.extend(outputs.cpu().numpy())
                test_true_labels.extend(labels.cpu().numpy())
        avg_test_loss = np.mean(test_losses)
        test_predictions_f1 = [y>0.5 for y in test_predictions]
        test_score = f1_score(test_true_labels, test_predictions_f1)
        print(f'Test Loss: {avg_test_loss:.4f}, Test Score: {test_score:.4f} for seed {seed+1} and {head} heads.')
        if head == 0:
            test_label_list.append(test_true_labels)
        test_prediction_dict[head].append(test_predictions)
        losses[head].append(avg_test_loss)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/covid/test_pred_dict_20s.pkl", "wb") as file:
    pkl.dump(test_prediction_dict, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/covid/test_losses_dict_20s.pkl", "wb") as file:
    pkl.dump(losses, file=file)
with open ("/Users/aviadsusman/Documents/Python_Projects/FeatureLevelAttention/FLA/results/covid/test_labels_20s.pkl", "wb") as file:
    pkl.dump(test_label_list, file=file)

seed 1, with 20 heads
Epoch 1, Validation Loss: 0.9845
Epoch 2, Validation Loss: 0.9430
Epoch 3, Validation Loss: 0.8871
Epoch 4, Validation Loss: 0.9084
Epoch 5, Validation Loss: 0.8520
Epoch 6, Validation Loss: 0.8356
Epoch 7, Validation Loss: 0.8425
Epoch 8, Validation Loss: 0.8188
Epoch 9, Validation Loss: 0.8194
Epoch 10, Validation Loss: 0.8028
Epoch 11, Validation Loss: 0.7952
Epoch 12, Validation Loss: 0.7998
Epoch 13, Validation Loss: 0.7939
Epoch 14, Validation Loss: 0.8358
Epoch 15, Validation Loss: 0.7869
Epoch 16, Validation Loss: 0.7903
Epoch 17, Validation Loss: 0.7964
Epoch 18, Validation Loss: 0.7777
Epoch 19, Validation Loss: 0.7945
Epoch 20, Validation Loss: 0.7827
Epoch 21, Validation Loss: 0.7741
Epoch 22, Validation Loss: 0.7792
Epoch 23, Validation Loss: 0.7747
Epoch 24, Validation Loss: 0.7891
Epoch 25, Validation Loss: 0.7642
Epoch 26, Validation Loss: 0.7586
Epoch 27, Validation Loss: 0.7984
Epoch 28, Validation Loss: 0.7551
Epoch 29, Validation Loss: 0.7448
E