# Segmentation

In [1]:
!pip install pytorch-lightning



In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from torchvision.datasets import VOCSegmentation
from PIL import Image
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [14]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cpu


In [16]:
class SegTransform:
  def __init__(self):
    self.transform = A.Compose([
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5),
        A.Normalize(),
        ToTensorV2()
    ])

  def __call__(self, image, mask):
    transformed = self.transform(image=np.array(image), mask=np.array(mask))
    return transformed['image'], transformed['mask'].long()

In [17]:
class VOCDataset(torch.utils.data.Dataset):
  def __init__(self, root, year='2012', image_set='train', transforms=None):
    self.dataset = VOCSegmentation(root, year=year, image_set=image_set, download=True)
    self.transforms = transforms

  def __getitem__(self, idx):
    img, mask = self.dataset[idx]
    if self.transforms:
      img, mask = self.transforms(img, mask)
    return img, mask

  def __len__(self):
    return len(self.dataset)

In [18]:
class SimpleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.conv_down1 = SimpleConv(3, 32)
        self.conv_down2 = SimpleConv(32, 64)
        self.conv_down3 = SimpleConv(64, 128)
        self.conv_down4 = SimpleConv(128, 256)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = SimpleConv(128 + 256, 128)
        self.conv_up2 = SimpleConv(64 + 128, 64)
        self.conv_up1 = SimpleConv(32 + 64, 32)

        self.conv_last = nn.Conv2d(32, n_classes, 1)

    def forward(self, x):
        conv1 = self.conv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.conv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.conv_down3(x)
        x = self.maxpool(conv3)

        x = self.conv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.conv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.conv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.conv_up1(x)

        out = self.conv_last(x)
        return out

In [19]:
class LitSegModel(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.model = UNet(n_classes=21)
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y, ignore_index=255)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y, ignore_index=255)
        preds = logits.argmax(dim=1)
        iou = self.compute_iou(preds, y)
        dice = self.dice_coeff(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_iou", iou, prog_bar=True)
        self.log("val_dice", dice, prog_bar=True)
        return {"loss": loss, "iou": iou, "dice": dice}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def compute_iou(self, pred, target, eps=1e-6):
        intersection = ((pred == target) & (target != 255)).sum().float()
        union = ((pred != 255) | (target != 255)).sum().float()
        return intersection / (union + eps)

    def dice_coeff(self, pred, target, eps=1e-6):
        intersection = ((pred == target) & (target != 255)).sum().float()
        total = ((pred != 255) + (target != 255)).sum().float()
        return 2 * intersection / (total + eps)

In [None]:
transform = SegTransform()
train_dataset = VOCDataset(root="./", image_set='train', transforms=transform)
val_dataset = VOCDataset(root="./", image_set='val', transforms=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

checkpoint = ModelCheckpoint(monitor="val_iou", mode="max", save_top_k=1, filename="best_unet")
early_stop = EarlyStopping(monitor="val_loss", mode="min", patience=3)

model = LitSegModel()

trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=10,
    callbacks=[checkpoint, early_stop],
    log_every_n_steps=10
)

trainer.fit(model, train_loader, val_loader)

def show_pred(model, dataset, idx=0):
    model.eval()
    with torch.no_grad():
        img, mask = dataset[idx]
        img_input = img.unsqueeze(0).to(DEVICE)
        pred = model(img_input)
        pred = pred.argmax(dim=1).squeeze().cpu().numpy()

        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(img.permute(1, 2, 0).cpu())
        axs[0].set_title("Image")
        axs[1].imshow(mask.cpu())
        axs[1].set_title("Ground Truth")
        axs[2].imshow(pred)
        axs[2].set_title("Prediction")
        for ax in axs: ax.axis("off")
        plt.tight_layout()
        plt.show()


best_model = LitSegModel.load_from_checkpoint("best_unet.ckpt")
best_model.to(DEVICE)
show_pred(best_model, val_dataset, idx=3)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | UNet | 969 K  | train
---------------------------------------
969 K     Trainable params
0         Non-trainable params
969 K     Total params
3.880     Total estimated model params size (MB)
32        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]