<a href="https://colab.research.google.com/github/Bustion11/NN-projects/blob/main/UNet/U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torchvision
from torchvision.datasets import Cityscapes
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import albumentations as A
import os

In [None]:
class Unet(nn.Module):
  def _block(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(True),
        nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(True)
    )

  def __init__(self, img_channels, n_classes, features=[64, 128, 256, 512]):
    super().__init__()
    self.img_channels = img_channels
    self.n_classes = n_classes

    self.down = nn.MaxPool2d(2, 2)

    self.stack_down, self.stack_up, self.stack_upscale = nn.ModuleList([]), nn.ModuleList([]), nn.ModuleList([])

    for feature in features:
      self.stack_down.append(self._block(img_channels, feature))
      img_channels = feature

    for feature in reversed(features):
      self.stack_up.append(self._block(feature*2, feature))
      self.stack_upscale.append(nn.ConvTranspose2d(feature*2, feature, 2, 2))
    
    self.center = self._block(features[-1], features[-1]*2)
    
    self.final = nn.Conv2d(features[0], n_classes, 1)

  def forward(self, x):
    samples = []
    for layer in self.stack_down:
      x = layer(x)
      samples.append(x)
      x = self.down(x)
    
    x = self.center(x)

    for (up, up_scale, sample) in zip(self.stack_up, self.stack_upscale, reversed(samples)):
      x = up_scale(x)
      if x.shape != sample.shape:
        x = transforms.functional.resize(x, sample.shape[2:])
      x = torch.cat((x, sample), dim=1)
      x = up(x)

    return self.final(x)

In [None]:
model = Unet(3, 2)

model

In [None]:
# Config
BATCH_SIZE = 16
LR = 1e-4
IMG_CHANNELS = 3
N_CLASSES = None #To be assigned
TARGET_TYPE = 'semantic'
N_WORKERS = os.cpu_count()
IMAGE_HEIGHT = 200
IMAGE_WIDTH = 200

In [None]:
#Cityscapes('/data/cityscapes', "train", target_type=TARGET_TYPE)
#Cityscapes('/data/cistyscapes', "test", target_type=TARGET_TYPE)

In [None]:
import numpy as np
from PIL import Image

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
def get_loaders(BATCH_SIZE, N_WORKERS, IMG_TRANSFORM, TARGET_TRANSFORM, DATASET:Dataset):
  train_loader = None
  test_loader = None
  return train_loader, test_loader

In [None]:
def train_fn(model, optimizer, scaler, train_loader, test_loader, loss_fn=nn.BCEWithLogitsLoss()):
  for batch_idx, (x, y) in enumerate(train_loader):
    with torch.cuda.amp.autocast():
      loss = loss_fn(model(x), y)
    
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.stap(optimizer)
    scaler.update()

    if batch_idx%100==0:
      print(loss.item)

In [None]:
train_transform = A.Compose([
                             A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
                             A.Rotate(limit=35, p=1.0),
                             A.HorizontalFlip(p=0.5),
                             A.VerticalFlip(p=0.1),
                             A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
                             #ToTensorV2(),
])

test_transform = A.Compose([
                            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
                            A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
                            #ToTensorV2(),
])