# ML Model Development Best Practices Template

This notebook acts as a **top-level run script** and living document for a new machine learning project.
All reusable code should live under `src/` (e.g. `src/utils/`), while this notebook orchestrates:

1. Environment and dependency management with **Poetry**
2. Generating a `requirements.txt` and importing dependencies
3. Creating **utility functions** for loading data from XNAT into `src/utils`
4. Formatting/organizing data into a **PyTorch `Dataset`/`DataLoader`**
5. Exploratory data analysis (EDA) and saving artifacts under `exploratory/`
6. Training a model with logging to `logs/`
7. Evaluating the model and saving metrics/plots under `results/`


## 1. Project setup and dependency management with Poetry

We use **Poetry** to manage dependencies and virtual environments in a reproducible way.

### 1.1. Initialize Poetry in the project root

Run these commands **once** at the project root (same directory as this notebook):

```bash
# Initialize a new Poetry project (answer prompts or edit pyproject.toml afterwards)
poetry init

# Or, to auto-accept defaults:
poetry init --name your_project_name --dependency torch --dependency torchvision --dependency "pydicom" --dependency "xnat" -n
```

This creates a `pyproject.toml` file that declares your project and dependencies.

### 1.2. Installing dependencies and activating the environment

After editing `pyproject.toml` as needed:

```bash
# Install all dependencies into Poetry's managed virtualenv
poetry install

# Start a shell in the virtualenv
poetry shell
```

Inside the Poetry shell, start Jupyter (or VS Code) so this notebook runs inside the same environment:

```bash
jupyter lab
# or
jupyter notebook
```


## 2. Exporting `requirements.txt` and core imports

Even though Poetry is the source of truth, it's often useful to generate a `requirements.txt`
for deployment or other tools.


In [1]:
# 2.1 Export a requirements.txt from Poetry (run in a terminal, not in Python)
# -------------------------------------------------------------------------
# In your shell at the project root:
# poetry export -f requirements.txt --output requirements.txt --without-hashes

# 2.2 Core imports for this notebook
# ----------------------------------
import os
from pathlib import Path
import logging
import sys

import torch
from torch.utils.data import DataLoader

# Local project imports (code will live under src/)
project_root = Path(".").resolve()
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

# After we scaffold utils in later cells, these imports should work:
# from utils.xnat_io import get_xnat_connection, fetch_project_metadata
# from utils.datasets import XnatImageDataset, build_dataloaders
# from utils.training import train_model
# from utils.evaluation import evaluate_model


## 3. Utility functions for loading data from XNAT (`src/utils`)

All reusable I/O and data-wrangling logic should live under `src/utils/`.
This notebook will **scaffold** a minimal set of modules:

- `src/utils/xnat_io.py` – connecting to XNAT and pulling metadata
- `src/utils/datasets.py` – dataset and dataloader utilities
- `src/utils/training.py` – training loop and logging utilities
- `src/utils/evaluation.py` – evaluation helpers

You can then iterate on these modules without bloating the notebook.


In [2]:
# 3.1 Create project directories and utility module stubs

from pathlib import Path

project_root = Path(".").resolve()

# Core directories
src_dir = project_root / "src"
utils_dir = src_dir / "utils"
logs_dir = project_root / "logs"
results_dir = project_root / "results"
exploratory_dir = project_root / "exploratory"

for d in [src_dir, utils_dir, logs_dir, results_dir, exploratory_dir]:
    d.mkdir(parents=True, exist_ok=True)

# Ensure src and utils are Python packages
(src_dir / "__init__.py").touch(exist_ok=True)
(utils_dir / "__init__.py").touch(exist_ok=True)

# 3.2 Minimal xnat_io utilities
xnat_io_code = '''"""XNAT I/O utilities."""

from __future__ import annotations

from pathlib import Path
from typing import Dict, Any

import os

import xnat  # type: ignore


def get_xnat_connection(host: str | None = None, username: str | None = None, password: str | None = None):
    """Create and return an XNAT connection."""
    host = host or os.environ.get("XNAT_HOST")
    username = username or os.environ.get("XNAT_USER")
    password = password or os.environ.get("XNAT_PASS")

    if host is None or username is None or password is None:
        raise ValueError("XNAT_HOST, XNAT_USER, XNAT_PASS must be set or passed explicitly.")

    return xnat.connect(host=host, user=username, password=password)


def fetch_project_metadata(connection, project_id: str) -> Dict[str, Any]:
    """Fetch basic metadata for a given XNAT project."""
    if project_id not in connection.projects:
        raise KeyError(f"Project '{project_id}' not found on XNAT.")

    project = connection.projects[project_id]

    metadata: Dict[str, Any] = {
        "id": project.id,
        "name": project.name,
        "description": project.description,
        "num_experiments": len(project.experiments),
    }

    return metadata
'''

# 3.3 Minimal dataset/dataloader utilities
datasets_code = '''"""Dataset and DataLoader utilities for model-ready tensors."""

from __future__ import annotations

from pathlib import Path
from typing import Optional, Callable

import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader


class XnatImageDataset(Dataset):
    """Generic image dataset built from a manifest DataFrame.

    Expected columns:
    - 'filepath': path to image file on disk
    - 'label': integer class index
    """

    def __init__(
        self,
        manifest: pd.DataFrame,
        transform: Optional[Callable] = None,
    ) -> None:
        self.manifest = manifest.reset_index(drop=True)
        self.transform = transform

        if "filepath" not in self.manifest or "label" not in self.manifest:
            raise ValueError("Manifest must contain 'filepath' and 'label' columns.")

        if "class_name" in self.manifest:
            self.classes = sorted(self.manifest["class_name"].unique().tolist())
        else:
            self.classes = sorted(self.manifest["label"].unique().tolist())

        self.labels = self.manifest["label"].tolist()

    def __len__(self) -> int:
        return len(self.manifest)

    def __getitem__(self, idx: int):
        row = self.manifest.iloc[idx]
        img_path = Path(row["filepath"])
        label = int(row["label"])

        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img, label


def build_dataloaders(
    train_manifest: pd.DataFrame,
    val_manifest: pd.DataFrame,
    batch_size: int = 32,
    num_workers: int = 4,
    train_transform: Optional[Callable] = None,
    val_transform: Optional[Callable] = None,
):
    """Construct PyTorch dataloaders for train/val splits."""
    train_ds = XnatImageDataset(train_manifest, transform=train_transform)
    val_ds = XnatImageDataset(val_manifest, transform=val_transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return {"train": train_loader, "val": val_loader}
'''

# 3.4 Training utilities
training_code = '''"""Training utilities with basic logging."""

from __future__ import annotations

from typing import Dict

import torch
from torch import nn
from torch.utils.data import DataLoader


def train_model(
    model: nn.Module,
    dataloaders: Dict[str, DataLoader],
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: torch.device,
    num_epochs: int,
    logger,
) -> Dict[str, list]:
    """Generic training loop."""
    history = {"train_loss": [], "val_loss": [], "train_err": [], "val_err": []}

    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        for xb, yb in dataloaders["train"]:
            xb = xb.to(device)
            yb = yb.to(device)

            optimizer.zero_grad()
            outputs = model(xb)
            loss = criterion(outputs, yb)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * xb.size(0)
            preds = outputs.argmax(dim=1)
            running_correct += (preds == yb).sum().item()
            running_total += yb.size(0)

        epoch_train_loss = running_loss / running_total
        epoch_train_err = 1.0 - (running_correct / running_total)

        model.eval()
        val_running_loss = 0.0
        val_running_correct = 0
        val_running_total = 0

        with torch.no_grad():
            for xb, yb in dataloaders["val"]:
                xb = xb.to(device)
                yb = yb.to(device)
                outputs = model(xb)
                loss = criterion(outputs, yb)

                val_running_loss += loss.item() * xb.size(0)
                preds = outputs.argmax(dim=1)
                val_running_correct += (preds == yb).sum().item()
                val_running_total += yb.size(0)

        epoch_val_loss = val_running_loss / val_running_total
        epoch_val_err = 1.0 - (val_running_correct / val_running_total)

        history["train_loss"].append(epoch_train_loss)
        history["val_loss"].append(epoch_val_loss)
        history["train_err"].append(epoch_train_err)
        history["val_err"].append(epoch_val_err)

        if scheduler is not None:
            from torch.optim.lr_scheduler import ReduceLROnPlateau
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(epoch_val_loss)
            else:
                scheduler.step()

        logger.info(
            f"Epoch {epoch+1}/{num_epochs} | "
            f"train_loss={epoch_train_loss:.4f}, val_loss={epoch_val_loss:.4f}, "
            f"train_err={epoch_train_err:.4f}, val_err={epoch_val_err:.4f}"
        )

    return history
'''

# 3.5 Evaluation utilities
evaluation_code = '''"""Evaluation utilities for classification models."""

from __future__ import annotations

from pathlib import Path
from typing import Dict

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt


def evaluate_model(
    model: nn.Module,
    dataloaders: Dict[str, DataLoader],
    class_names,
    device: torch.device,
    results_dir: Path,
    prefix: str = "val",
) -> Dict[str, float]:
    """Run evaluation and save reports/plots to results_dir."""
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for xb, yb in dataloaders[prefix]:
            xb = xb.to(device)
            outputs = model(xb)
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(yb.numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True, zero_division=0)
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(len(class_names))))

    results_dir.mkdir(parents=True, exist_ok=True)

    report_path = results_dir / f"{prefix}_classification_report.txt"
    with report_path.open("w") as f:
        for label, metrics in report.items():
            f.write(f"{label}: {metrics}\n")

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm, interpolation="nearest")
    ax.figure.colorbar(im, ax=ax)
    ax.set(
        xticks=np.arange(len(class_names)),
        yticks=np.arange(len(class_names)),
        xticklabels=class_names,
        yticklabels=class_names,
        ylabel="True label",
        xlabel="Predicted label",
        title="Confusion matrix",
    )
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, int(cm[i, j]), ha="center", va="center")

    fig.tight_layout()
    fig_path = results_dir / f"{prefix}_confusion_matrix.png"
    fig.savefig(fig_path)
    plt.close(fig)

    return {
        "macro_f1": report["macro avg"]["f1-score"],
        "macro_precision": report["macro avg"]["precision"],
        "macro_recall": report["macro avg"]["recall"],
    }
'''

# Write the files
(utils_dir / "xnat_io.py").write_text(xnat_io_code)
(utils_dir / "datasets.py").write_text(datasets_code)
(utils_dir / "training.py").write_text(training_code)
(utils_dir / "evaluation.py").write_text(evaluation_code)

print("Scaffolded src/utils modules and core directories.")


Scaffolded src/utils modules and core directories.


## 4. Formatting and organizing data for the model

This section shows how to:

- Build a **manifest** (e.g. a `pandas.DataFrame`) with filepaths and labels.
- Use `src/utils/datasets.py` to create PyTorch `Dataset`/`DataLoader` objects.

In a real project, you would typically generate the manifest by querying XNAT,
downloading DICOMs to disk, and converting them to PNG/JPEG or tensors.


In [3]:
import pandas as pd
from torchvision import transforms
from utils.datasets import build_dataloaders

data_dir = Path("data")
manifest_path = data_dir / "manifest.csv"

if manifest_path.exists():
    manifest = pd.read_csv(manifest_path)
else:
    manifest = pd.DataFrame(columns=["filepath", "label", "class_name"])

if not manifest.empty:
    frac_train = 0.8
    train_manifest = manifest.sample(frac=frac_train, random_state=42)
    val_manifest = manifest.drop(train_manifest.index).reset_index(drop=True)
    train_manifest = train_manifest.reset_index(drop=True)
else:
    train_manifest = manifest.copy()
    val_manifest = manifest.copy()

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataloaders = build_dataloaders(
    train_manifest=train_manifest,
    val_manifest=val_manifest,
    batch_size=32,
    num_workers=4,
    train_transform=train_transform,
    val_transform=val_transform,
)

print("Train batches:", len(dataloaders["train"]))
print("Val batches:", len(dataloaders["val"]))


ValueError: num_samples should be a positive integer value, but got num_samples=0

## 5. Exploring the data and saving results to `exploratory/`

This section performs basic exploratory data analysis:

- Class distribution plots
- Example images

All figures should be saved under `exploratory/` for traceability.


In [None]:
from collections import Counter
import matplotlib.pyplot as plt

exploratory_dir.mkdir(parents=True, exist_ok=True)

if not train_manifest.empty:
    label_counts = Counter(train_manifest["class_name"])

    labels = list(label_counts.keys())
    values = [label_counts[l] for l in labels]

    plt.figure(figsize=(6, 4))
    plt.bar(labels, values)
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Count")
    plt.title("Training class distribution")
    plt.tight_layout()

    class_dist_path = exploratory_dir / "train_class_distribution.png"
    plt.savefig(class_dist_path)
    plt.show()

    print(f"Saved class distribution plot to: {class_dist_path}")
else:
    print("Manifest is empty; skipping EDA plots.")


## 6. Training the model and logging to `logs/`

We now define a model, configure logging, and call the reusable training
loop from `src/utils/training.py`. All epoch-level logs go to a file under `logs/`.


In [None]:
from datetime import datetime

import torch.nn as nn
import torch.optim as optim
from torchvision import models

from utils.training import train_model

logs_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_path = logs_dir / f"training_{timestamp}.log"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(log_path),
        logging.StreamHandler(),
    ],
)
logger = logging.getLogger("training")

logger.info("Starting training run")
logger.info(f"Logging to {log_path}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not train_manifest.empty:
    num_classes = len(train_manifest["label"].unique())
else:
    num_classes = 2
logger.info(f"Detected {num_classes} classes")

try:
    weights = models.ResNet50_Weights.IMAGENET1K_V2
    model = models.resnet50(weights=weights)
except AttributeError:
    model = models.resnet50(pretrained=True)

model.fc = nn.Linear(model.fc.in_features, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=2)

num_epochs = 5

if not train_manifest.empty:
    history = train_model(
        model=model,
        dataloaders=dataloaders,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        num_epochs=num_epochs,
        logger=logger,
    )
else:
    logger.warning("Manifest is empty; skipping actual training loop.")
    history = {"train_loss": [], "val_loss": [], "train_err": [], "val_err": []}

logger.info("Training complete")


## 7. Evaluating the model and saving results to `results/`

Finally, we evaluate the trained model on the validation set and save:

- Classification report
- Confusion matrix plot
- Aggregate metrics

All artifacts are written under `results/`.


In [None]:
from utils.evaluation import evaluate_model

results_dir.mkdir(parents=True, exist_ok=True)

if not train_manifest.empty:
    class_names = sorted(train_manifest["class_name"].unique().tolist())
    metrics = evaluate_model(
        model=model,
        dataloaders=dataloaders,
        class_names=class_names,
        device=device,
        results_dir=results_dir,
        prefix="val",
    )
    print("Evaluation metrics:", metrics)
else:
    print("Manifest is empty; skipping evaluation.")
