In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import SBDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


In [21]:
from torch.utils.data import Dataset

class SBDWithTransforms(Dataset):
    def __init__(self, base_dataset, img_transform=None, mask_transform=None):
        self.base = base_dataset
        self.img_transform = img_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, mask = self.base[idx]

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.mask_transform is not None:
            mask = self.mask_transform(mask)

        return img, mask


In [22]:
from torch.utils.data import Subset, random_split

transform_img = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

transform_mask = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.PILToTensor()
])

sbd_base = SBDataset(
    root="data",
    image_set="train",
    mode="segmentation",
    download=False
)

train_dataset = SBDWithTransforms(
    sbd_base,
    img_transform=transform_img,
    mask_transform=transform_mask
)

train_sub_dataset = Subset(train_dataset, range(4000))

dataset_size = len(train_sub_dataset)
val_size = int(0.2 * dataset_size)
train_size = dataset_size - val_size

train_ds, val_ds = random_split(
    train_sub_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_ds,
                          batch_size=8,
                          shuffle=True,
                          num_workers=0,
                          pin_memory=True,
                          persistent_workers=False)

val_loader = DataLoader(val_ds,
                        batch_size=8,
                        shuffle=False,
                        num_workers=0,
                        pin_memory=True,
                        persistent_workers=False
)


In [23]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [24]:
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.down1 = DoubleConv(3, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.middle = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        c1 = self.down1(x)
        c2 = self.down2(self.pool(c1))
        c3 = self.down3(self.pool(c2))

        m = self.middle(self.pool(c3))

        u3 = self.up3(m)
        u3 = self.conv3(torch.cat([u3, c3], dim=1))

        u2 = self.up2(u3)
        u2 = self.conv2(torch.cat([u2, c2], dim=1))

        u1 = self.up1(u2)
        u1 = self.conv1(torch.cat([u1, c1], dim=1))

        return self.out(u1)


In [25]:
def validate(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(device)
            masks = masks.squeeze(1).long().to(device)

            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            preds = outputs.argmax(dim=1)
            valid = masks != 255
            correct += (preds[valid] == masks[valid]).sum().item()
            total += valid.sum().item()

    val_loss = total_loss / len(loader)
    val_acc = 100.0 * correct / total
    return val_loss, val_acc


In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


num_classes = 21  # SBD semantic classes
model = UNet(num_classes).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


cuda
True
NVIDIA GeForce GTX 1660


In [27]:
import time
t0 = time.time()
img, mask = train_dataset[122]
print("Load time:", time.time() - t0)



Load time: 0.008517026901245117


In [28]:
import random
epochs = 3

idx = random.randint(0, len(train_dataset) - 1)
img, mask = train_dataset[idx]

# print(type(img))
# print(type(mask))

for epoch in range(epochs):
    model.train()
    train_total, train_correct, train_loss = 0 , 0, 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=True)

    for imgs, masks in pbar:
        imgs = imgs.to(device, non_blocking=True)
        masks = masks.squeeze(1).long().to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)

        loss.backward()
        optimizer.step()

        preds = outputs.argmax(dim=1)
        valid = masks != 255

        train_correct += (preds[valid] == masks[valid]).sum().item()
        train_total += valid.sum().item()

        train_loss += loss.item()

        pbar.set_postfix(loss=loss.item())

    train_loss = train_loss / len(train_dataset)
    train_acc = 100 * train_correct / train_total

    val_loss, val_acc = validate(model, val_loader)

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%"
    )

    model.eval()
    img, mask = train_dataset[idx]
    with torch.no_grad():
        pred = model(img.unsqueeze(0).to(device))
        pred = torch.argmax(pred, dim=1).squeeze().cpu()


Epoch 1/3: 100%|██████████| 400/400 [05:42<00:00,  1.17it/s, loss=1.3]  


Epoch 1/3 | Train Loss: 0.0698, Train Acc: 67.88% | Val Loss: 1.3188, Val Acc: 70.33%


Epoch 2/3: 100%|██████████| 400/400 [04:12<00:00,  1.58it/s, loss=1.82] 


Epoch 2/3 | Train Loss: 0.0613, Train Acc: 70.47% | Val Loss: 1.3279, Val Acc: 70.33%


Epoch 3/3: 100%|██████████| 400/400 [04:19<00:00,  1.54it/s, loss=0.939]


Epoch 3/3 | Train Loss: 0.0606, Train Acc: 70.47% | Val Loss: 1.2839, Val Acc: 70.33%


In [None]:
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.title("Image")
plt.imshow(img.permute(1,2,0))
plt.axis("off")

plt.subplot(1,3,2)
plt.title("Ground Truth")
plt.imshow(mask.squeeze(), cmap="jet")
plt.axis("off")

plt.subplot(1,3,3)
plt.title("Prediction")
plt.imshow(pred, cmap="jet")
plt.axis("off")

plt.show()
