In [1]:
from pathlib import Path

import torch
from torch import Tensor, nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.models import segmentation
from torchvision.transforms import v2

In [16]:
root = Path(r"..\dataset")
transform = v2.Compose(
    [v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.CenterCrop([500, 500])]
)
dataset = datasets.VOCSegmentation(
    root, year="2007", image_set="train", transforms=transform
)
train_loader = DataLoader(dataset, batch_size=2, drop_last=True)

num_epochs = 10
learn_step = 4
device = "cuda" if torch.cuda.is_available() else "cpu"
model = segmentation.deeplabv3_mobilenet_v3_large(num_classes=21)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.SGD(model.parameters())

In [None]:
model.to(device)
for i in range(num_epochs):
    model.train()
    total_loss = 0
    for j, (images, masks) in enumerate(train_loader):
        images: Tensor = images.to(device)
        # masks should be long int and no channel dim if it is 1
        masks: Tensor = masks.to(device, torch.long).squeeze(1)

        logits: dict[str, Tensor] = model(images)
        for k, v in logits.items():
            logits[k] = F.interpolate(v, masks.shape[-2:], mode="bilinear")
        losses = {k: criterion(v, masks) for k, v in logits.items()}
        loss_sum = sum(losses.values())
        if isinstance(loss_sum, Tensor):
            loss_sum.backward()

        if (j + 1) % learn_step == 0 or j == len(train_loader) - 1:
            print(f"Batch {j} optimize")
            optimizer.step()
            optimizer.zero_grad()
        total_loss += losses["out"].item()

    total_loss /= len(dataset)
    print("Epoch", i, total_loss)