## Step 1: Setup and Configuration
This section imports all necessary libraries and defines robust, absolute paths to our data and model directories. Using absolute paths is a best practice that prevents errors when running code from different locations or using background workers.

In [None]:
import os
import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.models import efficientnet_b2
from sklearn.metrics import classification_report, confusion_matrix

# --- 1. Robust Path Definitions (Best Practice) ---
# Get the absolute path to the directory this notebook is in
notebook_dir = os.path.abspath('')
# Go UP one level to get the main project's base directory
base_dir = os.path.dirname(notebook_dir)

print(f"Project base directory determined as: {base_dir}")

# Build all other paths from this absolute base_dir
root_path = os.path.join(base_dir, 'dataset', 'malaria')
train_json_path = os.path.join(root_path, 'training.json')
test_json_path = os.path.join(root_path, 'test.json')
models_dir = os.path.join(base_dir, 'effecientnetb2_model', 'efficientnet_models')
image_path = os.path.join(root_path, 'images')

# Create models directory if it doesn't exist
os.makedirs(models_dir, exist_ok=True)

print(f"Train JSON: {train_json_path}")
print(f"Test JSON:  {test_json_path}")
print(f"Images:     {image_path}")
print(f"Models Dir: {models_dir}")

## Step 2: Model and Dataset Definitions
Here we define the core components: the `MalariaDataset` class to load our specific JSON format, the `EfficientNetDetector` model, and our custom loss functions.

In [None]:
class MalariaDataset(Dataset):
    def __init__(self, json_path, image_root, transform=None, category_map=None):
        with open(json_path, 'r') as f:
            self.entries = json.load(f)
        self.image_root = image_root
        self.transform = transform

        if category_map is None:
            all_categories = set()
            for item in self.entries:
                for obj in item['objects']:
                    all_categories.add(obj['category'])
            self.category_map = {cat: idx for idx, cat in enumerate(sorted(list(all_categories)))}
        else:
            self.category_map = category_map
        
        self.labels = []
        for item in self.entries:
            if item['objects']:
                cat = item['objects'][0]['category']
                self.labels.append(self.category_map[cat])
            else:
                self.labels.append(-1)

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        pathname_from_json = entry['image']['pathname']
        image_name = os.path.basename(pathname_from_json)
        image_full_path = os.path.join(self.image_root, image_name)

        try:
            image = Image.open(image_full_path).convert("RGB")
        except FileNotFoundError:
            print(f"Error: Image not found at {image_full_path}")
            return None # Will be filtered by collate_fn

        boxes = []
        labels = []
        for obj in entry['objects']:
            bb = obj['bounding_box']
            boxes.append([bb['minimum']['c'], bb['minimum']['r'], bb['maximum']['c'], bb['maximum']['r']])
            labels.append(self.category_map[obj['category']])

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        if self.transform:
            image = self.transform(image)

        target = {'boxes': boxes, 'labels': labels}
        return image, target

class EfficientNetDetector(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = efficientnet_b2(weights='IMAGENET1K_V1')
        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.classifier = nn.Linear(num_features, num_classes)
        self.bbox_head = nn.Linear(num_features, 4)

    def forward(self, x):
        features = self.backbone(x)
        class_scores = self.classifier(features)
        bbox_preds = self.bbox_head(features)
        return class_scores, bbox_preds

# Custom collate_fn to filter out None values from the batch (e.g., from missing images)
def custom_collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return torch.tensor([]), []
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    images = torch.stack(images, dim=0)
    return images, targets

## Step 3: Training and Validation Functions
This is the core logic. `train_model` handles one epoch of training. The `validate_model` function is crucially important and contains the **new, robust logic** to generate a fair, image-level ground truth for the final classification report.

In [None]:
def train_model(model, loader, optimizer, device, epoch, num_classes):
    model.train()
    running_loss, correct, total_objects = 0.0, 0, 0
    pbar = tqdm(loader, desc=f"Training Epoch {epoch}")

    # These loss functions must be defined to be used
    def margin_loss(class_scores, targets, margin=0.2):
        one_hot_targets = F.one_hot(targets.long(), num_classes=class_scores.size(-1)).float()
        left = F.relu(margin - class_scores) * one_hot_targets
        right = F.relu(class_scores - (1 - margin)) * (1.0 - one_hot_targets)
        return (left + right).sum(dim=-1).mean()

    def bbox_loss(preds, targets):
        return F.smooth_l1_loss(preds, targets)

    for images, targets_list in pbar:
        if not images.numel(): continue # Skip empty batches
        images = images.to(device)
        optimizer.zero_grad()
        class_scores, bbox_preds = model(images)
        
        batch_class_loss = 0
        batch_bbox_loss = 0
        image_size = images.shape[-1]

        for i in range(images.size(0)):
            target = targets_list[i]
            target_boxes = target['boxes'].to(device)
            target_labels = target['labels'].to(device)
            
            if len(target_labels) == 0: continue

            batch_class_loss += margin_loss(class_scores[i].unsqueeze(0), target_labels)
            
            # For bbox loss, we only compare against the first object's box for simplicity
            target_box_norm = target_boxes[0] / torch.tensor([image_size, image_size, image_size, image_size], device=device)
            batch_bbox_loss += bbox_loss(bbox_preds[i], target_box_norm)
            
            predicted = torch.argmax(class_scores[i])
            correct += (predicted == target_labels).sum().item()
            total_objects += len(target_labels)
            
        loss = batch_class_loss + 0.1 * batch_bbox_loss
        if total_objects > 0:
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    epoch_loss = running_loss / len(loader) if len(loader) > 0 else 0
    accuracy = 100 * correct / total_objects if total_objects > 0 else 0
    return epoch_loss, accuracy

def validate_model(model, loader, device, num_classes, category_map, return_preds=False):
    model.eval()
    all_labels_for_report = []
    all_preds_for_report = []
    pbar = tqdm(loader, desc="Validating")

    # Invert map to get name from index
    idx_to_category = {v: k for k, v in category_map.items()}
    # Get the integer labels for non-parasite classes
    non_parasite_labels = {category_map[cat] for cat in ['red blood cell', 'leukocyte', 'difficult'] if cat in category_map}

    with torch.no_grad():
        for images, targets_list in pbar:
            if not images.numel(): continue
            images = images.to(device)
            class_scores, _ = model(images)

            for i in range(images.size(0)):
                target_labels = targets_list[i]['labels']
                if len(target_labels) == 0: continue

                predicted_idx = torch.argmax(class_scores[i]).item()
                all_preds_for_report.append(predicted_idx)

                # --- ROBUST LOGIC FOR TRUE LABEL --- 
                parasite_labels = [lbl.item() for lbl in target_labels if lbl.item() not in non_parasite_labels]
                
                if parasite_labels:
                    true_label_for_report = parasite_labels[0]
                else:
                    true_label_for_report = target_labels[0].item()
                
                all_labels_for_report.append(true_label_for_report)

    if return_preds:
        return all_labels_for_report, all_preds_for_report
    else:
        # Return overall accuracy if not doing a detailed report
        accuracy = 100 * np.mean(np.array(all_labels_for_report) == np.array(all_preds_for_report))
        return accuracy

## Step 4: Main Training Pipeline
This is where we put everything together. We'll define our hyperparameters, set up the datasets and dataloaders, and then run the training loop for a set number of epochs. After training, the best model is saved.

In [None]:
# --- 1. Define Hyperparameters for our single run ---
params = {
    'lr': 0.001,
    'optimizer': 'Adam',
    'batch_size': 32,
    'image_size': 224,
    'sampling': 'oversample' # Use oversampling to help with imbalance
}
NUM_EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- 2. Create Datasets and Category Map ---
transform = transforms.Compose([
    transforms.Resize((params['image_size'], params['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = MalariaDataset(train_json_path, image_path, transform=transform)
category_map = train_ds.category_map
num_classes = len(category_map)
test_ds = MalariaDataset(test_json_path, image_path, transform=transform, category_map=category_map)

print(f"Found {num_classes} classes: {category_map}")

# --- 3. Create DataLoaders (with optional oversampling) ---
sampler = None
if params['sampling'] == 'oversample':
    print("Applying weighted random oversampling...")
    class_counts = np.bincount(train_ds.labels)
    class_weights = 1. / class_counts
    sample_weights = np.array([class_weights[t] for t in train_ds.labels])
    sampler = WeightedRandomSampler(torch.from_numpy(sample_weights).double(), len(sample_weights))

train_loader = DataLoader(
    train_ds, 
    batch_size=params['batch_size'], 
    sampler=sampler,
    collate_fn=custom_collate_fn,
    # Set shuffle=False when using a sampler
    shuffle=sampler is None 
)

test_loader = DataLoader(
    test_ds, 
    batch_size=params['batch_size'], 
    shuffle=False, 
    collate_fn=custom_collate_fn
)

# --- 4. Initialize Model and Optimizer ---
model = EfficientNetDetector(num_classes=num_classes).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])

# --- 5. Run Training Loop ---
best_val_accuracy = 0.0
history = {'train_loss': [], 'train_accuracy': [], 'val_accuracy': []}

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_model(model, train_loader, optimizer, DEVICE, epoch, num_classes)
    val_acc = validate_model(model, test_loader, DEVICE, num_classes, category_map)
    
    history['train_loss'].append(train_loss)
    history['train_accuracy'].append(train_acc)
    history['val_accuracy'].append(val_acc)
    
    print(f"Epoch {epoch} Summary: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        save_path = os.path.join(models_dir, 'best_model.pth')
        torch.save(model.state_dict(), save_path)
        print(f"🎉 New best model saved to {save_path} with accuracy: {best_val_accuracy:.2f}%")

print("\n--- Training Complete ---")

## Step 5: Final Evaluation and Analysis
Now that the model is trained, we load the best-performing version (the one saved during the epoch with the highest validation accuracy) and run a final, detailed evaluation on the test set. This gives us an unbiased assessment of its performance.

In [None]:
print("--- Starting Final Evaluation on Test Set ---")

# --- 1. Load the Best Model State ---
best_model_path = os.path.join(models_dir, 'best_model.pth')
eval_model = EfficientNetDetector(num_classes=num_classes).to(DEVICE)
eval_model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
print(f"Loaded best model from {best_model_path}")

# --- 2. Get Final Predictions ---
y_true, y_pred = validate_model(eval_model, test_loader, DEVICE, num_classes, category_map, return_preds=True)

# --- 3. Generate Reports ---
class_names = [name for name, index in sorted(category_map.items(), key=lambda item: item[1])]
labels_for_report = list(range(len(class_names)))

# Classification Report
print("\n--- Classification Report ---")
report = classification_report(
    y_true, 
    y_pred, 
    target_names=class_names,
    labels=labels_for_report,
    zero_division=0
)
print(report)

# Confusion Matrix
print("\n--- Confusion Matrix ---")
cm = confusion_matrix(y_true, y_pred, labels=labels_for_report)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix', fontsize=16)
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.show()