In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import label_binarize
from tqdm import tqdm
import numpy as np
from joblib import Parallel, delayed
import torch.backends.cudnn as cudnn

cudnn.benchmark = True
torch.cuda.empty_cache()

In [None]:
# load data
directory = "transformer_data"
labels = np.load(os.path.join(directory, "labels.npy"))
aout_data_array = np.load(os.path.join(directory, "aout_data_3d.npy"))
visit_data_array = np.load(os.path.join(directory, "visit_data.npy"))


indices = np.random.permutation(len(labels))
visit_data_array = visit_data_array[indices]
aout_data_array = aout_data_array[indices]
labels = labels[indices]

In [None]:
print(aout_data_array.shape)
print(visit_data_array.shape)
print(labels.shape)

In [None]:
# Binarize labels
labels = label_binarize(labels, classes=[0, 1, 2, 3])

In [None]:
# split test and train dataset
X_train, X_val, Y_train, Y_val = train_test_split(np.arange(len(labels)), labels, test_size=0.2, random_state=42)
visit_train, visit_val = visit_data_array[X_train], visit_data_array[X_val]
aout_train, aout_val = aout_data_array[X_train], aout_data_array[X_val]

In [None]:
class VisitDataset(Dataset):
    def __init__(self, visit_features, aout_features, labels):
        self.visit_features = visit_features
        self.aout_features = aout_features
        self.labels = labels.astype(np.float32)
        #self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.visit_features[idx], self.aout_features[idx], self.labels[idx]

In [None]:
train_dataset = VisitDataset(visit_train, aout_train, Y_train)
val_dataset = VisitDataset(visit_val, aout_val, Y_val)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=2)

In [None]:
class TwoTowerTransformer(nn.Module):
    def __init__(self, input_dim1, input_dim2, hidden_dim1, hidden_dim2, hidden_dim3, num_heads, num_layers, num_classes, seq_len1, seq_len2):
        super(TwoTowerTransformer, self).__init__()
        self.embedding1 = nn.Linear(input_dim1, hidden_dim1)
        self.embedding2 = nn.Linear(input_dim2, hidden_dim2)
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(seq_len1)
        self.bn2 = nn.BatchNorm1d(seq_len2)
        self.transformer1 = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim1, num_heads, hidden_dim1),
            num_layers=num_layers)
        self.transformer2 = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim2, num_heads, hidden_dim2),
            num_layers=num_layers)
        self.final_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim3, num_heads, hidden_dim1 + hidden_dim2),
            num_layers=num_layers)
        self.dropout = nn.Dropout(0.3)
        self.dense1 = nn.Linear(seq_len1, seq_len1)
        self.dense2 = nn.Linear(seq_len2, seq_len2)
        self.fc = nn.Linear(hidden_dim3 * (seq_len1+ seq_len2), num_classes)
    
    def forward(self, x1, x2):
        x1 = self.embedding1(x1)
        x1 = self.bn1(x1)
        x1 = self.relu(x1)

        x2 = self.embedding2(x2)
        x2 = self.bn2(x2)
        x2 = self.relu(x2)

        x1 = x1.permute(0, 2, 1)  
        x1 = self.dense1(x1) 
        x1 = x1.permute(2, 0, 1)  

        x2 = x2.permute(0, 2, 1)  
        x2 = self.dense2(x2)  
        x2 = x2.permute(2, 0, 1)  

        x1 = self.transformer1(x1)
        x2 = self.transformer2(x2)

        x1 = x1.permute(1, 0, 2)
        x2 = x2.permute(1, 0, 2)

        x = torch.cat((x1, x2), dim=1)
        x = x.permute(1, 0, 2)
        x = self.final_transformer(x)
        x = x.reshape(x.size(1), -1)

        x = self.dropout(x)
        x = self.fc(x)
        return x

In [None]:
#define early stopping
class EarlyStopping:
    def __init__(self, patience=3, verbose=True):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            #self.save_checkpoint(model)
        elif val_loss > self.best_score:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            #self.save_checkpoint(model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, model):
        if self.verbose:
            print(f'Saving model...')
        torch.save(model.state_dict(), 'checkpoint.pth')

In [None]:
input_dim1 = 5
input_dim2 = 5
hidden_dim1 = 512
hidden_dim2 = 512
hidden_dim3 = 512
seq_len1 = 36000
seq_len2 = 5
num_heads = 2
num_layers = 2
num_classes = 4
device_id = 0
device = torch.device('cuda:{}'.format(device_id)) if torch.cuda.is_available() else 'cpu'
print(device)

model = TwoTowerTransformer(input_dim1, input_dim2, hidden_dim1, hidden_dim2, hidden_dim3, num_heads, num_layers, num_classes, seq_len1, seq_len2)


In [None]:
from torchsummary import summary

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
num_params = count_trainable_parameters(model)
print(f'Trainable parameters: {num_params}')

In [None]:
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model, device_ids=[0,1])

model = model.to(device)
# pretrained_path = '../Sleep_Apnea/temp_epoch10.pth'
# model.load_state_dict(torch.load(pretrained_path))

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
early_stopping = EarlyStopping(patience=5, verbose=True)
scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

In [None]:
num_epochs = 100
best_loss = 1.0
best_auc = 0.0
log_file = open("log_twotower.txt", "w")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    all_outputs = []
    all_labels = []
    
    for inputs1, inputs2, labels in tqdm(train_loader):
        inputs1 = inputs1.to(torch.float32)
        inputs2 = inputs2.to(torch.float32)
        inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs1, inputs2)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_outputs.append(outputs.detach().cpu().numpy())
        all_labels.append(labels.detach().cpu().numpy())

        '''
        _,predictions = outputs.max(1)
        correct += predictions.eq(labels).sum().item()
        total_sample += labels.size(0)
        '''
    scheduler.step()
    all_outputs = np.concatenate(all_outputs)
    all_labels = np.concatenate(all_labels)

    # AUC and Accuracy
    train_auc_score = roc_auc_score(all_labels, all_outputs, multi_class='ovr')

    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {total_loss / len(train_loader):.4f}, Train AUC: {train_auc_score:.4f}')

    avg_loss = total_loss / len(train_loader)
    best_loss = min(best_loss, total_loss)
    best_auc = max(best_auc, train_auc_score)

    # Validation after each epoch
    model.eval()
    val_total_loss = 0.0
    val_all_outputs = []
    val_all_labels = []

    with torch.no_grad():
        for val_inputs, val_inputs2, val_labels in val_loader:
            val_inputs1 = val_inputs1.to(torch.float32)
            val_inputs2 = val_inputs2.to(torch.float32)
            val_inputs1, val_inputs2, val_labels = val_inputs1.to(device), val_inputs2.to(device), val_labels.to(device)
            val_outputs = model(val_inputs1, val_inputs2)
            val_loss = criterion(val_outputs, val_labels)

            val_total_loss += val_loss.item()
            val_all_outputs.append(val_outputs.detach().cpu().numpy())
            val_all_labels.append(val_labels.detach().cpu().numpy())

    val_all_outputs = np.concatenate(val_all_outputs)
    val_all_labels = np.concatenate(val_all_labels)

    val_auc_score = roc_auc_score(val_all_labels, val_all_outputs)
    
    avg_val_loss = val_total_loss / len(val_loader)
    
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation AUC: {val_auc_score:.4f}")

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), f'best_model.pth')
        print("Best model saved")

    log_file.write(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} Train AUC: {train_auc_score:.4f} Val Loss: {avg_val_loss:.4f} Val AUC: {val_auc_score:.4f}\n")
    log_file.write(f'Best loss: {best_loss}, Best AUC: {best_auc}\n\n')
    log_file.flush()

    if early_stopping(avg_loss, model):
        print('Early stopping triggered')
        break

print(f'Best loss is {best_loss}, best AUC is {best_auc}')


In [None]:
#test
model.eval()
total_loss = 0.0
all_outputs = []
all_labels = []
for inputs1, inputs2, labels in val_loader:
    inputs1 = inputs1.to(torch.float32)
    inputs2 = inputs2.to(torch.float32)
    inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
    outputs = model(inputs1, inputs2)
    loss = criterion(outputs, labels)

    total_loss += loss.item()
    all_outputs.append(outputs.detach().cpu().numpy())
    all_labels.append(labels.detach().cpu().numpy())

all_outputs = np.concatenate(all_outputs)
all_labels = np.concatenate(all_labels)
auc_score = roc_auc_score(all_labels, all_outputs, multi_class='ovr')
print(f"Validation Loss: {total_loss / len(val_loader):.4f}, Validation AUC: {auc_score:.4f}")