## Îç∞Ïù¥ÌÑ∞ÏÖã Îã§Ïö¥Î°úÎìú

In [1]:
!pip install -q torchvision

In [2]:
from torchvision.datasets import OxfordIIITPet

raw_dataset = OxfordIIITPet(
    root="./data/oxfordpet",
    download=True,
    target_types="segmentation"
)

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 792M/792M [00:50<00:00, 15.7MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19.2M/19.2M [00:02<00:00, 7.86MB/s]


## Overall Process Implementation

In [24]:
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from tqdm import tqdm
import logging
from PIL import Image
from torch import Tensor
from torchvision.transforms import Compose

### Model implementation

In [17]:
class DoubleConv(nn.Module):
  """(conv > bn > relu) * 2 """
  def __init__(self, in_channels, out_channels, mid_channels=None):
    super().__init__()
    if not mid_channels:
      mid_channels = out_channels
    self.double_conv = nn.Sequential(
        nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(mid_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.double_conv(x)

class Down(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.maxpool_conv = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_channels, out_channels)
    )

  def forward(self, x):
    return self.maxpool_conv(x)

class Up(nn.Module):
  def __init__(self, in_channels, out_channels, bilinear=True):
    super().__init__()

    if bilinear:
      self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
      self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
    else:
      self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
      self.conv = DoubleConv(in_channels, out_channels)

  def forward(self, x1, x2):
    x1 = self.up(x1)
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]
    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

    x = torch.cat([x2, x1], dim=1)
    return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [18]:
class UNet(nn.Module):
  def __init__(self, n_channels, n_classes, bilinear=False):
    super(UNet, self).__init__()
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.bilinear = bilinear

    self.inc = (DoubleConv(n_channels, 64))
    self.down1 = (Down(64, 128))
    self.down2 = (Down(128, 256))
    self.down3 = (Down(256, 512))
    factor = 2 if bilinear else 1
    self.down4 = (Down(512, 1024 // factor))
    self.up1 = (Up(1024, 512 // factor, bilinear))
    self.up2 = (Up(512, 256 // factor, bilinear))
    self.up3 = (Up(256, 128 // factor, bilinear))
    self.up4 = (Up(128, 64, bilinear))
    self.outc = (OutConv(64, n_classes))

  def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)
    logits = self.outc(x)
    return logits

  def use_checkpointing(self):
    self.inc = torch.utils.checkpoint(self.inc)
    self.down1 = torch.utils.checkpoint(self.down1)
    self.down2 = torch.utils.checkpoint(self.down2)
    self.down3 = torch.utils.checkpoint(self.down3)
    self.down4 = torch.utils.checkpoint(self.down4)
    self.up1 = torch.utils.checkpoint(self.up1)
    self.up2 = torch.utils.checkpoint(self.up2)
    self.up3 = torch.utils.checkpoint(self.up3)
    self.up4 = torch.utils.checkpoint(self.up4)
    self.outc = torch.utils.checkpoint(self.outc)

## Dataset class

In [45]:
class SegmentationDataset(Dataset):
  """Custom dataset """
  def __init__(self, images, masks, transforms):
    self.images = images
    self.masks = masks
    self.transforms = transforms

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    image_path = self.images[idx]
    mask_path = self.masks[idx]

    # open
    image = Image.open(image_path).convert("RGB")
    mask = Image.open(mask_path).convert('L') # GRAYSCALE

    # Apply transform
    if self.transforms is not None:
      image = self.transforms(image)
      mask = self.transforms(mask)
      # transforms.toTensor() Î•º ÌïòÍ≤åÎêòÎ©¥ image/255 Î•º Ìï¥ÏÑú Îã§Ïãú Í≥±Ìï¥Ï£ºÍ≥† -1 ÏùÑ Ìï¥Ï£ºÎäîÍ≤É.
      mask = (mask*255).squeeze().to(torch.int64)
      mask -= 1 # Ìï¥Îãπ Îç∞Ïù¥ÌÑ∞Í∞Ä 1,2,3 Î∂ÄÌÑ∞ Î†àÏù¥Î∏îÎßÅÏù¥ ÎêòÏñ¥ÏûàÎã§Í≥† Ìï®.

    return image, mask

## Train loop

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: ", device)

device:  cuda


In [63]:
# Variables
epochs = 5
batch_size = 16
learning_rate = 1e-3
bilinear = True
classes = 3
image_size = (256, 256)

model = UNet(n_channels=3, n_classes=classes, bilinear=bilinear)
model = model.to(device)

print(f'Network:\n'
      f'\t{model.n_channels} input channels\n'
      f'\t{model.n_classes} output channels (classes)\n'
      f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

Network:
	3 input channels
	3 output channels (classes)
	Bilinear upscaling


In [72]:
from random import shuffle
## Dataset
image_path = "/content/data/oxfordpet/oxford-iiit-pet/images"
mask_path = "/content/data/oxfordpet/oxford-iiit-pet/annotations/trimaps"
split_rate = 0.2

# Create a list of image paths
img_paths = sorted([
    os.path.join(image_path, name)
    for name in os.listdir(image_path)
    if name.endswith('.jpg')
])

# Create a list of mask paths
mask_paths = sorted([
    os.path.join(mask_path, name)
    for name in os.listdir(mask_path)
    if not name.startswith('.') and name.endswith('.png')
])

tmp = list(zip(img_paths, mask_paths))
shuffle(tmp)
img_paths, mask_paths = zip(*tmp)
img_paths, mask_paths = list(img_paths), list(mask_paths)
train_imgs = img_paths[int(split_rate*len(img_paths)):]
train_masks = mask_paths[int(split_rate*len(mask_paths)):]
test_imgs = img_paths[:int(split_rate * len(img_paths))]
test_masks = mask_paths[:int(split_rate * len(mask_paths))]

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()
])

train_dataset = SegmentationDataset(train_imgs, train_masks, transform)
test_dataset = SegmentationDataset(test_imgs, test_masks, transform)
print('Train images: {}\n Test images: {}'.format(len(train_dataset), len(test_dataset)))

train_loader = DataLoader(train_dataset, batch_size=batch_size,
           num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
           num_workers=8, pin_memory=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=30, eta_min=1e-6
)

Train images: 5912
 Test images: 1478


In [65]:
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [66]:
best_loss = float('inf')
checkpoint_path = "/content/segmentation_checkpoints"
os.makedirs(checkpoint_path, exist_ok=True)

for epoch in range(1, epochs+1):
  model.train()
  train_loss = 0.0
  test_loss = 0.0

  # for i, (images, masks) in tqdm( enumerate(train_loader)):
  for images, masks in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
    images, masks = images.to(device), masks.to(device)
    pred = model(images)
    loss = criterion(pred, masks)
    dice_loss_ = dice_loss(
        F.softmax(pred, dim=1).float(),
        F.one_hot(masks, model.n_classes).permute(0, 3, 1, 2).float(),
        multiclass=True
      )
    loss += dice_loss_

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  # evaluate
  with torch.no_grad():
    model.eval()
    # for i, (images, masks) in tqdm(enumerate(test_loader)):
    for images, masks in tqdm(test_loader, desc=f"Epoch {epoch} [Eval]"):
      images, masks = images.to(device), masks.to(device)
      pred = model(images)
      loss = criterion(pred, masks)
      loss += dice_loss(
        F.softmax(pred, dim=1).float(),
        F.one_hot(masks, model.n_classes).permute(0, 3, 1, 2).float(),
        multiclass=True
      )
      test_loss += loss.item()

    # Calculate
    avg_train_loss = train_loss / len(train_loader)
    avg_test_loss = test_loss / len(test_loader)

    print('Train loss: {} Test loss: {}'.format(avg_train_loss, avg_test_loss))

  scheduler.step()
  # Save checkpoint
  ckpt_file = os.path.join(checkpoint_path, "best_model.pth")
  if avg_test_loss < best_loss:
    best_loss = avg_test_loss
    torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': best_loss
      }, ckpt_file)
    print(f"üìå Best model saved at epoch {epoch} (loss={best_loss:.4f})")

370it [03:57,  1.56it/s]
93it [00:20,  4.65it/s]


Train loss: 0.8864724131854804 Test loss: 0.6911396659830565
üìå Best model saved at epoch 1 (loss=0.6911)


370it [03:56,  1.56it/s]
93it [00:21,  4.33it/s]


Train loss: 0.6077948816724725 Test loss: 0.5773407079840219
üìå Best model saved at epoch 2 (loss=0.5773)


370it [03:56,  1.56it/s]
93it [00:20,  4.54it/s]


Train loss: 0.5287298710765065 Test loss: 0.5385000455764032
üìå Best model saved at epoch 3 (loss=0.5385)


370it [03:57,  1.56it/s]
93it [00:20,  4.45it/s]


Train loss: 0.47995712378540556 Test loss: 0.5138944602140816
üìå Best model saved at epoch 4 (loss=0.5139)


370it [03:56,  1.56it/s]
93it [00:20,  4.50it/s]


Train loss: 0.443558214645128 Test loss: 0.5053440204230688
üìå Best model saved at epoch 5 (loss=0.5053)


## Test

In [70]:
import numpy as np
import cv2

# class index ‚Üí RGB Ïª¨Îü¨
colors = {
    0: (0, 0, 0),       # background
    1: (255, 0, 0),     # class1
    2: (0, 255, 0),     # class2
}

def decode_segmap(mask):
    h,w = mask.shape
    result = np.zeros((h,w,3), dtype=np.uint8)
    for cls, color in colors.items():
        result[mask == cls] = color
    return result

test_loader = DataLoader(test_dataset, batch_size=1,num_workers=8, pin_memory=True)

model = UNet(n_channels=3, n_classes=3, bilinear=bilinear).to(device)
state_dict = torch.load("/content/best_model.pth", map_location=device)
model.load_state_dict(state_dict['model_state_dict'])

save_dir = "/content/seg_results/"
os.makedirs(save_dir, exist_ok=True)

with torch.no_grad():
  model.eval()
  for i, (images, masks) in tqdm(enumerate(test_loader)):
    images = images.to(device)
    pred = model(images) # (1, 3, H, W)
    pred_mask = torch.argmax(pred, dim=1) # (1, H, W)
    mask_np = pred_mask.squeeze().cpu().numpy().astype(np.uint8)
    # ÌååÏùº Ïù¥Î¶Ñ Ï†ÄÏû•
    save_path = os.path.join(save_dir, f"mask_{i}.png")
    color_mask = decode_segmap(mask_np)
    cv2.imwrite(save_path, color_mask)

print(f"\nüìå Saved inference masks to: {save_dir}")

1478it [00:31, 46.74it/s]


üìå Saved inference masks to: /content/seg_results/



