In [4]:
import sys
from pathlib import Path

import torch
from torch import GradScaler, nn, optim
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.models import segmentation

sys.path.append(str(Path("..").resolve()))
from src.datasets import DATASET_ZOO
from src.models import MODEL_ZOO
from src.pipeline import LocalLogger, Trainer, init_logging
from src.utils.rng import seed
from src.utils.transform import DataAugment, DataTransform

In [None]:
root = Path(r"../dataset")
entry = DATASET_ZOO["VOC"]
meta = entry.meta
train_dataset = entry.construct_train(
    root=root, transforms=DataTransform((500, 500)), year="2007"
)
val_dataset = entry.construct_val(root=root, transforms=DataTransform(), year="2007")
# train_dataset = data.Subset(train_dataset, range(50))
# val_dataset = data.Subset(val_dataset, range(50))
train_loader = DataLoader(train_dataset, batch_size=2, drop_last=True)
val_loader = DataLoader(val_dataset)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = MODEL_ZOO["deeplabv3_resnet50"](num_classes=meta.num_classes, aux_loss=True)
criterion = nn.CrossEntropyLoss(ignore_index=meta.ignore_index)
optimizer = optim.SGD(model.parameters(), lr=3e-4, momentum=0.9, weight_decay=5e-4)

run_folder = Path("../runs/exp1")

In [10]:
trainer = Trainer(
    model,
    train_loader,
    DataAugment(hflip=0.5),
    val_loader,
    DataAugment(),
    criterion,
    optimizer,
    optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5),
    GradScaler(device),
    device,
    4,
    10,
    {"aux": 0.5},
    meta.num_classes,
    meta.labels,
    meta.colors,
    run_folder,
    checkpoint_epochs=5,
    loggers=(LocalLogger(run_folder, meta.labels),),
)

In [None]:
seed(42)
run_folder.mkdir(parents=True)
with init_logging(run_folder / "run.log"):
    trainer.train()

2025-02-19 08:58:10 :: src.pipeline.trainer.INFO     :: ----- Epoch [   0/10] -----
train:   7%|▋         | 7/104 [00:37<07:45,  4.80s/it, acc=0.00536, macc=0.0253, miou=0.00266, fwiou=0.00172, dice=0.00527, loss=1.65, time=2.64] 