In [None]:
import torch
import torch.nn as nn
import numpy as np
from dataloader import EHRDataset
from torch.utils.data import DataLoader
from models.transformer_model import TransformerPredictor
from utils import create_tokenizer, compute_metrics
from tqdm import trange, tqdm
import copy

create_tokenizer()
train_dataset = EHRDataset(mode="train")
test_dataset = EHRDataset(mode="test")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from torch.utils.data import WeightedRandomSampler

targets = np.array(list(train_dataset.targets.values()))
labels_uniques, counts = np.unique(targets, return_counts=True)
class_weights = [sum(counts)/c for c in counts]
weights = [class_weights[x] for x in targets]
sampler = WeightedRandomSampler(weights, len(targets))

criterion = nn.BCELoss().to(device)

In [None]:
d_embed, d_transformer = 48, 128

model = TransformerPredictor(d_embedding=d_embed, d_model=d_transformer, n_layers=2, tokenizer_codes=train_dataset.tokenizer, dropout=0.5, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1
    # weight_decay=1e-4
    )

best_test_auprc = 0.
patience, current_patience = 15, 0

# Training loop
epochs = 100
start_lr, end_lr = 1e-4, 1e-5
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: start_lr*(1-epoch/epochs) + end_lr*(epoch/epochs))

for epoch in range(epochs):
    model.train()
    y_prob, y_true = [], []
    for e in DataLoader(dataset=train_dataset, batch_size=8, drop_last=False, sampler=sampler):
        optimizer.zero_grad()
        
        minutes, codes, values = e['minutes'].to(device), e['codes'].to(device), e['values'].to(device)
        y = e['target'].to(device)

        output = model(codes, values, minutes)
        loss = criterion(output.squeeze(), y.float())
        loss.backward()
        optimizer.step()

        y_prob += output.squeeze().detach().tolist()
        y_true += y.tolist()
        
    acc, auprc, auroc, bce = compute_metrics(y_true, y_prob)
    print(f"Epoch {1+epoch}: train: acc {round(acc, 3)}; auprc {round(auprc, 3)}; auroc {round(auroc, 3)}; bce {round(bce, 3)}")
    scheduler.step()
    
    model.eval()
    y_prob, y_true = [], []
    for e in DataLoader(dataset=test_dataset, batch_size=256, shuffle=False, drop_last=False):
        minutes, codes, values = e['minutes'].to(device), e['codes'].to(device), e['values'].to(device)
        y = e['target'].to(device)

        output = model(codes, values, minutes)
        loss = criterion(output.squeeze(), y.float())
        
        y_prob += output.squeeze().detach().tolist()
        y_true += y.tolist()

    acc, auprc, auroc, bce = compute_metrics(y_true, y_prob)
    print(f" test: acc {round(acc, 3)}; auprc {round(auprc, 3)}; auroc {round(auroc, 3)}; bce {round(bce, 3)}\n")
    if auprc > best_test_auprc:
        current_patience = 0
        best_test_auprc = auprc
        best_dict = copy.deepcopy(model.state_dict())
        best_row = f" best test (epoch {epoch}): acc {round(acc, 3)}; auprc {round(auprc, 3)}; auroc {round(auroc, 3)}; bce {round(bce, 3)}"
    else :
        current_patience += 1

    tqdm.write(best_row)
    
    if current_patience == patience:
        break
        

torch.save(best_dict, f"{round(100*best_test_auprc)}%_model.pt")

In [None]:
list(zip(y_true, y_prob))