In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import label_binarize
from tqdm import tqdm

In [None]:
directory = "transformer_data"
x_dir = os.path.join(directory,"x.npy")
y_dir = os.path.join(directory,"y.npy")
X = np.load(x_dir)
Y = np.load(y_dir)
print(X)
print(Y.shape)
Y_bin = label_binarize(Y, classes=[0, 1, 2, 3])

In [None]:
#split train and test
X_train, X_val, Y_train, Y_val = train_test_split(X, Y_bin, test_size=0.2, random_state=42)

In [None]:
#define the Dataset
class VisitDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels.astype(np.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

In [None]:
train_dataset = VisitDataset(X_train, Y_train)
test_dataset = VisitDataset(X_val, Y_val)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=20480, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=20480, shuffle=False)

In [None]:
#naive transformer
class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_layers, num_classes):
        super().__init__()
        #self.embedding = nn.Embedding(input_dim, hidden_dim)
        self.embedding = nn.Linear(input_dim,hidden_dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim),
            num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2).reshape(x.size(1), -1)

        x = self.fc(x)
        return x

In [None]:
#define early stopping
class EarlyStopping:
    def __init__(self, patience=5, verbose=True):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            #self.save_checkpoint(model)
        elif val_loss > self.best_score:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            #self.save_checkpoint(model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
        if self.verbose:
            print(f'Saving model...')
        torch.save(model.state_dict(), 'checkpoint.pth')

In [None]:
from sklearn.metrics import roc_curve, auc
from scipy.special import expit

def calculate_auc(y_true, y_scores):
    y_scores = expit(y_scores)

    auc_total = 0.0
    num_labels = y_true.shape[1]
    valid_labels = 0
    
    for i in range(num_labels):
        if np.sum(y_true[:, i]) == 0 or np.sum(y_true[:, i]) == len(y_true[:, i]):
            print(f"Warning: No positive or negative samples in y_true for label {i}, skipping this label.")
            continue
        fpr, tpr, _ = roc_curve(y_true[:, i], y_scores[:, i])
        if len(fpr) > 1:
            auc_total += auc(fpr, tpr)
            valid_labels += 1
        else:
            print(f"Warning: Not enough data to calculate AUC for label {i}")
    
    if valid_labels == 0:
        return float('nan')
    return auc_total / valid_labels

In [None]:
input_dim = 5
hidden_dim = 512
num_heads = 2
num_layers = 2
num_classes = 4
device_id = 0
device = torch.device('cuda:{}'.format(device_id)) if torch.cuda.is_available() else 'cpu'
print(device)
model = Transformer(input_dim, hidden_dim, num_heads, num_layers, num_classes).to(device)

pretrained_path = 'best_model.pth'
model.load_state_dict(torch.load(pretrained_path))

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
early_stopping = EarlyStopping(patience=5, verbose=True)
scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

In [None]:

# Training and validation
num_epochs = 300
best_loss = 1.0
best_auc = 0.0
#log_file = open("training_log_new.txt", "w")
log_file = open("log_transformer.txt", "w")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    all_outputs = []
    all_labels = []
    
    for inputs, labels in tqdm(train_loader):
        inputs = inputs.unsqueeze(1).to(torch.float32)
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_outputs.append(outputs.detach().cpu().numpy())
        all_labels.append(labels.detach().cpu().numpy())

    all_outputs = np.concatenate(all_outputs)
    all_labels = np.concatenate(all_labels)
    #train_auc_score = roc_auc_score(all_labels, all_outputs, multi_class='ovr')
    unique_labels = np.unique(all_labels, axis=0)
    if len(unique_labels) > 1:
        train_auc_score = calculate_auc(all_labels, all_outputs)
    else:
        train_auc_score = 0.0 
    avg_train_loss = total_loss / len(train_loader)
    
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_train_loss:.4f}, AUC: {train_auc_score:.4f}')
    
    # Validation after each epoch
    model.eval()
    val_total_loss = 0.0
    val_all_outputs = []
    val_all_labels = []

    with torch.no_grad():
        for val_inputs, val_labels in test_loader:
            val_inputs = val_inputs.unsqueeze(1).to(torch.float32)
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)

            val_total_loss += val_loss.item()
            val_all_outputs.append(val_outputs.detach().cpu().numpy())
            val_all_labels.append(val_labels.detach().cpu().numpy())

    val_all_outputs = np.concatenate(val_all_outputs)
    val_all_labels = np.concatenate(val_all_labels)

    val_auc_score = calculate_auc(val_all_labels, val_all_outputs)
    
    avg_val_loss = val_total_loss / len(test_loader)
    
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation AUC: {val_auc_score:.4f}")
    
    # Logging
    log_file.write(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_train_loss:.4f} Train AUC: {train_auc_score:.4f} Val Loss: {avg_val_loss:.4f} Val AUC: {val_auc_score:.4f}\n")
    
    
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        torch.save(model.state_dict(), f'best_model_new.pth')
        print("Best model saved")
    
    
    best_auc = max(best_auc, val_auc_score)
    log_file.write(f'Best loss: {best_loss}, Best AUC: {best_auc}\n\n')
    log_file.flush()

    # Early stopping check
    if early_stopping(avg_train_loss, model):
        print('Early stopping triggered')
        break
    
    
    # torch.save(model.state_dict(), f'transformer_epoch{epoch}.pth')
    print("Model saved")
    

print(f'Best loss is {best_loss}, best AUC is {best_auc}')
log_file.close()


In [None]:
#Test
model.eval()
total_loss = 0.0
all_outputs = []
all_labels = []
for inputs, labels in test_loader:
    inputs = inputs.unsqueeze(1).to(torch.float32)
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    total_loss += loss.item()
    all_outputs.append(outputs.detach().cpu().numpy())
    all_labels.append(labels.detach().cpu().numpy())

all_outputs = np.concatenate(all_outputs)
all_labels = np.concatenate(all_labels)
auc_score = roc_auc_score(all_labels, all_outputs, multi_class='ovr')
print(f"Validation Loss: {total_loss / len(test_loader):.4f}, Validation AUC: {auc_score:.4f}")