**Data layout (TIFF supported):**
```
./Data/
  Train/
    Normal/      *.tif, *.tiff, *.png, *.jpg
    Abnormal/    *.tif, *.tiff, *.png, *.jpg
  Validation/
    Normal/      ...
    Abnormal/    ...
  Test/
    Normal/      ...   (used for pretrained weight evaluation)
    Abnormal/    ...
```

> **Colab/GitHub-ready:** This notebook uses relative paths (`./data`, `./checkpoints`) so it runs smoothly when opened from GitHub in Colab.  
> **Training:** runs for **5 epochs** and saves to `./checkpoints/best_train.pth`.  
> **Evaluation:** uses your pretrained file `./checkpoints/epoch=75-val_loss=0.69-val_acc=0.77.ckpt`.

# 🧪 Sperm Classification (Normal vs Abnormal) — DenseNet-169 (Colab Notebook)

**Goal:** quick, accurate *trial* training run (10 epochs) using **DenseNet‑169**. This notebook keeps key hyperparameters consistent with your setup: **image size 800×800** and normalization `mean=0.2636`, `std=0.1562`. It also provides a simple evaluation section to load **your pretrained weights** and run on test images.

> Tip: If you store data/weights in Google Drive, use the Drive mount cell below.

## 1) Runtime & Drive (optional)
If your images/checkpoints live in Drive, mount it first.

In [None]:
# (Optional) Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')
#
# After mounting, you can refer to files like:
# data_dir = '/content/drive/MyDrive/your_dataset_root'
# weights_path = '/content/drive/MyDrive/weights/densenet169_sperm.pth'


## 2) Setup
Colab usually has recent PyTorch/torchvision. If you need specific versions, uncomment the cell below.

In [None]:
# If you need to pin specific versions, uncomment and set the versions you prefer
# !pip install --quiet torch torchvision torchmetrics==1.4.0


## 3) Configuration
Set your dataset paths, batch size, and training hyperparameters. We use **800×800** images and the grayscale normalization stats you provided. We convert grayscale to **3 channels** inside the transform (to feed DenseNet).

In [None]:
from dataclasses import dataclass

@dataclass
class CFG:
    # Paths
    data_dir: str = "./Data"# root containing subfolders 'train' and 'val' (ImageFolder style)
    # Example structure:
    # /content/data/
    #   train/
    #     Normal/...
    #     Abnormal/...
    #   val/
    #     Normal/...
    #     Abnormal/...

    # Training
    epochs: int = 5# trial run as requested
    batch_size: int = 8
    num_workers: int = 2
    lr: float = 1e-4
    weight_decay: float = 0.0
    seed: int = 42

    # Image/Transforms — kept consistent with your code
    img_size: int = 800
    mean: float = 0.2636
    std: float = 0.1562

    # Checkpoints
    save_dir: str = "./checkpoints"
    ckpt_name: str = "epoch=75-val_loss=0.69-val_acc=0.77.ckpt"

cfg = CFG()
print(cfg)
    train_ckpt_name: str = "best_train.pth"

# Expect ImageFolder layout with /Train, /Validation, /Test under cfg.data_dir
train_dir = os.path.join(cfg.data_dir, "Train")
val_dir = os.path.join(cfg.data_dir, "Validation")
test_dir = os.path.join(cfg.data_dir, "Test")

train_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
val_ds = datasets.ImageFolder(val_dir, transform=val_tfms)
test_ds = datasets.ImageFolder(test_dir, transform=val_tfms)

class_names = train_ds.classes
print(f"Classes: {class_names}")

from torch.utils.data import DataLoader
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,  num_workers=cfg.num_workers, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)


## 4) Imports & Reproducibility
from PIL import Image


In [None]:
import os, random, math, time
from pathlib import Path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

import numpy as np
import matplotlib.pyplot as plt

# Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

from PIL import Image

## 5) Datasets & Dataloaders
We keep your normalization and effective grayscale handling. We convert grayscale to 3 channels so DenseNet‑169 can ingest it, but the **statistics remain the same** (we just repeat the single channel values across RGB).

In [None]:
# Transforms — training includes flips; validation only resize+normalize.
# We convert to 3 channels (RGB) so DenseNet works, but keep grayscale stats replicated across channels.

IM_SIZE = cfg.img_size
MEAN = [cfg.mean, cfg.mean, cfg.mean]
STD = [cfg.std, cfg.std, cfg.std]

train_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IM_SIZE, IM_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

val_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IM_SIZE, IM_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

# Expect ImageFolder layout with /train and /val under cfg.data_dir
train_dir = os.path.join(cfg.data_dir, "train")
val_dir = os.path.join(cfg.data_dir, "val")

train_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
val_ds = datasets.ImageFolder(val_dir, transform=val_tfms)

class_names = train_ds.classes
num_classes = len(class_names)
print(f"Classes: {class_names} -> num_classes={num_classes}")

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
len(train_ds), len(val_ds)

## 6) Model — DenseNet‑169 (pretrained)
We replace the classifier with a 2‑unit linear layer for binary classification. Loss: `CrossEntropyLoss`. Optimizer: `Adam` (lr=1e-4 by default).

In [None]:
# Build DenseNet-169
dnet = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)

# Replace classifier for our task
in_features = dnet.classifier.in_features
dnet.classifier = nn.Linear(in_features, 2)

dnet = dnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(dnet.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

## 7) Training Loop (10 epochs)
A minimal, well‑commented loop with val accuracy each epoch. Checkpoints are saved to `cfg.save_dir`.

In [None]:
from typing import Dict

os.makedirs(cfg.save_dir, exist_ok=True)
best_val_acc = 0.0
history = {"train_loss": [], "val_loss": [], "val_acc": []}

def run_epoch(model, loader, train: bool) -> float:
    epoch_loss = 0.0
    total = 0
    correct = 0

    if train:
        model.train()
    else:
        model.eval()

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(images)           # logits
            loss = criterion(outputs, labels) # CE loss

            if train:
                loss.backward()
                optimizer.step()

        epoch_loss += loss.item() * images.size(0)

        # accuracy on the fly
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = epoch_loss / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc

for epoch in range(1, cfg.epochs + 1):
    t0 = time.time()
    train_loss, _ = run_epoch(dnet, train_loader, train=True)
    val_loss, val_acc = run_epoch(dnet, val_loader, train=False)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_path = os.path.join(cfg.save_dir, cfg.train_ckpt_name)
        torch.save({"model_state": dnet.state_dict(),
                    "class_names": class_names,
                    "cfg": cfg.__dict__}, best_path)

    t1 = time.time()
    print(f"Epoch {epoch:02d}/{cfg.epochs} | "
          f"train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f} | "
          f"time: {t1 - t0:.1f}s")

print(f"Best val_acc: {best_val_acc:.4f} | Saved to: {os.path.join(cfg.save_dir, cfg.train_ckpt_name)}")

## 8) Plot Training Curves

In [None]:
# Plot loss and val accuracy
plt.figure()
plt.plot(history["train_loss"], label="train_loss")
plt.plot(history["val_loss"], label="val_loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("Loss")

plt.figure()
plt.plot(history["val_acc"], label="val_acc")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("Validation Accuracy")

plt.show()

## 9) Evaluation — Load Pretrained Weights and Test
Point `weights_path` to your pretrained checkpoint (`.pth`) saved by this notebook **or your own weights**. We provide both **single‑image test** and **folder‑based batch test** with a confusion matrix.

In [None]:
# === Evaluate PRETRAINED weights on TEST ===
weights_path = os.path.join(cfg.save_dir, cfg.ckpt_name)
model_inf, class_names_inf = load_model_for_inference(weights_path, num_classes=len(class_names))

y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model_inf(images)
        preds = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(preds)
        y_true.extend(labels.numpy().tolist())

print("Classes:", class_names_inf)
print(classification_report(y_true, y_pred, target_names=test_ds.classes, digits=4))

### 9.a) Example: load weights

In [None]:
# Evaluation will use the PRETRAINED checkpoint (do not overwrite).
weights_path = os.path.join(cfg.save_dir, cfg.ckpt_name)
model_inf, class_names_inf = load_model_for_inference(weights_path)
print("Loaded inference model; classes:", class_names_inf)


### 9.b) Single‑image prediction

In [None]:
# Set an example image
# img_path = "/content/data/val/Normal/example_001.png"
# result = predict_single_image(img_path, model_inf, class_names_inf)
# print(result)

### 9.c) Batch evaluation on a folder (e.g., your validation set)

In [None]:
# data_root = "/content/data/val"
# cm, report, classes = evaluate_folder(data_root, model_inf, val_tfms)
# print("Classes:", classes)
# print("Classification report:\n", report)

# # Plot confusion matrix (no seaborn, simple matplotlib)
# import itertools
# plt.figure()
# plt.imshow(cm, interpolation='nearest')
# plt.title("Confusion Matrix")
# plt.colorbar()
# tick_marks = np.arange(len(classes))
# plt.xticks(tick_marks, classes, rotation=45)
# plt.yticks(tick_marks, classes)

# thresh = cm.max() / 2.
# for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
#     plt.text(j, i, format(cm[i, j], 'd'),
#              horizontalalignment="center",
#              color="white" if cm[i, j] > thresh else "black")

# plt.ylabel('True label')
# plt.xlabel('Predicted label')
# plt.tight_layout()
# plt.show()

## 10) Notes
- Image size kept at **800×800** to match your code.
- Normalization kept at **mean=0.2636**, **std=0.1562** (replicated across 3 channels under the hood).
- Random H/V flips used in training, none in validation — consistent with your approach.
- You can easily swap the dataset paths or use Drive.
- For a quick run, keep 10 epochs; adjust later as needed.
- If your labels are reversed (0/1 mapping), just read `class_names` from `train_ds.classes` and interpret accordingly.
- The evaluation helpers mimic your testing flow: softmax probs, per‑image prediction, and a batch confusion matrix.