In [None]:
import os

import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

from config import Config, Device
from datasets import MRIDataset, BalancedMRIDataset
from models import CNN, CNNOneChannel
from train import Trainer, TrainerOneChannel
from test import Tester, TesterOneChannel

In [3]:
device = Device.device
print(device)

mps


In [4]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
mean = Config.mean # mean of the entire datasaet
std = Config.std # std of the entire dataaset
image_size = 224

In [5]:
# new dataset code - adding augmented data for the minority class

'''
new dataset -> adding augmented data for the minority class to avoide unbalanced data

ToTensor -> re-scales the data to the range [0,1]

Note -> in case of pretrained models typically: Normalize(mean=0.5, std=0.5)
'''

resclaed_mean = round(mean/255,4) # re-scale the actual mean
rescaled_std = round(std/mean, 4) # re-scale the actual std

train_transforms = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

augment_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

test_transforms = transforms.Compose([
    # transforms.Lambda(lambda img: img.astype(np.float32)),
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

In [6]:
'''
BalancedMRIDataset class is used.

MRIDatset -> original unbalabed dataset. ratio: 20/80
BalancedMRIDataset -> balanced dataset. ratio :42/58

augmented data is added to the minority class 

random seceltion padding is used instead of zero padding
'''

train_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='train',
    transform=train_transforms,
    augment_transform=augment_transforms,
    max_slices=20,
    augment=True
)

val_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='val',
    transform=test_transforms,
    max_slices=20
)

test_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='test',
    transform=test_transforms,
    max_slices=20
)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=32)
test_dl = DataLoader(test_dataset, batch_size=32)

In [8]:
data_, label_ = next(iter(train_dl))
data_.size()

torch.Size([32, 20, 224, 224])

In [9]:
model = CNNOneChannel().to(device=device)
model

CNNOneChannel(
  (cnn1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=359552, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=1, bias=True)
)

In [7]:
def compute_class_weights_from_csv(csv_file_path):
    # Read the CSV file
    df = pd.read_csv(csv_file_path)

    labels = df['prediction'].values

    # Convert labels to integers if they are not already
    labels = labels.astype(int)

    # Compute class weights
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight(
        class_weight='balanced', classes=unique_labels, y=labels)

    # Convert to torch tensor
    return torch.tensor(class_weights, dtype=torch.float)


# Path to your CSV file
class_weights = compute_class_weights_from_csv(labels_path)

# For binary classification, use the appropriate class weight
# Assuming binary classification with class labels 0 and 1
class_weights = class_weights[1]  # Adjust if necessary
print("Class Weights:", class_weights)

Class Weights: tensor(4.0051)


In [None]:
class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, class_weights, device, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-8  # prevent training from Nan-loss error
        self.device = device
        
        # Ensure class_weights is a tensor and moved to the correct device
        self.class_weights = class_weights.clone().detach().to(self.device) if class_weights is not None else None

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt


        if self.class_weights is not None:
            focal_loss = focal_loss * self.class_weights
        
        return torch.mean(focal_loss)


In [10]:
# loss and optimizer
'''
dataset is almost balanced so pos_weight and FocalLoss is not chosen 
'''

# criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights).to(device)
# criterion = FocalLoss(class_weights=class_weights, device=device, alpha=0.25, gamma=2).to(device)
criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
model_name = model.__class__.__name__
model_name

'CNNOneChannel'

In [None]:
'''
images have one channel therefore feeding needs to be adjusted.
for each patient each image is feeded to the model seperately and then 
the average outputs goes to loss function

'''

trainer = TrainerOneChannel(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=3,
    patience=5,
    threshold=0.5,
    save_path=f"saved_models/{model_name}.pth"
)

# Start training
trainer.train()

In [10]:
model = CNNOneChannel().to(device=device)
model.load_state_dict(torch.load("saved_models/CNNOneChannel.pth"))

<All keys matched successfully>

In [11]:
tester = TesterOneChannel(
    model=model,
    test_dl=test_dl,
    test_dataset=test_dataset,
    device=device,
    threshold=0.5  # Set the threshold for binary classification
)

# Perform testing and print metrics
tester.test()

Test Accuracy: 27.4952, Precision: 0.0000, Recall: 0.0000, AUC: 0.5158, Avg Metric: 0.1719


In [25]:
torch.cuda.empty_cache()