ResNet-50 Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from torchvision import datasets, models, transforms
import time
import copy
import os
from PIL import Image, ImageFile
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

ImageFile.LOAD_TRUNCATED_IMAGES = True # Truncation for image loading

# Image loader
def ImageLoader(path):
    try:
        img = Image.open(path)
        return img.convert('RGB')
    except Exception as e:
        print(f"Skipping corrupt image: {path}")
        return Image.new('RGB', (224, 224))

# Ensure we are using GPU, which is cudo:0 in this case (1st aviable GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transforms for augmentation
dataTransforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

def loadDatasets():
    ds1Path = '/content/chest_xray'
    ds2Path = '/content/chest_xray2'

    class PneumoniaDataset(torch.utils.data.Dataset):
        def __init__(self, rootDir, transform=None):
            self.rootDir = rootDir
            self.transform = transform
            self.classes = ['NORMAL', 'PNEUMONIA']
            self.classToIdx = {'NORMAL': 0, 'PNEUMONIA': 1}
            self.samples = []

            for className in self.classes:
                classDir = os.path.join(rootDir, className)
                if os.path.isdir(classDir):
                    for filename in os.listdir(classDir):
                        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                            path = os.path.join(classDir, filename)
                            label = self.classToIdx[className]
                            self.samples.append((path, label))

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            path, label = self.samples[idx]
            img = ImageLoader(path)
            if self.transform:
                img = self.transform(img)
            return img, label

    # Load all of the datasets
    ds1Train = PneumoniaDataset(os.path.join(ds1Path, 'train'))
    ds2Train = PneumoniaDataset(os.path.join(ds2Path, 'train'))
    ds1Test = PneumoniaDataset(os.path.join(ds1Path, 'test'))
    ds2Test = PneumoniaDataset(os.path.join(ds2Path, 'test'))

    # Combine all samples from both datasets
    allSamples = ds1Train.samples + ds2Train.samples + ds1Test.samples + ds2Test.samples
    allLabels = [label for _, label in allSamples]

    # 70/20/10 split
    trainIdx, testValIdx = train_test_split(
        range(len(allSamples)),
        test_size=0.3,
        stratify=allLabels,
        random_state=42
    )

    testIdx, valIdx = train_test_split(
        testValIdx,
        test_size=1/3,
        stratify=[allLabels[i] for i in testValIdx],
        random_state=42
    )

    # Create base dataset for subsets to manyally contol data entries
    baseDataset = PneumoniaDataset('', transform=None)
    baseDataset.samples = allSamples

    # Subsets with transforms (standard split)
    trainDataset = Subset(baseDataset, trainIdx)
    valDataset = Subset(baseDataset, valIdx)
    testDataset = Subset(baseDataset, testIdx)

    trainDataset.dataset.transform = dataTransforms['train']
    valDataset.dataset.transform = dataTransforms['val']
    testDataset.dataset.transform = dataTransforms['test']

    # Class weights
    trainLabels = [allLabels[i] for i in trainIdx]
    classCounts = np.bincount(trainLabels)
    print(f"Class distribution in training set: {classCounts}")

    weights = 1. / torch.tensor(classCounts, dtype=torch.float32)
    sampleWeights = weights[trainLabels]
    sampler = WeightedRandomSampler(
        weights=sampleWeights,
        num_samples=len(sampleWeights),
        replacement=True
    )

    # Print dataset sizes
    totalSize = len(allSamples)
    print(f"\nDataset splits:")
    print(f"Train: {len(trainIdx)} ({len(trainIdx)/totalSize:.1%})")
    print(f"Val: {len(valIdx)} ({len(valIdx)/totalSize:.1%})")
    print(f"Test: {len(testIdx)} ({len(testIdx)/totalSize:.1%})")

    return trainDataset, valDataset, testDataset, sampler

trainDataset, valDataset, testDataset, sampler = loadDatasets()

dataloaders = {
    'train': DataLoader(trainDataset, batch_size=32, sampler=sampler, num_workers=4),
    'val': DataLoader(valDataset, batch_size=32, shuffle=False, num_workers=4),
    'test': DataLoader(testDataset, batch_size=32, shuffle=False, num_workers=4)
}

# Initialize the model
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
numFeatures = model.fc.in_features
model.fc = nn.Linear(numFeatures, 2)
model = model.to(device)

# Optimizer and scheduler
trainLabels = [label for _, label in trainDataset]
classCounts = np.bincount(trainLabels)
classWeights = torch.tensor([1.0/classCounts[0], 1.0/classCounts[1]], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=classWeights)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Training function w/ mixed precision
def trainModel(model, criterion, optimizer, scheduler, numEpochs=25):
    since = time.time()
    bestModelWeights = copy.deepcopy(model.state_dict())
    bestAcc = 0.0
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(numEpochs):
        print(f'Epoch {epoch}/{numEpochs-1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            runningLoss = 0.0
            runningCorrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    with torch.cuda.amp.autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                runningLoss += loss.item() * inputs.size(0)
                runningCorrects += torch.sum(preds == labels.data)

            epochLoss = runningLoss / len(dataloaders[phase].dataset)
            epochAcc = runningCorrects.double() / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epochLoss:.4f} Acc: {epochAcc:.4f}')

            if phase == 'val' and epochAcc > bestAcc:
                bestAcc = epochAcc
                bestModelWeights = copy.deepcopy(model.state_dict())

        scheduler.step()
        print()

    timeElapsed = time.time() - since
    print(f'Training complete in {timeElapsed//60:.0f}m {timeElapsed%60:.0f}s')
    print(f'Best val Acc: {bestAcc:.4f}')

    model.load_state_dict(bestModelWeights)
    return model

print("Training ResNet-50...")
model = trainModel(model, criterion, optimizer, scheduler, numEpochs=25) # Train model for 25 epochs

savePath = "/content/drive/MyDrive/pneumonia_resnet50.pth"
torch.save(model.state_dict(), savePath)
print(f"Model saved to {savePath}")

# Evaluation function
def evaluateModel(model, dataloader):
    model.eval()
    allPreds = []
    allLabels = []
    allProbs = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            probs = torch.nn.functional.softmax(outputs, dim=1)

            allPreds.extend(preds.cpu().numpy())
            allLabels.extend(labels.cpu().numpy())
            allProbs.extend(probs.cpu().numpy())

    return allLabels, allPreds, allProbs

testLabels, testPreds, testProbs = evaluateModel(model, dataloaders['test']) # Test evaluation 

print("\nClassification Report:")
print(classification_report(testLabels, testPreds, target_names=['NORMAL', 'PNEUMONIA']))

print("\nConfusion Matrix:")
print(confusion_matrix(testLabels, testPreds))

testProbsArray = np.array(testProbs)
print(f"\nROC AUC: {roc_auc_score(testLabels, testProbsArray[:, 1]):.4f}")


DenseNet-121 Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from torchvision import models, transforms
import time
import copy
import os
from PIL import Image, ImageFile
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

ImageFile.LOAD_TRUNCATED_IMAGES = True # Truncation for image loading

# Image loader
def imageLoader(path):
    try:
        img = Image.open(path)
        return img.convert('RGB')
    except Exception as e:
        print(f"Skipping corrupt image: {path}")
        return Image.new('RGB', (224, 224))

# Ensure we are using GPU, which is cudo:0 in this case (1st aviable GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transforms with augmentation
dataTransforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

def loadDatasets():
    ds1Path = '/content/chest_xray'
    ds2Path = '/content/chest_xray2'

    class PneumoniaDataset(torch.utils.data.Dataset):
        def __init__(self, rootDir, transform=None):
            self.rootDir = rootDir
            self.transform = transform
            self.classes = ['NORMAL', 'PNEUMONIA']
            self.classToIdx = {'NORMAL': 0, 'PNEUMONIA': 1}
            self.samples = []

            for className in self.classes:
                classDir = os.path.join(rootDir, className)
                if os.path.isdir(classDir):
                    for filename in os.listdir(classDir):
                        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                            path = os.path.join(classDir, filename)
                            label = self.classToIdx[className]
                            self.samples.append((path, label))

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            path, label = self.samples[idx]
            img = imageLoader(path)

            if self.transform:
                img = self.transform(img)

            return img, label

    ds1Train = PneumoniaDataset(os.path.join(ds1Path, 'train'))
    ds2Train = PneumoniaDataset(os.path.join(ds2Path, 'train'))
    ds1Test = PneumoniaDataset(os.path.join(ds1Path, 'test'))
    ds2Test = PneumoniaDataset(os.path.join(ds2Path, 'test'))

    allSamples = ds1Train.samples + ds2Train.samples + ds1Test.samples + ds2Test.samples
    allLabels = [label for _, label in allSamples]

    trainIdx, testValIdx = train_test_split(
        range(len(allSamples)),
        test_size=0.3,
        stratify=allLabels,
        random_state=42
    )

    testIdx, valIdx = train_test_split(
        testValIdx,
        test_size=1/3,
        stratify=[allLabels[i] for i in testValIdx],
        random_state=42
    )

    baseDataset = PneumoniaDataset('', transform=None)
    baseDataset.samples = allSamples

    trainDataset = Subset(baseDataset, trainIdx)
    valDataset = Subset(baseDataset, valIdx)
    testDataset = Subset(baseDataset, testIdx)

    trainDataset.dataset.transform = dataTransforms['train']
    valDataset.dataset.transform = dataTransforms['val']
    testDataset.dataset.transform = dataTransforms['test']

    trainLabels = [allLabels[i] for i in trainIdx]
    classCounts = np.bincount(trainLabels)
    print(f"Class distribution in training set: {classCounts}")

    weights = 1. / torch.tensor(classCounts, dtype=torch.float32)
    sampleWeights = weights[trainLabels]
    sampler = WeightedRandomSampler(
        weights=sampleWeights,
        num_samples=len(sampleWeights),
        replacement=True
    )

    totalSize = len(allSamples)
    print(f"\nDataset splits:")
    print(f"Train: {len(trainIdx)} ({len(trainIdx)/totalSize:.1%})")
    print(f"Val: {len(valIdx)} ({len(valIdx)/totalSize:.1%})")
    print(f"Test: {len(testIdx)} ({len(testIdx)/totalSize:.1%})")

    return trainDataset, valDataset, testDataset, sampler

# Load data
trainDataset, valDataset, testDataset, sampler = loadDatasets()

# Create dataloaders
dataLoaders = {
    'train': DataLoader(trainDataset, batch_size=32, sampler=sampler, num_workers=4),
    'val': DataLoader(valDataset, batch_size=32, shuffle=False, num_workers=4),
    'test': DataLoader(testDataset, batch_size=32, shuffle=False, num_workers=4)
}

# Initialize DenseNet-121 model
model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)

# Replaec original classifer with linear layer that outputs binary classification
numFeatures = model.classifier.in_features
model.classifier = nn.Linear(numFeatures, 2)
model = model.to(device)

# Loss function with class weighting
trainLabels = [label for _, label in trainDataset]
classCounts = np.bincount(trainLabels)
classWeights = torch.tensor([1.0/classCounts[0], 1.0/classCounts[1]], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=classWeights)

# Optimizer and scheduler
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Training function
def trainModel(model, criterion, optimizer, scheduler, numEpochs=25):
    since = time.time()
    bestModelWeights = copy.deepcopy(model.state_dict())
    bestAcc = 0.0
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(numEpochs):
        print(f'Epoch {epoch}/{numEpochs-1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            runningLoss = 0.0
            runningCorrects = 0

            for inputs, labels in dataLoaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    with torch.cuda.amp.autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                runningLoss += loss.item() * inputs.size(0)
                runningCorrects += torch.sum(preds == labels.data)

            epochLoss = runningLoss / len(dataLoaders[phase].dataset)
            epochAcc = runningCorrects.double() / len(dataLoaders[phase].dataset)

            print(f'{phase} Loss: {epochLoss:.4f} Acc: {epochAcc:.4f}')

            if phase == 'val' and epochAcc > bestAcc:
                bestAcc = epochAcc
                bestModelWeights = copy.deepcopy(model.state_dict())

        scheduler.step()
        print()

    timeElapsed = time.time() - since
    print(f'Training complete in {timeElapsed//60:.0f}m {timeElapsed%60:.0f}s')
    print(f'Best val Acc: {bestAcc:.4f}')

    model.load_state_dict(bestModelWeights)
    return model


print("Training DenseNet-121...")
model = trainModel(model, criterion, optimizer, scheduler, numEpochs=25) # Train model

# Save model
savePath = "/content/drive/MyDrive/pneumonia_densenet121.pth"
torch.save(model.state_dict(), savePath)
print(f"Model saved to {savePath}")

# Evaluation function
def evaluateModel(model, dataLoader):
    model.eval()
    allPreds = []
    allLabels = []
    allProbs = []

    with torch.no_grad():
        for inputs, labels in dataLoader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            probs = torch.nn.functional.softmax(outputs, dim=1)

            allPreds.extend(preds.cpu().numpy())
            allLabels.extend(labels.cpu().numpy())
            allProbs.extend(probs.cpu().numpy())

    return allLabels, allPreds, allProbs

# Test evaluations
testLabels, testPreds, testProbs = evaluateModel(model, dataLoaders['test'])

print("\nClassification Report:")
print(classification_report(testLabels, testPreds, target_names=['NORMAL', 'PNEUMONIA']))

print("\nConfusion Matrix:")
print(confusion_matrix(testLabels, testPreds))

testProbsArray = np.array(testProbs)
print(f"\nROC AUC: {roc_auc_score(testLabels, testProbsArray[:, 1]):.4f}")


PneumoniaNet (Custom CNN Architecture)

In [None]:
import torch
from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler, Subset
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
import os
from PIL import Image, ImageFile
import numpy as np
import time
import copy
from sklearn.model_selection import train_test_split

ImageFile.LOAD_TRUNCATED_IMAGES = True # Truncation for image loading

# Image loader
def imageLoader(path):
    try:
        img = Image.open(path)
        return img.convert('RGB')
    except Exception as e:
        print(f"Skipping corrupt image: {path}")
        return Image.new('RGB', (224, 224))  

# Ensure we are using GPU, which is cudo:0 in this case (1st aviable GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transforms with augmentation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  # rotation
        transforms.ColorJitter(brightness=0.2, contrast=0.2),  # color jitter
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # small translations
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

def load_datasets():
    ds1_path = '/content/chest_xray'
    ds2_path = '/content/chest_xray2'

    # Custom dataset class for the data
    class PneumoniaDataset(torch.utils.data.Dataset):
        def __init__(self, root_dir, transform=None, loader=imageLoader):
            self.root_dir = root_dir
            self.transform = transform
            self.loader = loader
            self.classes = ['NORMAL', 'PNEUMONIA']  # Match classification names
            self.class_to_idx = {'NORMAL': 0, 'PNEUMONIA': 1}

            # Find all image paths and their labels
            self.samples = []
            for class_name in self.classes:
                class_dir = os.path.join(root_dir, class_name)
                if os.path.isdir(class_dir):
                    for filename in os.listdir(class_dir):
                        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                            path = os.path.join(class_dir, filename)
                            label = self.class_to_idx[class_name]
                            self.samples.append((path, label))

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            path, label = self.samples[idx]
            img = self.loader(path)

            # Filter out blank images (corrupted or empty imagges)
            if img is not None and all(np.array(img).flatten() == 0):
                img = Image.new('RGB', (224, 224)) # default placeholder if image is not corrupted

            if self.transform:
                img = self.transform(img)

            return img.float(), label  # Return as float32

    # Load all datasets from both sources (without transforms for now)
    ds1_train_set = PneumoniaDataset(os.path.join(ds1_path, 'train'), transform=None)
    ds2_train_set = PneumoniaDataset(os.path.join(ds2_path, 'train'), transform=None)
    ds1_test_set = PneumoniaDataset(os.path.join(ds1_path, 'test'), transform=None)
    ds2_test_set = PneumoniaDataset(os.path.join(ds2_path, 'test'), transform=None)

    # Combine all samples
    all_samples = ds1_train_set.samples + ds2_train_set.samples + ds1_test_set.samples + ds2_test_set.samples
    all_labels = [label for _, label in all_samples]

    # Create indices for for the 70/20/10 split
    indices = list(range(len(all_samples)))
    train_idx, temp_idx = train_test_split(
        indices,
        test_size=0.3,  # 30% for test+validation
        stratify=all_labels,
        random_state=42
    )

    # Split the remaining 30% into 20% test and 10% validation
    test_idx, val_idx = train_test_split(
        temp_idx,
        test_size=0.33,  # 10% of total (1/3 of 30%)
        stratify=[all_labels[i] for i in temp_idx],
        random_state=42
    )

    # Create dataset splits with appropriate transforms
    class SubsetDataset(torch.utils.data.Dataset):
        def __init__(self, samples, transform=None, loader=imageLoader):
            self.samples = samples
            self.transform = transform
            self.loader = loader

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            path, label = self.samples[idx]
            img = self.loader(path)

            if self.transform:
                img = self.transform(img)

            return img.float(), label  # Ensure float32 output

    # Create subsets with proper transforms
    train_samples = [all_samples[i] for i in train_idx]
    val_samples = [all_samples[i] for i in val_idx]
    test_samples = [all_samples[i] for i in test_idx]

    train_dataset = SubsetDataset(train_samples, transform=data_transforms['train'])
    val_dataset = SubsetDataset(val_samples, transform=data_transforms['val'])
    test_dataset = SubsetDataset(test_samples, transform=data_transforms['val'])

    # Extract train labels for computing class weights and sampler
    train_labels = [label for _, label in train_samples]
    class_sample_counts = np.bincount(train_labels)
    print(f"Class distribution in training set: {class_sample_counts}")

    # Create weighted sampler for balanced training
    weights = 1. / torch.tensor(class_sample_counts, dtype=torch.float32)  # float32
    samples_weights = weights[train_labels]
    sampler = WeightedRandomSampler(samples_weights, len(samples_weights))

    # Display dataset sizes
    total_size = len(train_samples) + len(val_samples) + len(test_samples)
    print(f"\nDataset split sizes:")
    print(f"- Training: {len(train_samples)} images ({len(train_samples)/total_size*100:.1f}%)")
    print(f"- Validation: {len(val_samples)} images ({len(val_samples)/total_size*100:.1f}%)")
    print(f"- Test: {len(test_samples)} images ({len(test_samples)/total_size*100:.1f}%)")
    print(f"- Total: {total_size} images")

    return train_dataset, val_dataset, test_dataset, sampler

# Create datasets with new split
train_dataset, val_dataset, test_dataset, sampler = load_datasets()

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=16, sampler=sampler, num_workers=2, pin_memory=True),
    'val': DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True),
    'test': DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)
}

# Calculate dataset sizes
dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset),
    'test': len(test_dataset)
}

# Updated model with further regularization
class PneumoniaNet(nn.Module):
    def __init__(self, num_classes=2):
        super(PneumoniaNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(0.1),  # spatial dropout

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(0.1),  # spatial dropout

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(0.2),  # spatial dropout

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x.float()  # Ensure float32 output

# Using transfer learning model with ResNet18
def get_transfer_model(num_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

    # Freeze early layers. Do this because when we freeze a large portion of the model, 
    # we update fewer weights during backpropagation, reducing memory useage and speeds up trainging.
    for param in list(model.parameters())[:-8]:
        param.requires_grad = False

    # Replace final fully connected layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, num_classes)
    )

    return model.float()  # Ensure float32 output

use_transfer_learning = True
print(f"Using {'transfer learning model (ResNet18)' if use_transfer_learning else 'custom PneumoniaNet model'}")

# Initialize model
if use_transfer_learning:
    model = get_transfer_model().to(device)
else:
    model = PneumoniaNet().to(device)

# Calculate class weights based on training data
train_labels = [y for _, y in train_dataset.samples]
class_counts = np.bincount(train_labels)
total = sum(class_counts)
class_weights = torch.tensor([total/class_counts[0], total/class_counts[1]], dtype=torch.float32).to(device)  # float32
print(f"Class weights: {class_weights}")
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer with weight decay for L2 regularization
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)

# Improved learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3, verbose=True
)

# Enhanced training function with early stopping
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, patience=7):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_loss = float('inf')
    no_improve_epochs = 0

    # For tracking metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        # Gradient clipping to prevent exploding gradients
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accs.append(epoch_acc.item())
            else:
                val_losses.append(epoch_loss)
                val_accs.append(epoch_acc.item())

                # Update scheduler based on validation loss
                scheduler.step(epoch_loss)

            # Deep copy the model if best validation accuracy
            if phase == 'val':
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    no_improve_epochs = 0
                    # Save best model
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'accuracy': best_acc,
                    }, '/content/drive/MyDrive/pneumonia_best_model.pth')
                    print(f'New best model saved with accuracy: {best_acc:.4f}')
                else:
                    no_improve_epochs += 1

        # Keeping track of epochs without improvement, but not stopping
        if no_improve_epochs >= patience:
            print(f'No improvement for {patience} epochs, but continuing training until end at least 30 per document requirement of at least 25')

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)

    # Return metrics for reporting/visualization
    metrics = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'best_acc': best_acc
    }

    return model, metrics

# Train the model with early stopping
print("Training Pneumonia Classification Model...")
model, metrics = train_model(model, criterion, optimizer, scheduler, num_epochs=30, patience=7)

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'metrics': metrics,
}, '/content/drive/MyDrive/pneumonia_final_model.pth')

print("Training completed. Model saved!")

# Evaluate the model on the test set
def evaluate_model(model, criterion, dataset_name='test'):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []

    # No gradients needed for evaluation
    with torch.no_grad():
        for inputs, labels in dataloaders[dataset_name]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # confusion matrix
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    eval_loss = running_loss / dataset_sizes[dataset_name]
    eval_acc = running_corrects.double() / dataset_sizes[dataset_name]

    print(f'\nFinal Evaluation on {dataset_name} set:')
    print(f'{dataset_name.capitalize()} Loss: {eval_loss:.4f}')
    print(f'{dataset_name.capitalize()} Accuracy: {eval_acc:.4f}')

    return all_preds, all_labels

# Run final evaluation on test set
print("\nPerforming final evaluation on test set...")
predictions, true_labels = evaluate_model(model, criterion, 'test')

# Compute confusion matrix, precision, recall, etc.
from sklearn.metrics import classification_report, confusion_matrix

# Print classification report
print("\nClassification Report:")
print(classification_report(true_labels, predictions, target_names=['Normal', 'Pneumonia']))

cm = confusion_matrix(true_labels, predictions)
print("\nConfusion Matrix:")
print(cm)