In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.functional as TF

In [None]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import numpy as np

### Model Block

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            )

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            x = self.ups[idx+1](torch.cat((skip_connection, x), dim=1))

        return self.final_conv(x)

### Test Blcok

In [None]:
def test():
  x = torch.randn((3,1,160,160))
  model = UNet(in_channels=1, out_channels=1)
  preds = model(x)
  print(preds.shape)
  print(x.shape)
  assert preds.shape == x.shape

In [None]:
test()

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])


### Dataset

In [None]:
class CustomDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):

    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index])#
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
    mask[mask==255.0] = 1.0
    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]

    return image, mask

### Utils

In [None]:
def save_checkpoint(state, filename="./drive/MyDrive/Seg/my_checkpoint.pth.tar"):
  print("===> Saving checkpoint...")
  torch.save(state, filename)

def load_checkpoint(checkpoint, model):
  print("===> Load 'checkpoint...")
  model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True
  ):
  train_ds = CustomDataset(train_dir, train_maskdir, train_transform)
  train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)

  val_ds = CustomDataset(val_dir, val_maskdir, val_transform)
  val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)

  return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
  num_correct = 0
  num_pixels = 0
  model.eval()

  with torch.no_grad():
    for x, y in loader:
      x = x.to(device)
      y = y.to(device).unsqueeze(1)
      preds = torch.sigmoid(model(x))
      preds = (preds > 0.5).float()
      num_correct += (preds == y ).sum()
      num_pixels += torch.numel(preds)

  print(f"Accuracy : {num_correct/num_pixels*100:.2f}")
  model.train()

def save_predictions_as_images(loader, model, folder="./drive/MyDrive/Seg/saved_images/", device="cuda"):
  model.eval()
  for idx, (x, y) in enumerate(loader):
    x = x.to(device)
    with torch.no_grad():
      preds = torch.sigmoid(model(x))
      preds = (preds > 0.5).float()
    torchvision.utils.save_image(preds, f"{folder}pred_{idx}.png")
    torchvision.utils.save_image(y.unsqueeze(1), f"{folder}y_{idx}.png")
  model.train()

### Train

In [None]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda"
BATCH_SIZE = 10
NUM_EPOCHS = 3
NUM_WORKERS = 0
IMAGE_HEIGHT = 500
IMAGE_WIDTH =  500
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "/content/drive/MyDrive/Seg/train/images"
TRAIN_MASK_DIR = "/content/drive/MyDrive/Seg/train/labels"
VAL_IMG_DIR = "/content/drive/MyDrive/Seg/test/images"
VAL_MASK_DIR = "/content/drive/MyDrive/Seg/test/labels"

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)

  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(DEVICE)
    targets = targets.float().unsqueeze(1).to(DEVICE)

    # forward
    with torch.cuda.amp.autocast():
      predictions = model(data)
      loss = loss_fn(predictions, targets)

    # backward
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # update tqdm loop
    loop.set_postfix(loss=loss.item())

In [None]:
train_transform = A.Compose([
      A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
      A.Rotate(limit=35, p=1.0),
      A.HorizontalFlip(p=0.5),
      A.VerticalFlip(p=1.0),
      A.Normalize(
          mean=[0.0, 0.0, 0.0],
          std=[1.0, 1.0, 1.0],
          max_pixel_value = 255.0
      ),
      ToTensorV2(),
  ], is_check_shapes=False)

val_transform = A.Compose([
      A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
      A.Normalize(
          mean=[0.0, 0.0, 0.0],
          std=[1.0, 1.0, 1.0],
          max_pixel_value = 255.0
      ),
      ToTensorV2(),
  ],  is_check_shapes=False)

In [None]:
train_loader, val_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      VAL_IMG_DIR,
      VAL_MASK_DIR,
      BATCH_SIZE,
      train_transform,
      val_transform,
      NUM_WORKERS,
      PIN_MEMORY
  )

In [None]:
def main(train_loader, val_loader):

  model = UNet(in_channels=3, out_channels=1).to(DEVICE)
  loss_fn = nn.BCEWithLogitsLoss()
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

  scaler = torch.cuda.amp.GradScaler()

  for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    checkpoint = {
        "state_dict" : model.state_dict(),
        "optimizer" : optimizer.state_dict(),
    }

    save_checkpoint(checkpoint, filename = f"./drive/MyDrive/Seg/my_checkpoint_{epoch}.pth.tar")

    check_accuracy(val_loader, model, device=DEVICE)

    save_predictions_as_images(val_loader, model, folder="./drive/MyDrive/Seg/saved_images/", device=DEVICE)

In [None]:
main(train_loader, val_loader)

100%|██████████| 246/246 [50:01<00:00, 12.20s/it, loss=0.256]


===> Saving checkpoint...
Accuracy : 99.40
