In [1]:
from radfusion import RadFusionCT
import torch
from torch import nn
from tqdm import tqdm
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import time

In [2]:
all_ct_embs = "/home/nuren.zhaksylyk/Documents/HC701/project/PEFormer/data/ct/all_ct_out_embeds.pickle"
all_emr_embs = "/home/nuren.zhaksylyk/Documents/HC701/project/PEFormer/data/ehr/all_emr_embs.pickle"
labels_csv = "/home/nuren.zhaksylyk/Documents/HC701/project/PEFormer/data/ct/all_ct_labels_split.csv"

In [3]:
dataset_train = RadFusionCT(pkl_ct_file=all_ct_embs, pkl_emr_file=all_emr_embs, csv_file=labels_csv, split='train')
dataset_val = RadFusionCT(pkl_ct_file=all_ct_embs, pkl_emr_file=all_emr_embs, csv_file=labels_csv, split='val')
dataset_test = RadFusionCT(pkl_ct_file=all_ct_embs, pkl_emr_file=all_emr_embs, csv_file=labels_csv, split='test')

Loaded 1454 samples for train set.
Loaded 193 samples for val set.
Loaded 190 samples for test set.


In [4]:
study_num, embeddings, mask, target = dataset_train[0]
print(study_num)
print(embeddings.shape)
print(mask.shape)
print(target)

1436
torch.Size([110, 2048])
torch.Size([110])
0


In [5]:
class AbsolutePositionalEncoder(nn.Module):
    def __init__(self, emb_dim, max_position=111):
        super(AbsolutePositionalEncoder, self).__init__()
        self.position = torch.arange(max_position).unsqueeze(1)

        self.positional_encoding = torch.zeros(1, max_position, emb_dim)

        _2i = torch.arange(0, emb_dim, step=2).float()

        # PE(pos, 2i) = sin(pos/10000^(2i/d_model))
        self.positional_encoding[0, :, 0::2] = torch.sin(self.position / (10000 ** (_2i / emb_dim)))

        # PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        self.positional_encoding[0, :, 1::2] = torch.cos(self.position / (10000 ** (_2i / emb_dim)))

    def forward(self, x):
        # batch_size, input_len, embedding_dim
        batch_size, seq_len, _ = x.size()

        return self.positional_encoding[:batch_size, :seq_len, :]



class PEFormer(nn.Module):
    def __init__(self, 
                 max_num_emb=110, 
                 emb_dim = 2048, 
                 num_heads = 8,
                 num_enc_layer = 6,
                 num_dec_layer = 6,
                 num_classes = 1,
                 dim_feedforward = 2048):
        super(PEFormer, self).__init__()

        self.positional_encoder = AbsolutePositionalEncoder(emb_dim, max_position=max_num_emb+1)

        self.cl_token = nn.Parameter(torch.randn(1, emb_dim))

        self.transformer = nn.Transformer(d_model=emb_dim, nhead=num_heads, num_encoder_layers=num_enc_layer, num_decoder_layers=num_dec_layer, batch_first=True, dim_feedforward=dim_feedforward)

        self.fc = nn.Linear(emb_dim, num_classes)
    
    def forward(self, x, mask):
        batch_size, seq_len, emb_dim = x.size()

        x = torch.cat([self.cl_token.expand(batch_size, -1, -1), x], dim=1)
        pos_enc = self.positional_encoder(x).to(x.device)
        x = x + pos_enc

        #add False to the beginning of mask to account for CL token
        mask = torch.cat([torch.tensor([[False]]).expand(batch_size, 1).to(mask.device), mask], dim=1)

        x = self.transformer(x, x, src_key_padding_mask = mask)

        x = self.fc(x[:, 0, :])

        return x

In [6]:
def train_step(
        model: torch.nn.Module,
        train_loader,
        loss_fn: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        device: torch.device,
):
    """
    Train model for one epoch.

    Args:
        model: PyTorch model to train.
        train_loader: PyTorch dataloader for training data.
        loss_fn: PyTorch loss function.
        optimizer: PyTorch optimizer.
        device: PyTorch device to use for training.

    Returns:
        Average loss, accuracy, macro F1 score, and macro recall for the epoch.
    """

    model.train()
    train_loss = 0.0

    targets = []
    predictions = []

    for i, (_, emb, mask, target) in enumerate(tqdm(train_loader)):
        target = target.float()
        emb, mask, target = emb.to(device), mask.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(emb, mask)     
        output = output.squeeze() 
        
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        prediction = torch.sigmoid(output) > 0.5

        predictions.extend(prediction.cpu().numpy())
        targets.extend(target.cpu().numpy())



    train_loss /= len(train_loader)
    train_acc = accuracy_score(targets, predictions)
    train_macro_f1 = f1_score(targets, predictions, average='macro')
    train_macro_recall = recall_score(targets, predictions, average='macro')

    return train_loss, train_acc, train_macro_f1, train_macro_recall


In [7]:
def val_step(
        model: torch.nn.Module,
        val_loader,
        loss_fn: torch.nn.Module,
        device: torch.device,
):
    """
    Evaluate model on val data.

    Args:
        model: PyTorch model to evaluate.
        val_loader: PyTorch dataloader for val data.
        loss_fn: PyTorch loss function.
        device: PyTorch device to use for evaluation.

    Returns:
        Average loss, accuracy, macro F1 score, and macro recall for the validation set.
    """

    model.eval()
    val_loss = 0.0
    val_targets = []
    val_predictions = []


    with torch.no_grad():
        for i, (_, emb, mask, target) in enumerate(tqdm(val_loader)):
            target = target.float()
            emb, mask, target = emb.to(device), mask.to(device), target.to(device)
            output = model(emb, mask)
            
            output = output.squeeze()
            loss = loss_fn(output, target)
            val_loss += loss.item()

            prediction = torch.sigmoid(output) > 0.5


            val_predictions.extend(prediction.cpu().numpy())
            val_targets.extend(target.cpu().numpy())
        
        val_loss /= len(val_loader)
        val_acc = accuracy_score(val_targets, val_predictions)
        val_macro_f1 = f1_score(val_targets, val_predictions, average='macro')
        val_macro_recall = recall_score(val_targets, val_predictions, average='macro')



    return val_loss, val_acc, val_macro_f1, val_macro_recall


In [8]:
def trainer(
        model: torch.nn.Module,
        train_loader,
        val_loader,
        loss_fn: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler,
        lr_scheduler_name: str,
        device: torch.device,
        epochs: int,
        save_dir: str,
        early_stopper=None,
        start_epoch = 1,
):
    """
    Train and evaluate model.

    Args:
        model: PyTorch model to train.
        train_loader: PyTorch dataloader for training data.
        val_loader: PyTorch dataloader for val data.
        loss_fn: PyTorch loss function.
        optimizer: PyTorch optimizer.
        lr_scheduler: PyTorch learning rate scheduler.
        device: PyTorch device to use for training.
        epochs: Number of epochs to train the model for.

    Returns:
        Average loss and accuracy for the val set.
    """

    results = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
        "train_f1":[],
        "val_f1":[],
        "train_recall":[],
        "val_recall":[],
    }
    best_val_loss = 1e10

    for epoch in range(start_epoch, epochs + 1):

        print(f"Epoch {epoch}:")
        train_loss, train_acc, train_macro_f1, train_macro_recall = train_step(model, train_loader, loss_fn, optimizer, device)
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_macro_f1:.4f}, Train recall: {train_macro_recall:.4f}")

        

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["train_f1"].append(train_macro_f1)
        results["train_recall"].append(train_macro_recall)


        val_loss, val_acc, val_f1, val_recall = val_step(model, val_loader, loss_fn, device)
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, Val recall: {val_recall:.4f}")
        print()

        if lr_scheduler_name == "ReduceLROnPlateau":
            lr_scheduler.step(val_loss)
        elif lr_scheduler_name != "None":
            lr_scheduler.step()
        
        results["val_loss"].append(val_loss)
        results["val_acc"].append(val_acc)
        results["val_f1"].append(val_f1)
        results["val_recall"].append(val_recall)
  
        
        
        # wandb.log({"train_loss": train_loss, "val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "train_f1": train_macro_f1, "train_recall": train_macro_recall, "val_f1": val_f1, "val_recall": val_recall,  "train_kappa": train_kappa, "val_kappa": val_kappa, "trian_auc": train_auc, "val_auc": val_auc})
        

        checkpoint = { 
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_sched': lr_scheduler}
            
                    
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, os.path.join(save_dir, "best_checkpoint.pth"))

        torch.save(checkpoint, os.path.join(save_dir, "last_checkpoint.pth"))

        if early_stopper is not None:
            if early_stopper.early_stop(val_loss):
                print("Early stopping")
                break

    return results


In [9]:
def START_seed(seed_value=9):
    seed = seed_value
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed) 

In [10]:
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_EPOCHS = 2
LEARNING_SCHEDULER = 'CosineAnnealingLR'
LOSS = 'BCEWithLogitsLoss'
SAVE_DIR = "/home/nuren.zhaksylyk/Documents/HC701/project/PEFormer/runs"
DEVICE = torch.device(f"cuda:0" if torch.cuda.is_available() else 'cpu')

print(f"Using {DEVICE} device")

Using cuda:0 device


In [11]:
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

In [12]:
START_seed()
run_id = time.strftime("%Y-%m-%d_%H-%M-%S")
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

os.mkdir(SAVE_DIR + run_id)

save_dir = SAVE_DIR + run_id

model = PEFormer()
model.to(DEVICE)

torch.compile(model)

if LOSS == "MSE":
    loss = torch.nn.MSELoss()
elif LOSS == "L1Loss":
    loss = torch.nn.L1Loss()
elif LOSS == "SmoothL1Loss":
    loss = torch.nn.SmoothL1Loss()
elif LOSS == "CrossEntropyLoss":
    loss = torch.nn.CrossEntropyLoss()
elif LOSS == "BCEWithLogitsLoss":
    loss = torch.nn.BCEWithLogitsLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

if LEARNING_SCHEDULER == "CosineAnnealingLR":
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, verbose=True)
elif LEARNING_SCHEDULER == "ReduceLROnPlateau":
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
elif LEARNING_SCHEDULER == "StepLR":
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1, verbose=True)
elif LEARNING_SCHEDULER == "MultiStepLR":
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 20], gamma=0.1)
else:
    lr_scheduler = None

Adjusting learning rate of group 0 to 1.0000e-04.


In [13]:
results = trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=loss,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        lr_scheduler_name=LEARNING_SCHEDULER,
        device=DEVICE,
        epochs=NUM_EPOCHS,
        save_dir=save_dir,
        early_stopper=None,
    )

Epoch 1:


100%|██████████| 46/46 [00:51<00:00,  1.13s/it]


Train Loss: 0.8903, Train Acc: 0.6018, Train F1: 0.4611, Train recall: 0.4926


  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
 86%|████████▌ | 6/7 [00:01<00:00,  4.00it/s]


ValueError: Target size (torch.Size([1])) must be the same as input size (torch.Size([]))