**<h1 align="center">Malaria Detection - PyTorch</h1>**

<p align="center"><i>V1.2 - Added Device Selection</i></p>

## Pip installs & Imports

In [277]:
# !pip install kagglehub
# !pip install numpy
# !pip install pandas
# !pip install matplotlib
# !pip install torch torchvision
# !pip install tqdm

In [278]:
# For data download
import kagglehub

# For data handling and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# For deep learning (PyTorch)
import torch
from torch import nn, optim
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms, datasets, models
from torchvision.models import ResNet50_Weights
from torchvision.transforms import functional as F

# For progress tracking
from tqdm import tqdm

# Miscellaneous / Utilities
import random
import shutil
import os

## Global variables

### Project specific variables

In [279]:
IMG_SIZE = 224  # Resize all images to defined size
BATCH_SIZE = 64

VALIDATION_SIZE = 0.2
TEST_SIZE = 0.2

# Device selection
DEVICE = 'GPU'  # Options: 'TPU', 'GPU', 'CPU'

# Data source selection
SOURCE = 'kaggle'
# DATA_DIR = kagglehub.dataset_download('iarunava/cell-images-for-detecting-malaria', download_dir='src/static/data/') # Download from Kaggle
# DATA_DIR = 'src/static/data/downloads/extracted/ZIP.data.lhnc.nlm.nih.gov_publ_Mala_cell_imagCpSVVrJBQVm1EAGSYJgFN2ZUxCZtjRh76bGSL61Dxmg.zip/cell_images' # Local
# DATA_DIR = '/kaggle/input/cell-images-for-detecting-malaria' # Colab
DATA_DIR = 'src/static/data'

MODEL_DIR = "src/static/model/"
# MODEL_NAME = 'model.keras'
MODEL_NAME = 'best_model.pt'
MODEL_PATH = MODEL_DIR + MODEL_NAME

### Colors

In [280]:
# ANSI escape codes for colored output
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RED = "\033[91m"
RESET = "\033[0m"  # White

## Device Selector

In [281]:
# Device selector
if DEVICE.upper() == 'TPU':
    try:
        import torch_xla
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        actual_device = 'TPU'
    except ImportError:
        raise ImportError("TPU support requires torch_xla. Make sure you're running on a TPU environment like Google Colab.")
elif DEVICE.upper() == 'GPU':
    if torch.cuda.is_available():
        device = torch.device('cuda')
        actual_device = 'GPU'
    else:
        device = torch.device('cpu')
        actual_device = 'CPU (fallback)'
elif DEVICE.upper() == 'CPU':
    device = torch.device('cpu')
    actual_device = 'CPU'
else:
    raise ValueError("Invalid DEVICE value. Choose from: 'TPU', 'GPU', 'CPU'.")

print(f"{YELLOW}Using device: {RESET}{actual_device}")

[93mUsing device: [0mCPU (fallback)


## Data source

In [282]:
if SOURCE == 'kaggle':
    # Download dataset
    downloaded_path = kagglehub.dataset_download('iarunava/cell-images-for-detecting-malaria')

    # This is where the actual images live
    source_path = os.path.join(downloaded_path, 'cell_images')

    # Sanity check: skip top-level duplicates
    top_level_dupes = ['Parasitized', 'Uninfected']
    for folder in top_level_dupes:
        dupe_path = os.path.join(downloaded_path, folder)
        if os.path.exists(dupe_path):
            print(f"{YELLOW}Skipping top-level duplicate: {RESET}{dupe_path}")

    # Ensure target directory exists
    os.makedirs(DATA_DIR, exist_ok=True)

    # Clean up nested 'cell_images' in the target dir if it already exists
    nested_path = os.path.join(DATA_DIR, 'cell_images')
    if os.path.exists(nested_path):
        print(f"{YELLOW}Removing nested folder: {RESET}{nested_path}")
        shutil.rmtree(nested_path)

    # Define allowed folders to copy (explicitly skip weird ones)
    allowed_folders = ['Parasitized', 'Uninfected']

    # Only copy those two
    for folder in allowed_folders:
        src_folder = os.path.join(source_path, folder)
        dst_folder = os.path.join(DATA_DIR, folder)
        if os.path.isdir(src_folder):
            if not os.path.exists(dst_folder):
                print(f"{YELLOW}Copying {RESET}{folder} {YELLOW}to {RESET}{dst_folder}")
                shutil.copytree(src_folder, dst_folder)
            else:
                print(f"{YELLOW}Skipping existing folder: {RESET}{dst_folder}")
# elif SOURCE == 'local':
#     DATA_DIR = DATA_DIR + '/cell_images/'
else:
    raise ValueError("Invalid SOURCE value. Choose from: 'kaggle', 'local'")

print(f"\n{YELLOW}Data source:{RESET}")
print(DATA_DIR)

print(f"\n{YELLOW}Labels found:{RESET}")
print(os.listdir(DATA_DIR))
# print(os.listdir(DATA_DIR + '/cell_images/'))

[93mSkipping existing folder: [0msrc/static/data\Parasitized
[93mSkipping existing folder: [0msrc/static/data\Uninfected

[93mData source:[0m
src/static/data

[93mLabels found:[0m
['Parasitized', 'Uninfected']


In [None]:
# Define your transforms for the training, validation, and testing sets

# Custom transform for adding Gaussian noise
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

# Train/Validation/Test Transforms

train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(contrast=0.5),  # Random contrast
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),            # Add noise
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

validation_transforms = transforms.Compose([transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
    [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
    [0.229, 0.224, 0.225])
])

print(f"{GREEN}Transforms Defined.{RESET}")

[92mTransforms Defined.[0m


## Load Dataset and Split

In [284]:
img_dir=DATA_DIR
train_data = datasets.ImageFolder(img_dir,transform=train_transforms)

print(f"{GREEN}Data loaded and transformed.{RESET}")

[92mData loaded and transformed.[0m


In [285]:
# number of subprocesses to use for data loading
num_workers = 0

# percentage of training set to use as validation
valid_size = 0.2

# Percentage of training set to use as test
test_size = 0.1

# convert data to a normalized torch.FloatTensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
valid_split = int(np.floor((valid_size) * num_train))
test_split = int(np.floor((valid_size+test_size) * num_train))
valid_idx, test_idx, train_idx = indices[:valid_split], indices[valid_split:test_split], indices[test_split:]

print(f"{BLUE}Total length:{RESET}")
print(num_train)
print(f"\n{GREEN}Train length:{RESET}")
print(len(train_idx))
print(f"\n{YELLOW}Validation length:{RESET}")
print(len(valid_idx))
print(f"\n{YELLOW}Test length:{RESET}")
print(len(test_idx))

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)

# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=32, sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=20, sampler=test_sampler, num_workers=num_workers)

[94mTotal length:[0m
27558

[92mTrain length:[0m
19291

[93mValidation length:[0m
5511

[93mTest length:[0m
2756


## Model ResNet50

In [286]:
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(2048, 2, bias=True)

fc_parameters = model.fc.parameters()

for param in fc_parameters:
    param.requires_grad = True

print(f"{GREEN}Model loaded.{RESET}\n")

# model # Display Model details - comment out if not needed

[92mModel loaded.[0m



In [287]:
if actual_device == 'GPU': # See later if I can relocate this to Model selection
    use_cuda = torch.cuda.is_available()

    if use_cuda:
        model = model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001 , momentum=0.9)

print(f"{GREEN}Cross entropy loss and SGD optimizer loaded.{RESET}\n")

[92mCross entropy loss and SGD optimizer loaded.[0m



## Model Training

In [None]:
def train(n_epochs, model, optimizer, criterion, use_cuda, save_path, patience=5):
    """Train model with early stopping, tqdm, and accuracy tracking."""

    valid_loss_min = np.inf
    epochs_no_improve = 0
    best_model_wts = None

    for epoch in range(1, n_epochs + 1):
        train_loss = 0.0
        valid_loss = 0.0
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []

        ###################
        # train the model #
        ###################
        model.train()
        train_loader_tqdm = tqdm(train_loader, desc=f"{YELLOW}Epoch {RESET}{epoch}/{n_epochs} [{YELLOW}Training{RESET}]", leave=False)

        for batch_idx, (data, target) in enumerate(train_loader_tqdm):
            if use_cuda:
                data, target = data.cuda(), target.cuda()

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))

            # Accuracy
            _, predicted = torch.max(output.data, 1)
            correct_train += (predicted == target).sum().item()
            total_train += target.size(0)

            train_loader_tqdm.set_postfix(loss=train_loss.item())

        ######################
        # validate the model #
        ######################
        model.eval()
        with torch.no_grad():
            valid_loader_tqdm = tqdm(valid_loader, desc=f"{GREEN}Epoch {RESET}{epoch}/{n_epochs} [{GREEN}Validation{RESET}]", leave=False)

            for batch_idx, (data, target) in enumerate(valid_loader_tqdm):
                if use_cuda:
                    data, target = data.cuda(), target.cuda()

                output = model(data)
                loss = criterion(output, target)
                valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.data - valid_loss))

                # Accuracy
                _, predicted = torch.max(output.data, 1)
                correct_val += (predicted == target).sum().item()
                total_val += target.size(0)

                valid_loader_tqdm.set_postfix(val_loss=valid_loss.item())

        # Accuracy calculations
        train_accuracy = 100 * correct_train / total_train
        val_accuracy = 100 * correct_val / total_val
        
        # Save stats
        train_losses.append(train_loss.item())
        val_losses.append(valid_loss.item())
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)

        # Print stats
        print(f"{GREEN}Epoch: {RESET}{epoch} "
              f"\t{BLUE}Train Loss: {RESET}{train_loss:.6f} | {BLUE}Train Acc: {RESET}{train_accuracy:.2f}% "
              f"\t{BLUE}Val Loss: {RESET}{valid_loss:.6f} | {BLUE}Val Acc: {RESET}{val_accuracy:.2f}%")

        ########################
        # early stopping check #
        ########################
        if valid_loss < valid_loss_min:
            torch.save(model.state_dict(), save_path)
            print(f"{YELLOW}Validation loss decreased {RESET}({valid_loss_min:.6f} --> {valid_loss:.6f}). {YELLOW}Saving model...{RESET}\n")
            valid_loss_min = valid_loss
            epochs_no_improve = 0
            best_model_wts = model.state_dict()
        else:
            epochs_no_improve += 1
            print(f"{YELLOW}No improvement in validation loss for {epochs_no_improve} epoch(s).{RESET}\n")

        if epochs_no_improve >= patience:
            print(f"\n{YELLOW}Early stopping triggered after {patience} epochs with no improvement.{RESET}")
            break

    if best_model_wts:
        model.load_state_dict(best_model_wts)

    return model, train_losses, val_losses, train_accuracies, val_accuracies

In [289]:
# Display selected device
print(f"{YELLOW}Using device: {RESET}{actual_device}")

[93mUsing device: [0mCPU (fallback)


In [None]:
# Check if folder exists, and create it if it doesn't
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)
    print(f"{YELLOW}Created directory: {RESET}{MODEL_DIR}\n")
else:
    print(f"{YELLOW}Directory already exists: {RESET}{MODEL_DIR}\n")

# Start Model Training
# train(25, model, optimizer, criterion, use_cuda, 'malaria_detection.pt') # Old Model
# trained_model = train(25, model, optimizer, criterion, use_cuda, save_path=MODEL_PATH, patience=3)

model, train_losses, val_losses, train_acc, val_acc = train(
    n_epochs=25,
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    use_cuda=use_cuda,
    save_path=MODEL_PATH,
    patience=3
)

                                                                                  0s/it, loss=0.715]

KeyboardInterrupt: 

## Display Training Metrics

In [None]:
def plot_training_metrics(train_losses, val_losses, train_acc, val_acc):
    epochs = range(1, len(train_losses) + 1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Loss plot
    ax1.plot(epochs, train_losses, label='Training Loss')
    ax1.plot(epochs, val_losses, label='Validation Loss')
    ax1.set_title('Loss over Epochs')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Accuracy plot
    ax2.plot(epochs, train_acc, label='Training Accuracy')
    ax2.plot(epochs, val_acc, label='Validation Accuracy')
    ax2.set_title('Accuracy over Epochs')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
# Display training metrics
plot_training_metrics(train_losses, val_losses, train_acc, val_acc)

NameError: name 'plot_training_metrics' is not defined