In [None]:
import os

import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

from config import Config, Device
from datasets import MRIDataset, BalancedMRIDataset
from models import CNN, CNNOneChannel
from train import Trainer, TrainerOneChannel
from test_file import Tester

In [2]:
device = Device.device
print(device)

mps


In [3]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
image_size = 224

In [20]:
# older code - not adding augmented data for the minority class 
# train_transforms = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Resize((image_size, image_size)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(10),
#     # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# test_transforms = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Resize((image_size, image_size)),
#     # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

In [4]:
# new dataset code - adding augmented data for the minority class

train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224))
])

augment_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224))
])

In [5]:
'''
BalancedMRIDataset class is used.

MRIDatset -> original unbalabed dataset. ratio: 20/80
BalancedMRIDataset -> balanced dataset. ratio :42/58

augmented data is added to the minority class 
'''


train_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='train',
    transform=train_transforms,
    augment_transform=augment_transforms,
    max_slices=20,
    augment=True
)

val_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='val',
    transform=test_transforms,
    max_slices=20
)

test_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='test',
    transform=test_transforms,
    max_slices=20
)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=32)
test_dl = DataLoader(test_dataset, batch_size=32)

In [6]:
from collections import Counter

# Assuming train_dataset is an instance of BalancedMRIDataset or any PyTorch Dataset

# Initialize a Counter to count the labels
label_counts = Counter()

# Iterate through the dataset and update the Counter
for _, label in train_dataset:
    label_counts[label.item()] += 1

# Print the counts
print("Label distribution in train_dataset:")
for label, count in label_counts.items():
    print(f"Label {label}: {count} samples")


# Label distribution in train_dataset:
# Label 0: 1645 samples
# Label 1: 1175 samples


Label distribution in train_dataset:
Label 0: 1645 samples
Label 1: 1175 samples


In [None]:
'''
as long as dataste is so inbalanced, augmented data could be added to the existing training data
but this was not quit good idea
- transforms is made to apply on array and image not tensors
- transforms is applicable on 2D images therefore very confusing to apply on 3D tensor
- more efficient performance in terms of computations and memory

so best way is including this augmentation in dataset class -> BalancedMRIDataset
'''

In [8]:
model = CNN().to(device=device)
model

CNN(
  (cnn1): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=359552, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=1, bias=True)
)

In [10]:
def compute_class_weights_from_csv(csv_file_path):
    # Read the CSV file
    df = pd.read_csv(csv_file_path)

    labels = df['prediction'].values

    # Convert labels to integers if they are not already
    labels = labels.astype(int)

    # Compute class weights
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight(
        class_weight='balanced', classes=unique_labels, y=labels)

    # Convert to torch tensor
    return torch.tensor(class_weights, dtype=torch.float)


# Path to your CSV file
class_weights = compute_class_weights_from_csv(labels_path)

# For binary classification, use the appropriate class weight
# Assuming binary classification with class labels 0 and 1
class_weights = class_weights[1]  # Adjust if necessary
print("Class Weights:", class_weights)

Class Weights: tensor(4.0051)


In [11]:
class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, class_weights, device, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-8  # prevent training from Nan-loss error
        self.device = device
        
        # Ensure class_weights is a tensor and moved to the correct device
        self.class_weights = class_weights.clone().detach().to(self.device) if class_weights is not None else None

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt


        if self.class_weights is not None:
            focal_loss = focal_loss * self.class_weights
        
        return torch.mean(focal_loss)


In [7]:
# loss and optimizer

# criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights).to(device)
# criterion = FocalLoss(class_weights=class_weights, device=device, alpha=0.25, gamma=2).to(device)
criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
model_name = model.__class__.__name__
model_name

'CNNOneChannel'

In [None]:
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=10,
    patience=5,
    threshold=0,
    save_path=f"saved_models/{model_name}.pth"
)

# Start training
trainer.train()

In [60]:
torch.cuda.empty_cache()

In [None]:
model = CNN().to(device=device)
model.load_state_dict(torch.load("saved_models/model_parameters1.pth")) 
# why not CNN.pth ?

In [None]:
tester = Tester(
    model=model,
    test_dl=test_dl,
    test_dataset=test_dataset,
    device=device,
    threshold=0.5  # Set the threshold for binary classification
)

# Perform testing and print metrics
tester.test()