In [1]:
from data.dataset import REFLACXWithClinicalDataset, RecordPoint
from model.xami import XAMIMultiModalSum
import torch.nn as nn
import torch.optim as optim
import torch 
from torch.autograd import Variable
import os
import sys
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from torch.utils.data import DataLoader
from datetime import datetime


In [2]:
# checking if the GPU is available
use_gpu = torch.cuda.is_available()

# setting up the device
device = 'cuda' if use_gpu else 'cpu'

In [3]:
# load the dataset
reflacx_dataset = REFLACXWithClinicalDataset()

# prepare the model
xami_mutlimodal = XAMIMultiModalSum(reflacx_dataset, device)
xami_mutlimodal = xami_mutlimodal.to(device)

In [4]:
## prepare learning parameters

lr = 0.0001
batch_size = 64

optimizer = optim.Adam(filter(lambda p: p.requires_grad, xami_mutlimodal.parameters()), lr= lr, weight_decay= 0)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, threshold=0.001, factor=0.1)


In [5]:
# seperate the dataset

train_dataset_len = int(len(reflacx_dataset) * .8)
test_dataset_len = int(len(reflacx_dataset) * .1)
val_dataset_len = len(reflacx_dataset) - (train_dataset_len + test_dataset_len)


# how does it seperate tthem?
(
    train_dataset,
    val_dataset,
    test_dataset
) = torch.utils.data.random_split(
    dataset= reflacx_dataset,
    lengths = [train_dataset_len, val_dataset_len, test_dataset_len],
    generator = torch.Generator().manual_seed(
        123
    )
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    collate_fn= reflacx_dataset.train_collate_fn
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size = batch_size,
    shuffle = True,
    collate_fn= reflacx_dataset.test_collate_fn
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size= batch_size,
    shuffle=True,
    collate_fn= reflacx_dataset.test_collate_fn
)



In [7]:
def transform_data(data, device):
    image, clinical_data, label = data
    image = image.to(device)
    label = label.to(device)
    clinical_numerical_data, clinical_categorical_data = clinical_data
    clinical_numerical_data = clinical_numerical_data.to(device)

    for col in clinical_categorical_data.keys():
        clinical_categorical_data[col] = clinical_categorical_data[col].to(device)

    clinical_data = (clinical_numerical_data, clinical_categorical_data)

    image = Variable(image, requires_grad=False)
    label = Variable(label, requires_grad=False)

    clinical_numerical_data, clinical_categorical_data = clinical_data
    clinical_numerical_data = Variable(clinical_numerical_data, requires_grad=False)

    for col in clinical_categorical_data.keys():
        clinical_categorical_data[col] = Variable(clinical_categorical_data[col], requires_grad= False)

    clinical_data = (clinical_numerical_data, clinical_categorical_data)

    return image, clinical_data,label

In [8]:
# implement test 1 epoch here
def train_epoch(epoch, model, dataloader, loss_fn, optimizer):
    model.train()
    model.to(device)

    batch_losses = []
    batch_accuracy = []
    batch_auc = []

    epoch = 1

    for batch_idx, data in enumerate(dataloader):
        image, clinical_data, label = transform_data(data, device, train=True)
        optimizer.zero_grad()
        outputs = model(image, clinical_data)
        loss = loss_fn(outputs, label)
        loss.backward()
        optimizer.step()
        print("Epoch: {:d} Batch:  ({:d}) Train Loss: {:.4f}".format(
            epoch, batch_idx, loss.item()))
        sys.stdout.flush()

        batch_losses.append(loss.item())

        # want accuracy here.
        batch_accuracy.append(
            accuracy_score(
                label.detach().cpu().numpy().flatten(), (outputs.detach().cpu().numpy() > 0.5).astype('int64').flatten())
        )
        batch_auc.append(roc_auc_score(label.detach().cpu().numpy(
        ).flatten(), outputs.detach().cpu().numpy().flatten()))


    train_loss = np.mean(batch_losses)
    train_acc = np.mean(batch_accuracy)
    train_auc = np.mean(batch_auc)

    print(
        f"Epoch {epoch} | Loss: {train_loss:.2f} | ACC: {train_acc*100:.2f}% | AUC: {train_auc:.2f}")

    return train_loss, train_acc, train_auc


In [9]:
# implement train 1 epoch here
def test_epoch(epoch, model, dataloader, loss_fn, ):
    model.eval()
    model.to(device)

    batch_losses = []
    batch_accuracy = []
    batch_auc = []

    epoch = 1

    with torch.no_grad():
        for batch_idx, data in enumerate(dataloader):
            image, clinical_data, label = transform_data(
                data, device, train=False)
            outputs = model(image, clinical_data)
            loss = loss_fn(outputs, label)
            batch_losses.append(loss.item())

            # want accuracy here.
            batch_accuracy.append(
                accuracy_score(
                    label.detach().cpu().numpy().flatten(), (outputs.detach().cpu().numpy() > 0.5).astype('int64').flatten())
            )
            batch_auc.append(roc_auc_score(label.detach().cpu().numpy(
            ).flatten(), outputs.detach().cpu().numpy().flatten()))

            print("Epoch: {:d} Batch:  ({:d}) Test Loss: {:.4f}".format(
                epoch, batch_idx, loss.item()))
            sys.stdout.flush()

  

    test_loss = np.mean(batch_losses)
    test_acc = np.mean(batch_accuracy)
    test_auc = np.mean(batch_auc)

    print(
        f"Epoch {epoch} | Loss: {test_loss:.2f} | ACC: {test_acc*100:.2f}% | AUC: {test_auc:.2f}")

    return test_loss, test_acc, test_auc


In [10]:
loss_fn = nn.MultiLabelSoftMarginLoss()

num_epochs = 2
best_model_wts, best_loss = xami_mutlimodal.state_dict(), float("inf")
counter = 0
num_epochs = 1


for epoch in range(1, num_epochs + 1):
    print("Epoch {}/{}".format(epoch, num_epochs))
    print("-" * 10)

    train_loss, train_acc, _ = train_epoch(epoch, xami_mutlimodal, dataloader=train_dataloader,
                loss_fn=loss_fn, optimizer=optimizer)
    val_loss, val_acc, _ = test_epoch(epoch, xami_mutlimodal,
                                dataloader=val_dataloader, loss_fn=loss_fn,)
    
    scheduler.step(val_loss)

    if (val_loss < best_loss):

        best_loss = val_loss
        best_model_wts = xami_mutlimodal.state_dict()
        counter = 0
    
    else:
        counter += 1

    if counter > 3:
            break

    torch.save(best_model_wts, os.path.join("saved_models", f"{val_loss:.4f}_{str(datetime.now())}".replace(":","_")))

print(f"Best Validation Loss: {best_loss:.4f}")

Epoch 1/1
----------
Epoch: 1 Batch:  (0) Train Loss: 0.9523
Epoch: 1 Batch:  (1) Train Loss: 0.9579
Epoch: 1 Batch:  (2) Train Loss: 0.9647
Epoch: 1 Batch:  (3) Train Loss: 0.9411
Epoch: 1 Batch:  (4) Train Loss: 0.9508
Epoch 1 | Loss: 0.95 | ACC: 47.50% | AUC: 0.62
Epoch: 1 Batch:  (0) Test Loss: 0.9426
Epoch: 1 Batch:  (1) Test Loss: 0.9357
Epoch: 1 Batch:  (2) Test Loss: 0.9081
Epoch 1 | Loss: 0.93 | ACC: 52.40% | AUC: 0.64
Best Validation Loss: 0.9288
