In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from gan_synthesis.datasets.dataset import Dataset
from torch.utils.data import DataLoader
from gan_synthesis.u_net_models.unet import UNet
from tqdm import tqdm

# Instantiate data loaders
dataset = Dataset()
train_set, test_set = dataset.split(0.95)
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=True)

# Instantiate models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(anchor=32).to(device)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss(weight=dataset.weights.to(device))


train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
ce_losses = []


num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    running_train_loss = 0.0
    running_kld_loss = 0.0
    running_ce_loss = 0.0
    running_dice_loss = 0.0

    for image, seg in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
        correct, total = 0, 0
        contrast_input = image.to(torch.float32).to(device)
        seg_target = seg.squeeze(1).long().to(device)

        recon = model(contrast_input)

        ce_loss = criterion(recon, seg_target)
        loss = ce_loss

        pred = torch.argmax(recon, dim=1)
        correct += (pred == seg_target).sum().item()
        total += seg.numel()

        

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        running_ce_loss += ce_loss
        running_train_loss += loss.item()
    

    avg_train_loss = running_train_loss / len(train_loader)
    avg_kld_loss = running_kld_loss / len(train_loader)
    avg_ce_loss = running_ce_loss / len(train_loader)
    avg_dice_loss = running_dice_loss / len(train_loader)
    ce_losses.append(avg_ce_loss.item())
    train_losses.append(avg_train_loss)

    acc = correct / total
    train_accuracies.append(acc)

    # Validation
    model.eval()
    running_val_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for image, seg in tqdm(test_loader, desc="Test", leave=False):
            contrast_input = image.to(device)
            seg_target = seg.squeeze(1).long().to(device)
            recon = model(contrast_input)
            ce_loss = criterion(recon, seg_target)
            loss = ce_loss
            running_val_loss += loss.item()

            pred = torch.argmax(recon, dim=1)
            correct += (pred == seg_target).sum().item()
            total += seg.numel()

    avg_val_loss = running_val_loss / len(test_loader)
    val_losses.append(avg_val_loss)

    acc = correct / total
    val_accuracies.append(acc)

Epoch 1:   0%|          | 0/22 [00:00<?, ?it/s]

                                                          

In [None]:
# train.py
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from gan_synthesis.datasets.dataset import Dataset
from gan_synthesis.u_net_models.unet import UNet


@dataclass
class TrainConfig:
    seed: int = 42
    epochs: int = 200
    batch_size: int = 16
    lr: float = 5e-3
    weight_decay: float = 1e-4
    num_workers: int = 4
    pin_memory: bool = True
    amp: bool = True
    ckpt_dir: str = "checkpoints"
    val_split: float = 0.05
    anchor: int = 32
    lr_patience: int = 10  # ReduceLROnPlateau patience
    lr_factor: float = 0.5
    min_lr: float = 1e-5


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # comment out for speed if you don't need strict determinism
    # torch.use_deterministic_algorithms(True)
    # torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> float:
    """Pixel accuracy for segmentation."""
    pred = logits.argmax(dim=1)
    correct = (pred == targets).sum().item()
    total = targets.numel()
    return correct / total

def build(cfg: TrainConfig):
    """Construct everything needed for training/eval, no side-effects."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Data ---
    dataset = Dataset()
    train_set, val_set = dataset.split(1 - cfg.val_split)

    common_loader_kwargs = dict(
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        persistent_workers=cfg.num_workers > 0,
    )
    train_loader = DataLoader(train_set, shuffle=True,  **common_loader_kwargs)
    val_loader   = DataLoader(val_set,   shuffle=False, **common_loader_kwargs)

    # --- Model / Loss / Optimizer / Scheduler / Scaler ---
    model = UNet(anchor=cfg.anchor).to(device)

    class_weights = getattr(dataset, "weights", None)
    if class_weights is not None:
        class_weights = class_weights.to(device=device, dtype=torch.float32)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=cfg.lr_factor,
        patience=cfg.lr_patience,
        min_lr=cfg.min_lr,
    )
    scaler = torch.amp.GradScaler("cuda") if (cfg.amp and device.type == "cuda") else None

    return {
        "device":       device,
        "model":        model,
        "criterion":    criterion,
        "optimizer":    optimizer,
        "scheduler":    scheduler,
        "scaler":       scaler,
        "train_set":    train_set,
        "val_set":      val_set,
        "train_loader": train_loader,
        "val_loader":   val_loader,
    }


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scaler: torch.amp.GradScaler | None,
    device: torch.device,
) -> Dict[str, float]:
    model.train()
    running_loss = 0.0
    running_acc = 0.0

    pbar = tqdm(loader, desc="Train", leave=False)
    for images, seg in pbar:
        images = images.to(device=device, dtype=torch.float32, non_blocking=True)
        targets = seg.squeeze(1).long().to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.amp.autocast("cuda"):
                logits = model(images)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

        acc = accuracy_from_logits(logits.detach(), targets)
        running_loss += loss.item()
        running_acc += acc
        pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc:.4f}"})

    n = len(loader)
    return {"loss": running_loss / n, "acc": running_acc / n}


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    running_loss = 0.0
    running_acc = 0.0

    # inference_mode is a tad faster and safer than no_grad for eval
    with torch.inference_mode():
        pbar = tqdm(loader, desc="Val", leave=False)
        for images, seg in pbar:
            images = images.to(device=device, dtype=torch.float32, non_blocking=True)
            targets = seg.squeeze(1).long().to(device, non_blocking=True)
            logits = model(images)
            loss = criterion(logits, targets)
            acc = accuracy_from_logits(logits, targets)

            running_loss += loss.item()
            running_acc += acc
            pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc:.4f}"})

    n = len(loader)
    return {"loss": running_loss / n, "acc": running_acc / n}


def main(cfg: TrainConfig):
    set_seed(cfg.seed)
    parts = build(cfg)

    device       = parts["device"]
    model        = parts["model"]
    criterion    = parts["criterion"]
    optimizer    = parts["optimizer"]
    scheduler    = parts["scheduler"]
    scaler       = parts["scaler"]
    train_loader = parts["train_loader"]
    val_loader   = parts["val_loader"]

    # --- Bookkeeping ---
    Path(cfg.ckpt_dir).mkdir(parents=True, exist_ok=True)
    best_val = math.inf
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for epoch in range(cfg.epochs):
        print(f"\nEpoch {epoch+1}/{cfg.epochs}")

        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_metrics   = evaluate(model, val_loader, criterion, device)

        scheduler.step(val_metrics["loss"])

        history["train_loss"].append(train_metrics["loss"])
        history["train_acc"].append(train_metrics["acc"])
        history["val_loss"].append(val_metrics["loss"])
        history["val_acc"].append(val_metrics["acc"])

        print(
            f"train: loss {train_metrics['loss']:.4f} | acc {train_metrics['acc']:.4f} "
            f"|| val: loss {val_metrics['loss']:.4f} | acc {val_metrics['acc']:.4f} "
            f"|| lr: {optimizer.param_groups[0]['lr']:.2e}"
        )

        # Save latest
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scaler_state": scaler.state_dict() if scaler is not None else None,
                "cfg": cfg.__dict__,
                "history": history,
            },
            Path(cfg.ckpt_dir) / "latest.pt",
        )

        # Save best on val loss
        if val_metrics["loss"] < best_val:
            best_val = val_metrics["loss"]
            torch.save(model.state_dict(), Path(cfg.ckpt_dir) / "best_model.pt")

    print("\nTraining complete.")
    return model, history  # <-- return the trained model too



if __name__ == "__main__":
    cfg = TrainConfig()
    model, history = main(cfg)



Epoch 1/200


                                                                               

KeyboardInterrupt: 