In [None]:
import os
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Dataset

import pydicom
from tqdm import tqdm

from config import Device

In [2]:
device = Device.device
print(device)

mps


In [3]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "labels.csv"

In [4]:
batch_size = 32
num_epochs = 10
learning_rate = 0.001

In [5]:
from datasets import MRIDataset

data_path = os.path.join("C:\\Users\\asus\\Desktop\\iaaa-mri-train", "data")
labels_path = "train.csv"
batch_size = 32

train_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),  # resize images to 224x224, required by resnet-50
        transforms.RandomHorizontalFlip(), 
        transforms.RandomRotation(10),  
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        # resnet transform has not been applied
])  

test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        # resnet transform has not been applied
    ])


train_dataset = MRIDataset(
    data_path, labels_path, split="train", transform=train_transforms, max_slices=20)
val_dataset = MRIDataset(data_path, labels_path, split="val",
                         transform=train_transforms, max_slices=20)
test_dataset = MRIDataset(
    data_path, labels_path, split="test", transform=test_transforms, max_slices=20)


train_dl = DataLoader(train_dataset, batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size)
test_dl = DataLoader(test_dataset, batch_size)

In [6]:
data_, label_ = next(iter(train_dataset))
print(data_.size())
print(label_.size())

torch.Size([20, 224, 224])
torch.Size([1])


In [7]:
data_, label_ = next(iter(train_dl))
print(data_.size())
print(label_.size())

torch.Size([32, 20, 224, 224])
torch.Size([32, 1])


In [None]:
'''
challenge: resnet input: 3 channels images, the mri images: 20 channels
-> 1 layer (conv2d) to transform 16 to 3 berfore resnet model

challenge: using pretrained resent50 model and update all parameters (aksh-ai)
or just fine tune last layer
-> update all parameters

challenge: resnet-50 transform normalizing is not applicable
we have gray scale 18 channel data 
'''

In [8]:
class MriResentModel(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        self.conv = nn.Conv2d(in_dim, 3, kernel_size=1, padding=0)
        
        # instantiate transfer learning model
        self.resnet_model = models.resnet50(pretrained=True)

        # set all paramters as trainable
        for param in self.resnet_model.parameters():
            param.requires_grad = True

        # get input of fc layer
        n_inputs = self.resnet_model.fc.in_features # 2048

        # redefine fc layer / top layer/ head for our classification problem
        self.resnet_model.fc = nn.Sequential(nn.Linear(n_inputs, 2048),
                                        nn.SELU(),
                                        nn.Dropout(p=0.4),
                                        nn.Linear(2048, 2048),
                                        nn.SELU(),
                                        nn.Dropout(p=0.4),
                                        nn.Linear(2048, out_dim),
                                        nn.LogSigmoid())

        # set model to run on GPU or CPU absed on availibility
        # self.resnet_model.to(device)

    def forward(self, x):
        x = self.conv(x)
        x = self.resnet_model(x)
        return x

In [18]:
mri_resnet_model = MriResentModel(20,1)
mri_resnet_model.to(device)

MriResentModel(
  (conv): Conv2d(20, 3, kernel_size=(1, 1), stride=(1, 1))
  (resnet_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        

In [19]:
random_data = torch.randn(1,20,224,224)
random_data = random_data.to(device)
mri_resnet_model(random_data).size()

torch.Size([1, 1])

In [20]:
data_, label_ = next(iter(train_dl))
data_ = data_.float().to(device) # changes to float32 which is compatible for 'mps'
mri_resnet_model(data_).size()

torch.Size([32, 1])

In [25]:
# loss and optimizer

# criterion = nn.BCEWithLogitsLoss()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mri_resnet_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

In [26]:
import torch
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, roc_auc_score
import numpy as np

# Early stopping parameters
patience = 5
best_avg_metric = 0
epochs_no_improve = 0

for epoch in range(num_epochs):
    mri_resnet_model.train()
    train_loss = 0.0
    train_correct = 0.0

    for images, labels in tqdm(train_dl):
        images = images.float().to(device=device)
        labels = labels.float().to(device=device) # BCEWithLogitsLoss expects floats

        # forward pass
        outputs = mri_resnet_model(images)
        loss = criterion(outputs, labels)

        # backward and optimizer
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        preds = (outputs > 0).float()  # convert logits to binary predictions
        train_correct += (preds == labels).sum().item()

    train_accuracy = train_correct / len(train_dataset)

    # Evaluate
    mri_resnet_model.eval()
    test_correct = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in test_dl:
            images = images.float().to(device=device)
            labels = labels.float().to(device=device)

            outputs = mri_resnet_model(images)

            preds = (outputs > 0).float()  # convert logits to binary predictions
            test_correct += (preds == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    test_accuracy = test_correct / len(test_dataset)

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_preds_binary = (all_preds > 0).astype(float)

    precision = precision_score(all_labels, all_preds_binary)
    recall = recall_score(all_labels, all_preds_binary)
    auc = roc_auc_score(all_labels, all_preds)

    avg_metric = (precision + recall + auc) / 3

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Epoch {epoch+1}/{num_epochs}, Test Accuracy: {test_accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, AUC: {auc:.4f}, Avg Metric: {avg_metric:.4f}")

    # Early stopping
    if avg_metric > best_avg_metric:
        best_avg_metric = avg_metric
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print("Early stopping triggered")
        break


100%|██████████| 59/59 [04:03<00:00,  4.12s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1/10, Train Loss: 0.0000, Train Accuracy: 0.8750
Epoch 1/10, Test Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.5569, Avg Metric: 0.1856


100%|██████████| 59/59 [04:02<00:00,  4.11s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2/10, Train Loss: 0.0000, Train Accuracy: 0.8750
Epoch 2/10, Test Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.5547, Avg Metric: 0.1849


  0%|          | 0/59 [00:02<?, ?it/s]


KeyboardInterrupt: 