In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 9.2 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 64.0 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 74.3 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.0 tokenizers-0.13.2 transformers-4.24.0


In [1]:
import os

os.chdir("drive/MyDrive/Colab Notebooks/Unet")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import shutil
import sys
from PIL import Image
import json
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from transformers import get_scheduler
from torchvision.utils import save_image

In [3]:
# definition of all blocks
class block(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3)
    self.relu  = nn.ReLU()
    self.conv2 = nn.Conv2d(out_channels, out_channels, 3)

  def forward(self, x):
    return self.conv2(self.relu(self.conv1(x)))

class encoder(nn.Module):
  def __init__(self, channels=(3,16,32,64)):
    super().__init__()
    self.blocks = nn.ModuleList([block(channels[i], channels[i+1]) for i in range(len(channels)-1)])
    self.pool = nn.MaxPool2d(2)
  def forward(self, x):
    features = []
    for b in self.blocks:
      x = b(x)
      features.append(x)
      x = self.pool(x)
    return features

class decoder(nn.Module):
  def __init__(self, channels=(64, 32, 16)):
    super().__init__()
    self.channels = channels
    self.upconvs = nn.ModuleList([nn.ConvTranspose2d(channels[i], channels[i+1], 2, 2) for i in range(len(channels)-1)])
    self.blocks = nn.ModuleList([block(channels[i], channels[i+1]) for i in range(len(channels)-1)])

  def crop(self, encoder_features, x):
    _, _, H, W = x.shape
    encoder_features = transforms.CenterCrop([H, W])(encoder_features)
    return encoder_features
  
  def forward(self, x, encoder_features):
    for i in range(len(self.channels)-1):
      x = self.upconvs[i](x)
      enc_ftrs  = self.crop(encoder_features[i], x)
      x = torch.cat([x, enc_ftrs], dim=1)
      x = self.blocks[i](x)
    return x

In [4]:
# Unet
class UNet(nn.Module):
  def __init__(self, encoder_channels=(3, 16, 32, 64), decoder_channels=(64, 32, 16), num_class=1, retain_dim=True, out_sz=(128, 128)):
    super().__init__()
    self.encoder = encoder(encoder_channels)
    self.decoder = decoder(decoder_channels)
    self.head = nn.Conv2d(decoder_channels[-1], num_class, 1)
    self.retain_dim = retain_dim
    self.out_sz = out_sz

  def forward(self, x):
    encoder_features = self.encoder(x)
    out = self.decoder(encoder_features[::-1][0], encoder_features[::-1][1:])
    out = self.head(out)
    if self.retain_dim:
        out = F.interpolate(out, self.out_sz)
    return out

In [5]:
transform_images = transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
transform_labels = transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ]
    )

In [6]:
class ipt_dataset(Dataset):
  def __init__(self, root_dir, annotation_file, transform_images=None, transform_labels=None):
    self.root_dir = root_dir
    self.annotations = pd.read_csv(annotation_file)

    # feature extraction
    self.transform_images = transform_images
    self.transform_labels = transform_labels

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):
    img_id = self.annotations.iloc[index, 0]
    img = Image.open(os.path.join(self.root_dir, "images" , f"{img_id}.png")).convert("RGB")
    label = Image.open(os.path.join(self.root_dir, "masks" , f"{img_id}.png")).convert("1")
    img = self.transform_images(img)
    label = self.transform_labels(label)

    return (img, label)

In [7]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 5
learning_rate = 2e-5
batch_size = 16
shuffle = True
num_workers = 1

dataset = ipt_dataset("train","train.csv", transform_images=transform_images, transform_labels=transform_labels)
train_loader = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size,num_workers=num_workers, pin_memory=True)

In [8]:
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

learning_rate_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_epochs * len(train_loader),
    )

In [9]:
# training
for epoch in range(num_epochs):
  model.train()
  loop = tqdm(train_loader, total = len(train_loader), leave = True)
  for imgs, labels in loop:
    imgs = imgs.to(device)
    labels = labels.to(device)
    
    optimizer.zero_grad()
    outputs = model(imgs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    learning_rate_scheduler.step()
    loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
    loop.set_postfix(loss = loss.item())
  checkpoint = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
  }
  torch.save(checkpoint, "checkpoint")

Epoch [1/5]: 100%|██████████| 4000/4000 [34:06<00:00,  1.95it/s, loss=1.49]
Epoch [2/5]: 100%|██████████| 4000/4000 [01:17<00:00, 51.56it/s, loss=0.336]
Epoch [3/5]: 100%|██████████| 4000/4000 [01:16<00:00, 52.44it/s, loss=0.23]
Epoch [4/5]: 100%|██████████| 4000/4000 [01:17<00:00, 51.33it/s, loss=0.684]
Epoch [5/5]: 100%|██████████| 4000/4000 [01:19<00:00, 50.16it/s, loss=0.67]


In [28]:
# testing
model.eval()
loop = tqdm(train_loader, total = len(train_loader), leave = True)
cnt = 0
for imgs, labels in loop:
  imgs = imgs.to(device)
  labels = labels.to(device)

  outputs = model(imgs)
  outputs = nn.Sigmoid()(outputs)
  save_image(imgs, f'{cnt}.png')
  outputs = (torch.squeeze(outputs) > 0.35).type(torch.FloatTensor)
  save_image(outputs, f'{cnt}_p.png')
  save_image(labels, f'{cnt}_t.png')
  cnt +=1 
  if cnt == 7:
    break


  0%|          | 6/4000 [00:00<03:14, 20.54it/s]
