Imports

In [7]:
from pathlib import Path
import json
import csv
import re

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
from PIL import Image


Configuration

In [8]:
PROCESSED_ROOT = Path("../../data/processed/Stage2")
INDEX_CSV      = PROCESSED_ROOT / "index.csv"

TRAIN_LABELS_JSON = Path("../../data/labels/Stage2/train.json")
VAL_LABELS_JSON   = Path("../../data/labels/Stage2/val.json")

NUM_CLASSES = 2
BATCH_SIZE  = 60
LR          = 1e-3
EPOCHS      = 16
IMAGE_SIZE  = 256

# Choose "rgb" (3-channel) or "gray" (1-channel)
IMAGE_MODE = "rgb"   # change to "gray" if you want grayscale training

MODEL_OUT_PATH = Path("../../models/classifier/Stage2/model.pt")
MODEL_OUT_PATH.parent.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


Load index.csv

In [9]:
def read_index_csv(index_csv_path: Path):
    rows = []
    with open(index_csv_path, "r", newline="") as f:
        reader = csv.DictReader(f)
        for r in reader:
            rows.append(r)
    if not rows:
        raise ValueError(f"index.csv is empty: {index_csv_path}")
    if "filepath" not in rows[0]:
        raise ValueError("index.csv must contain a 'filepath' column")
    return rows

index_rows = read_index_csv(INDEX_CSV)
print("Rows in index.csv:", len(index_rows))
print("Example row keys:", list(index_rows[0].keys()))
print("Example filepath:", index_rows[0]["filepath"])


Rows in index.csv: 1100
Example row keys: ['filepath', 'split', 'stage', 'scan_session_id', 'filename', 'scan_timestamp', 'source_interim_path']
Example filepath: images/train/2026-01-21_15-51-50-237_aug01.png


Load label JSONs (train/val)

This expects a simple dict mapping:

{
  "images/train/xxx.png": 0,
  "images/train/yyy.png": 1
}

In [10]:
def load_label_map(label_path: Path):
    if not label_path.exists():
        raise FileNotFoundError(f"Label file not found: {label_path}")
    with open(label_path, "r") as f:
        data = json.load(f)
    if isinstance(data, dict):
        # Original format: {filepath: class_id}
        out = {}
        for k, v in data.items():
            out[k.replace("\\", "/")] = int(v)
        return out
    elif isinstance(data, list):
        # Custom format: list of {"image": filename, "no_contraband": 0/1, "isolated_items": 0/1}
        out = {}
        for item in data:
            filepath = item['image']
            # Assume class 0 for no_contraband=1, class 1 otherwise
            class_id = 0 if item.get('no_contraband', 0) == 1 else 1
            out[filepath] = class_id
        return out
    else:
        raise ValueError("Label JSON must be a dict {filepath: class_id} or Label Studio list format")

train_label_map = load_label_map(TRAIN_LABELS_JSON)
val_label_map   = load_label_map(VAL_LABELS_JSON)

print("Train labels:", len(train_label_map))
print("Val labels:", len(val_label_map))

from collections import Counter
print("Train label distribution:", Counter(train_label_map.values()))
print("Val label distribution:", Counter(val_label_map.values()))

# Show a couple samples
for i, (k, v) in enumerate(train_label_map.items()):
    print("Example train label:", k, "->", v)
    if i >= 2:
        break


Train labels: 935
Val labels: 110
Train label distribution: Counter({0: 935})
Val label distribution: Counter({0: 110})
Example train label: 2026-01-21_15-43-51-512_aug00.png -> 0
Example train label: 2026-01-21_15-43-51-512_aug01.png -> 0
Example train label: 2026-01-21_15-43-51-512_aug02.png -> 0


Define Dataset (respects split, no leakage)

In [11]:
class ProcessedSplitDataset(Dataset):
    def __init__(self, index_rows, processed_root: Path, split: str, label_map: dict, transform=None):
        self.processed_root = processed_root
        self.split = split
        self.transform = transform
        self.label_map = label_map

        # Filter filepaths by split
        filepaths = []
        for r in index_rows:
            fp = r["filepath"].replace("\\", "/")
            row_split = (r.get("split") or "").strip().lower()

            if row_split:
                if row_split == split:
                    filepaths.append(fp)
            else:
                # fallback if index.csv doesn't have split column
                if f"images/{split}/" in fp:
                    filepaths.append(fp)

        if not filepaths:
            raise ValueError(f"No samples found for split='{split}'")

        # Keep only those that have labels (strict)
        self.filepaths = []
        self.labels = []
        for fp in filepaths:
            filename = Path(fp).name
            label = self.label_map.get(filename, 0)  # default to 0 if missing
            self.filepaths.append(fp)
            self.labels.append(label)

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        rel_path = self.filepaths[idx]
        img_path = self.processed_root / rel_path

        img = Image.open(img_path)

        # Force consistent mode
        if IMAGE_MODE == "gray":
            img = img.convert("L")
        else:
            img = img.convert("RGB")

        if self.transform:
            img = self.transform(img)

        y = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, y

Transforms + DataLoaders

In [12]:
if IMAGE_MODE == "gray":
    transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.ToTensor(),  # -> [1, H, W]
    ])
    in_channels = 1
else:
    transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.ToTensor(),  # -> [3, H, W]
    ])
    in_channels = 3

train_ds = ProcessedSplitDataset(index_rows, PROCESSED_ROOT, "train", train_label_map, transform=transform)
val_ds   = ProcessedSplitDataset(index_rows, PROCESSED_ROOT, "val",   val_label_map,   transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("Train samples:", len(train_ds))
print("Val samples:", len(val_ds))

# Quick check tensor shapes
x, y = next(iter(train_loader))
print("Batch image shape:", x.shape, "| Batch label shape:", y.shape)


Train samples: 935
Val samples: 110


Batch image shape: torch.Size([60, 3, 256, 256]) | Batch label shape: torch.Size([60])


Train with the previous model from previous stage

In [13]:
class SimpleCNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),  # /2
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),          # /4
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),          # /8
        )
        # IMAGE_SIZE is global; compute flattened size dynamically
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, IMAGE_SIZE, IMAGE_SIZE)
            out = self.features(dummy)
            flat_dim = out.view(1, -1).shape[1]

        self.classifier = nn.Sequential(
            nn.Linear(flat_dim, 128), nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

model = SimpleCNN(in_channels=in_channels, num_classes=NUM_CLASSES).to(device)
model.load_state_dict(torch.load("../../models/classifier/Stage1/model.pt")) #IMPORTANT, LOAD THE MODEL FROM PREVIOUS STAGE
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

print(model)


SimpleCNN(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=65536, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=2, bias=True)
  )
)


Training Loop (train + val)

In [14]:
for epoch in range(1, EPOCHS + 1):
    # ---- Train ----
    model.train()
    train_loss_sum, train_correct, train_total = 0.0, 0, 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss_sum += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss = train_loss_sum / max(1, train_total)
    train_acc  = train_correct / max(1, train_total)

    # ---- Val ----
    model.eval()
    val_loss_sum, val_correct, val_total = 0.0, 0, 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            logits = model(imgs)
            loss = criterion(logits, labels)

            val_loss_sum += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_loss_sum / max(1, val_total)
    val_acc  = val_correct / max(1, val_total)

    print(
        f"Epoch [{epoch}/{EPOCHS}] "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} "
        f"| Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f}"
    )


Epoch [1/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [2/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [3/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [4/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [5/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [6/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [7/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [8/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [9/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [10/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [11/16] Train Loss: 0.0000 | Train Acc: 1.000 | Val Loss: 0.0000 | Val Acc: 1.000
Epoch [12/16] Train Loss: 0.0000 | Train 

Save Model

In [15]:
torch.save(model.state_dict(), MODEL_OUT_PATH)
print("✅ Saved model to:", MODEL_OUT_PATH)


✅ Saved model to: ../../models/classifier/Stage2/model.pt
