In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
import timm
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve, average_precision_score, precision_recall_fscore_support,auc
from sklearn.preprocessing import label_binarize
from transformers import get_linear_schedule_with_warmup
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd
import random
import copy

In [None]:
# Download dataset
! curl -L -o ~/peripheral-blood-cell.zip\
    https://www.kaggle.com/api/v1/datasets/download/bzhbzh35/peripheral-blood-cell

In [None]:
! unzip peripheral-blood-cell.zip

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Dataset path
dataset_dir = '/root/PBC_dataset_normal_DIB_224/PBC_dataset_normal_DIB_224'

# Class names
class_names = ['basophil', 'eosinophil', 'erythroblast', 'ig', 'lymphocyte', 'monocyte', 'neutrophil', 'platelet']

# Data augmentation and normalization
transform = {
        'train': transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
}

# Load the full dataset without split
full_dataset = datasets.ImageFolder(dataset_dir, transform=transform['train'])

# Split indices for train, validation, and test (80% train, 10% validation, 10% test)
train_idx, test_idx = train_test_split(list(range(len(full_dataset))), test_size=0.2, stratify=full_dataset.targets)
train_idx, val_idx = train_test_split(train_idx, test_size=0.125, stratify=[full_dataset.targets[i] for i in train_idx])

# Create subsets
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')

In [None]:
# Update transforms for validation and test datasets (since they share the same normalization)
val_dataset.dataset.transform = transform['test']
test_dataset.dataset.transform = transform['test']



# Function to calculate class weights for balanced training
def calculate_class_weights(dataset):
    if isinstance(dataset, torch.utils.data.Subset):
        targets = np.array([dataset.dataset.targets[i] for i in dataset.indices])
    else:
        targets = np.array(dataset.targets)

    class_sample_count = np.bincount(targets)
    class_weights = 1. / class_sample_count

    return torch.from_numpy(class_weights).float()

# Create a weighted sampler for oversampling minority classes
def create_weighted_sampler(dataset):
    if isinstance(dataset, torch.utils.data.Subset):
        targets = np.array([dataset.dataset.targets[i] for i in dataset.indices])
    else:
        targets = np.array(dataset.targets)

    class_weights = calculate_class_weights(dataset)
    samples_weights = np.array([class_weights[t] for t in targets])
    sampler = WeightedRandomSampler(samples_weights, len(samples_weights))

    return sampler

# Use WeightedRandomSampler in the DataLoader for balanced sampling
train_sampler = create_weighted_sampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

print(f'Train DataLoader size: {len(train_loader.dataset)}')
print(f'Validation DataLoader size: {len(val_loader.dataset)}')
print(f'Test DataLoader size: {len(test_loader.dataset)}')

# Calculate class weights for CrossEntropyLoss
class_weights = calculate_class_weights(full_dataset)
class_weights = class_weights.to(device)

# Print class weights
print(f'Class Weights: {class_weights}')

In [None]:
# Function to visualize class distribution and total sample count from batches
def visualize_batch_class_distribution(data_loader, num_batches=10):
    class_count = np.zeros(len(class_names))
    total_samples = 0

    # Loop through the specified number of batches
    for i, (images, labels) in enumerate(data_loader):
        if i >= num_batches:
            break
        total_samples += labels.size(0)  # Add the number of samples in the current batch
        class_count += np.bincount(labels.numpy(), minlength=len(class_names))

    # Print total number of samples processed
    print(f"Total number of samples processed in {num_batches} batches: {total_samples}")

    # Plot the batch distribution
    plt.figure(figsize=(8,6))
    sns.barplot(x=class_names, y=class_count)
    plt.title(f"Class Distribution in {num_batches} Sampled Batches")
    plt.xticks(rotation=45)
    plt.ylabel("Number of samples")
    plt.show()

# Visualize the class distribution over multiple batches and show the total number of samples
visualize_batch_class_distribution(train_loader, num_batches=50)


In [None]:
# Function to display an image tensor
def imshow(img):
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img

# Sample 'num_samples' images per class
def plot_class_samples(dataset, class_names, num_samples=5):
    num_classes = len(class_names)
    samples_per_class = {class_name: [] for class_name in class_names}

    # Go through dataset and store 'num_samples' samples per class
    for img, label in dataset:
        class_name = class_names[label]
        if len(samples_per_class[class_name]) < num_samples:
            samples_per_class[class_name].append((img, class_name))
        if all(len(v) == num_samples for v in samples_per_class.values()):
            break  # Stop once we have 'num_samples' samples per class

    # Plot the images in rows (one row per class)
    fig, axes = plt.subplots(nrows=num_classes, ncols=num_samples, figsize=(15, 3*num_classes))

    for row, class_name in enumerate(class_names):
        for col, (img, _) in enumerate(samples_per_class[class_name]):
            ax = axes[row, col]
            ax.imshow(imshow(img))
            ax.axis('off')

        # Set the title for the first image of each row
        axes[row, 0].set_title(class_name, fontsize=14, pad=20, loc='left')

    # Adjust layout to give more space for titles
    plt.subplots_adjust(left=0.15, hspace=0.5, wspace=0.3)
    plt.show()

# Sample and plot 5 images for each class
plot_class_samples(full_dataset, class_names, num_samples=5)

In [None]:
MODEL_NAME = 'tf_efficientnet_b0'
NUM_CLASSES =len(class_names)
# Model Creation
# Load the pre-trained model from timm
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=0) # num_classes=0 removes the classifier

# Get the number of input features for the final layer
num_in_features = model.num_features

# Add our custom classification head
model.classifier = nn.Sequential(
        nn.BatchNorm1d(num_in_features),
        nn.Linear(num_in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, NUM_CLASSES)
)

model = model.to(device)

In [None]:
# Training Phase 1: Train only the classifier head
# Freeze all layers in the base model
BATCH_SIZE = 32
NUM_EPOCHS_INITIAL = 10
NUM_EPOCHS_FINE_TUNE = 10
LEARNING_RATE_INITIAL = 1e-3
LEARNING_RATE_FINE_TUNE = 1e-5
# --- 4. Validation Function ---
def validate(model, dataloader, criterion):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    running_corrects = 0

    # No need to track gradients for validation
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)

    return epoch_loss, epoch_acc

In [None]:
# Training Phase 1: Train only the classifier head
print("\n--- Phase 1: Training the classifier head ---")
# Freeze base model layers
for param in model.parameters():
    param.requires_grad = False
# Unfreeze classifier head
for param in model.classifier.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE_INITIAL)

best_val_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(NUM_EPOCHS_INITIAL):
    # Training
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    train_loss = running_loss / len(train_dataset)

    # Validation
    val_loss, val_acc = validate(model, val_loader, criterion)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS_INITIAL} -> "
          f"Train Loss: {train_loss:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        print("  -> New best model saved!")

# Load best model weights from Phase 1 before fine-tuning
model.load_state_dict(best_model_wts)

In [None]:
# Training Phase 2: Fine-tune the whole model
import time
print("\n--- Phase 2: Fine-tuning the entire model ---")
# Unfreeze all layers
for param in model.parameters():
    param.requires_grad = True

# Re-create optimizer with a lower learning rate
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE_FINE_TUNE)

best_val_loss = float('inf') # Reset best loss for fine-tuning phase

for epoch in range(NUM_EPOCHS_FINE_TUNE):
    # Training
    model.train()
    running_loss = 0.0
    start_time = time.time()

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    train_loss = running_loss / len(train_dataset)
    epoch_time = time.time() - start_time

    # Validation
    val_loss, val_acc = validate(model, val_loader, criterion)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS_FINE_TUNE} -> "
          f"Train Loss: {train_loss:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Time: {epoch_time:.2f}s")

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        print("  -> New best model saved!")

In [None]:
#  Run Inference on Test Set
all_preds = []
all_labels = []

print("\nRunning inference on the test set...")
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("Inference complete.")

# --- 5. Generate and Print Classification Report ---
print("\n--- Classification Report ---")
report = classification_report(all_labels, all_preds, target_names=class_names)
print(report)

# --- 6. Generate and Plot Confusion Matrix ---
print("\n--- Confusion Matrix ---")
cm = confusion_matrix(all_labels, all_preds)

# Plotting the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix for WBC Classification')
plt.show()

In [None]:

print("\nFinished Training. Loading best model weights.")
model.load_state_dict(best_model_wts)

print("Attempting to save the model now...")

torch.save(model.state_dict(), 'wbc_classifier_efficientnet_b0_best.pth')
print("Best model saved to wbc_classifier_efficientnet_b0_best.pth")