In [None]:
import os
import json
import math
import random
from pathlib import Path
from typing import Dict, List, Sequence, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.auto import tqdm

from huggingface_hub import hf_hub_download


In [None]:
from torchvision import models

model = models.densenet121(num_classes=10)
...
model.classifier = nn.Linear(in_features, len(CLASS_NAMES))


In [None]:
from pathlib import Path

PROJECT_ROOT = Path("../Marine_Learning").resolve()
IMAGE_DIR = PROJECT_ROOT / "data" / "data" / "classification_dataset" / "images"
MANIFEST_PATH = PROJECT_ROOT / "data" / "data" / "classification_dataset" / "labels.txt"
SPLIT_PATH = PROJECT_ROOT / "models" / "splits.json"

CLASS_NAMES = ["Scallop", "Roundfish", "Crab", "Whelk", "Skate", "Flatfish", "Eel"]
CLASS_TO_INDEX = {name.lower(): idx for idx, name in enumerate(CLASS_NAMES)}
SEED = 415
IMAGE_SIZE = 224
BATCH_SIZE = 16
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1


In [None]:
import json
import random
from typing import Dict, List, Sequence, Tuple
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

def _load_manifest(manifest_path: Path, image_dir: Path) -> List[Tuple[str, int]]:
    if not manifest_path.exists():
        raise FileNotFoundError(f"Manifest missing at {manifest_path}")
    records: List[Tuple[str, int]] = []
    with manifest_path.open("r", encoding="utf-8") as handle:
        for line in handle:
            stripped = line.strip()
            if not stripped:
                continue
            filename, *label_parts = stripped.split()
            label_key = " ".join(label_parts).lower()
            label_idx = CLASS_TO_INDEX[label_key]
            image_path = image_dir / filename
            if not image_path.exists():
                raise FileNotFoundError(f"Missing image: {image_path}")
            records.append((str(image_path), label_idx))
    if not records:
        raise ValueError("Manifest contained no usable records.")
    return records

def _stratified_split(records: Sequence[Tuple[str, int]], train_ratio: float, val_ratio: float, seed: int):
    grouped: Dict[int, List[Tuple[str, int]]] = {}
    for path, label in records:
        grouped.setdefault(label, []).append((path, label))

    rng = random.Random(seed)
    splits = {"train": ([], []), "val": ([], []), "test": ([], [])}

    for label, items in grouped.items():
        rng.shuffle(items)
        total = len(items)
        train_count = max(1, round(total * train_ratio))
        val_count = max(1, round(total * val_ratio))
        if train_count + val_count >= total:
            val_count = max(1, total - train_count - 1)
        train_items = items[:train_count]
        val_items = items[train_count: train_count + val_count]
        test_items = items[train_count + val_count:]

        for split_name, subset in (("train", train_items), ("val", val_items), ("test", test_items)):
            paths, labels = splits[split_name]
            for path, lbl in subset:
                paths.append(path)
                labels.append(lbl)

    return splits

def load_or_create_splits(split_path: Path):
    records = _load_manifest(MANIFEST_PATH, IMAGE_DIR)
    if split_path.exists():
        with split_path.open("r", encoding="utf-8") as handle:
            saved = json.load(handle)
        filename_to_label = {Path(path).name: label for path, label in records}
        splits = {}
        for split_name, filenames in saved.items():
            paths = [str(IMAGE_DIR / name) for name in filenames]
            labels = [filename_to_label[Path(p).name] for p in paths]
            splits[split_name] = (paths, labels)
        return splits

    splits = _stratified_split(records, TRAIN_RATIO, VAL_RATIO, SEED)
    serializable = {name: [Path(path).name for path in paths] for name, (paths, _) in splits.items()}
    split_path.parent.mkdir(parents=True, exist_ok=True)
    with split_path.open("w", encoding="utf-8") as handle:
        json.dump(serializable, handle, indent=2)
    return splits


In [None]:
class BenthicDataset(Dataset):
    def __init__(self, filepaths, labels, transform=None):
        self.filepaths = list(filepaths)
        self.labels = list(labels)
        self.transform = transform

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        image = Image.open(self.filepaths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

def build_transforms(image_size=224, random_erasing_p=0.5):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.Resize(int(image_size * 1.2)),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.15, contrast=0.15,
                               saturation=0.15, hue=0.05),
        transforms.ToTensor(),
        normalize,
        transforms.RandomErasing(p=random_erasing_p, scale=(0.02, 0.25),
                                 ratio=(0.3, 3.3), value="random"),
    ])
    eval_transform = transforms.Compose([
        transforms.Resize(int(image_size * 1.1)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        normalize,
    ])
    return train_transform, eval_transform

def build_dataloaders(splits, image_size, batch_size):
    train_tfm, eval_tfm = build_transforms(image_size=image_size)
    train_files, train_labels = splits["train"]
    val_files, val_labels = splits["val"]
    test_files, test_labels = splits["test"]

    train_ds = BenthicDataset(train_files, train_labels, train_tfm)
    val_ds = BenthicDataset(val_files, val_labels, eval_tfm)
    test_ds = BenthicDataset(test_files, test_labels, eval_tfm)

    loader_kwargs = dict(batch_size=batch_size,
                         num_workers=0,
                         pin_memory=torch.cuda.is_available())
    return (
        DataLoader(train_ds, shuffle=True, **loader_kwargs),
        DataLoader(val_ds, shuffle=False, **loader_kwargs),
        DataLoader(test_ds, shuffle=False, **loader_kwargs),
    )


In [None]:
splits = load_or_create_splits(SPLIT_PATH)
train_loader, val_loader, test_loader = build_dataloaders(
    splits,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
)

print(
    f"Train samples: {len(train_loader.dataset)} | "
    f"Val samples: {len(val_loader.dataset)} | "
    f"Test samples: {len(test_loader.dataset)}"
)


In [None]:
import matplotlib.pyplot as plt

batch_images, batch_labels = next(iter(train_loader))
inv_norm = transforms.Normalize(
    mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
    std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
)

cols = 5
rows = math.ceil(batch_images.size(0) / cols)
plt.figure(figsize=(cols * 3, rows * 3))
for idx in range(min(batch_images.size(0), cols * rows)):
    img = inv_norm(batch_images[idx]).clamp(0, 1).permute(1, 2, 0).cpu().numpy()
    label = CLASS_NAMES[batch_labels[idx]]
    ax = plt.subplot(rows, cols, idx + 1)
    ax.imshow(img)
    ax.set_title(label)
    ax.axis("off")
plt.tight_layout()
plt.show()


In [None]:
from torchvision import models

REPO_ID = "timm/densenet121.cifar10"
STATE_DICT_FILE = "pytorch_model.bin"

state_dict_path = hf_hub_download(repo_id=REPO_ID, filename=STATE_DICT_FILE)
state_dict = torch.load(state_dict_path, map_location="cpu")

model = models.densenet121(num_classes=10)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, len(CLASS_NAMES))
nn.init.trunc_normal_(model.classifier.weight, mean=0.0, std=0.02)
if model.classifier.bias is not None:
    nn.init.zeros_(model.classifier.bias)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [None]:
plt.figure(figsize=(8, 4))
plt.plot(history["train_acc"], label="Train Acc")
plt.plot(history["val_acc"], label="Val Acc")
plt.legend(); plt.title("Accuracy vs Epoch")

plt.figure(figsize=(8, 4))
plt.plot(history["train_loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Val Loss")
plt.legend(); plt.title("Loss vs Epoch")


In [None]:
targets = classification_report(
    test_targets.numpy(),
    test_preds.numpy(),
    target_names=CLASS_NAMES,
    output_dict=True,
)

focus = {cls: targets[cls] for cls in ("Crab", "Eel")}
print(json.dumps(focus, indent=2))
