In [1]:
import sys
from pathlib import Path

import torch
from torch import GradScaler, Tensor, nn, optim
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.models import segmentation
from torchvision.transforms import v2

sys.path.append(str(Path("..").resolve()))
from src.datasets import DATASET_ZOO, VOC_COLORS, VOC_LABELS
from src.models import MODEL_ZOO
from src.pipeline import LocalLogger, Trainer, init_logging
from src.utils.rng import seed
from src.utils.transform import DataAugment, DataTransform

In [2]:
print(DATASET_ZOO.keys())
print(MODEL_ZOO.keys())

dict_keys([])
dict_keys([])


In [6]:
root = Path(r"..\dataset")
train_dataset = datasets.VOCSegmentation(
    root=root, transforms=DataTransform((500, 500)), year="2007", image_set="train"
)
train_dataset = data.Subset(train_dataset, range(50))
val_dataset = datasets.VOCSegmentation(
    root=root, transforms=DataTransform(), year="2007", image_set="train"
)
val_dataset = data.Subset(val_dataset, range(50))
train_loader = DataLoader(train_dataset, batch_size=2, drop_last=True)
val_loader = DataLoader(val_dataset)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = segmentation.deeplabv3_mobilenet_v3_large(num_classes=21, aux_loss=True)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.SGD(model.parameters(), lr=3e-4, momentum=0.9, weight_decay=5e-4)

run_folder = Path("../runs/exp1")

In [7]:
trainer = Trainer(
    model,
    train_loader,
    DataAugment(hflip=0.5),
    val_loader,
    DataAugment(),
    criterion,
    optimizer,
    optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5),
    GradScaler(device),
    device,
    4,
    10,
    {"aux": 0.5},
    21,
    VOC_LABELS,
    VOC_COLORS,
    run_folder,
    checkpoint_epochs=5,
    loggers=(LocalLogger(run_folder, VOC_LABELS),),
)

In [None]:
seed(42)
run_folder.mkdir(parents=True)
with init_logging(run_folder / "run.log"):
    trainer.train()