In [1]:
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset, WeightedRandomSampler, Subset, DataLoader
import os
import torchxrayvision as xrv
import torchvision.transforms as transforms
from skimage.color import rgb2gray
from skimage.transform import resize
import pydicom
from torchxrayvision.datasets import XRayCenterCrop
import pandas as pd
import wandb
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, cohen_kappa_score, classification_report, confusion_matrix
import helpers, train_utils, classes
from collections import Counter
import torch
import torch.nn as nn

  from tqdm.autonotebook import tqdm


In [2]:
dicom_dir_1 = 'C:/Users/user-pc/Masters/MSc - Project/MBOD_Datasets/Dataset 1'
metadata_1 = pd.read_excel('C:/Users/user-pc/Masters/MSc - Project/MBOD_Datasets/Dataset 1/FileDatabaseWithRadiology.xlsx')
dicom_dir_2 = 'C:/Users/user-pc/Masters/MSc - Project/MBOD_Datasets/Dataset 2'
metadata_2 = pd.read_excel('C:/Users/user-pc/Masters/MSc - Project/MBOD_Datasets/Dataset 2/Database_Training-2024.08.28.xlsx')

d1 = classes.DICOMDataset1(dicom_dir=dicom_dir_1, metadata_df=metadata_1, target_size=224) 
d2 = classes.DICOMDataset2(dicom_dir=dicom_dir_2, metadata_df=metadata_2, target_size=224)

# Split datasets and store indices
train_indices_d1, val_indices_d1, test_indices_d1 = helpers.split_dataset(d1)
train_indices_d2, val_indices_d2, test_indices_d2 = helpers.split_dataset(d2)

# Save indices for later use
split_indices = {
    'd1': {'train': train_indices_d1, 'val': val_indices_d1, 'test': test_indices_d1},
    'd2': {'train': train_indices_d2, 'val': val_indices_d2, 'test': test_indices_d2}
}

label = 'Profusion'
d1.set_target(target_label=label, target_size=224)
d2.set_target(target_label=label, target_size=224)

train_d1 = Subset(d1, train_indices_d1)
val_d1 = Subset(d1, val_indices_d1)
test_d1 = Subset(d1, test_indices_d1)

train_d2 = Subset(d2, train_indices_d2)
val_d2 = Subset(d2, val_indices_d2)
test_d2 = Subset(d2, test_indices_d2)

In [3]:
def create_weighted_sampler(dataset, target_label):
    # Calculate class weights
    class_counts = np.bincount([label for _, label in dataset])
    class_weights = 1. / class_counts
    sample_weights = [class_weights[label] for _, label in dataset]

    # Create a weighted sampler
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    return sampler

# Create the base datasets
train_d1 = Subset(d1, train_indices_d1)
train_d2 = Subset(d2, train_indices_d2)

# Define augmentations
augmentations_list = [
    transforms.RandomHorizontalFlip(p=1.0),
    transforms.RandomRotation(15),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
]

# Create augmented datasets
augmented_train_d1 = classes.AugmentedDataset(base_dataset=train_d1, augmentations_list=augmentations_list)
augmented_train_d2 = classes.AugmentedDataset(base_dataset=train_d2, augmentations_list=augmentations_list)

    # Create dataloaders
train_loader_d1, train_aug_loader_d1, val_loader_d1, test_loader_d1 = helpers.create_dataloaders(
    train_d1, augmented_train_d1, val_d1, test_d1, batch_size=32, oversam=True, target=label
)

train_loader_d2, train_aug_loader_d2, val_loader_d2, test_loader_d2 = helpers.create_dataloaders(
    train_d2, augmented_train_d2, val_d2, test_d2, batch_size=32, oversam=True, target=label
)

Sampler: <torch.utils.data.sampler.WeightedRandomSampler object at 0x000001AAF3B5DD50>
Sampler: <torch.utils.data.sampler.WeightedRandomSampler object at 0x000001AAF3E85F90>


In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = xrv.models.DenseNet(weights="densenet121-res224-all").to(device)
model.classifier = classes.BaseClassifier(in_features=1024
                                          
                                          )
augmentations = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
])
