<a href="https://colab.research.google.com/github/akibkhan1/skin-lesion-classification/blob/main/Implementing_Pix2Pix_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

### Building Discriminator

In [None]:
class disc_block(nn.Module):
  def __init__(self, in_channels, out_channels, stride=2):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, bias=False, padding_mode="reflect"),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)

class discriminator(nn.Module):
  def __init__(self, in_channels=3, channels=[64, 128, 256, 512]):  # takes 256*256*3 input image down to 33*33*512
    super().__init__()

    # the initial block is different from the general sequential blocks. it doesn't contain BatchNorm

    self.initial_block = nn.Sequential(
        nn.Conv2d(in_channels*2, channels[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )

    # creating the subsequent blocks after the initial block

    blocks = []
    in_channels = channels[0]

    for channel in channels[1:]:
      blocks.append(
          disc_block(in_channels, channel, stride=1 if channel == channels[-1] else 2)  # in the last 512 convolution, authors used a stride of 1
      )
      in_channels = channel

    blocks.append(
        nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
    )

    self.model = nn.Sequential(*blocks)

  def forward(self, x, y):
    x = torch.cat([x, y], dim=1)
    x = self.initial_block(x)
    return self.model(x)

### Unit Test

In [None]:
def test():
  x = torch.rand((1, 3, 256, 256))
  y = torch.rand((1, 3, 256, 256))
  model = discriminator()
  preds = model(x, y)
  print(preds.shape)

In [None]:
test()

torch.Size([1, 1, 26, 26])


### Building Generator

In [None]:
class gen_block(nn.Module):
  def __init__(self, in_channels, out_channels, contracting_path=True, activation="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, padding_mode='reflect') if contracting_path
      else nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU() if activation == "relu"
      else nn.LeakyReLU(0.2)
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x

class generator(nn.Module):
  def __init__(self, in_channels=3, channels=64):
    super().__init__()

    # the initial block is different from the general sequential blocks. it doesn't contain BatchNorm

    self.initial_block = nn.Sequential(
        nn.Conv2d(in_channels, channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    ) # 128

    # creating the subsequent blocks after the initial block

    self.down1 = gen_block(channels, channels*2, contracting_path=True, activation="leaky_relu")  # 64
    self.down2 = gen_block(channels*2, channels*4, contracting_path=True, activation="leaky_relu")  # 32
    self.down3 = gen_block(channels*4, channels*8, contracting_path=True, activation="leaky_relu")  # 16
    self.down4 = gen_block(channels*8, channels*8, contracting_path=True, activation="leaky_relu")  # 8
    self.down5 = gen_block(channels*8, channels*8, contracting_path=True, activation="leaky_relu")  # 4
    self.down6 = gen_block(channels*8, channels*8, contracting_path=True, activation="leaky_relu")  # 2

    self.bottle_neck = nn.Sequential(
        nn.Conv2d(channels*8, channels*8, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.ReLU()
    ) # 1*1

    self.up1 = gen_block(channels*8, channels*8, contracting_path=False, activation="relu", use_dropout=True)
    self.up2 = gen_block(channels*8*2, channels*8, contracting_path=False, activation="relu", use_dropout=True)
    self.up3 = gen_block(channels*8*2, channels*8, contracting_path=False, activation="relu", use_dropout=True)
    self.up4 = gen_block(channels*8*2, channels*8, contracting_path=False, activation="relu", use_dropout=False)
    self.up5 = gen_block(channels*8*2, channels*4, contracting_path=False, activation="relu", use_dropout=False)
    self.up6 = gen_block(channels*4*2, channels*2, contracting_path=False, activation="relu", use_dropout=False)
    self.up7 = gen_block(channels*2*2, channels, contracting_path=False, activation="relu", use_dropout=False)
    
    self.final_block = nn.Sequential(
        nn.ConvTranspose2d(channels*2, in_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
    )

  def forward(self, x):
    d1 = self.initial_block(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)

    bottleneck = self.bottle_neck(d7)

    u1 = self.up1(bottleneck)
    u2 = self.up2(torch.cat([u1, d7], dim=1))
    u3 = self.up3(torch.cat([u2, d6], dim=1))
    u4 = self.up4(torch.cat([u3, d5], dim=1))
    u5 = self.up5(torch.cat([u4, d4], dim=1))
    u6 = self.up6(torch.cat([u5, d3], dim=1))
    u7 = self.up7(torch.cat([u6, d2], dim=1))

    return self.final_block(torch.cat([u7, d1], dim=1))

### Unit Test

In [None]:
def test_generator():
  x = torch.randn((1, 3, 256, 256))
  model = generator(in_channels=3, channels=64)
  preds = model(x)
  print(preds.shape)

In [None]:
test_generator()

torch.Size([1, 3, 256, 256])


### Preparing Data

In [None]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset
from torchvision.utils import save_image
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
!pip install albumentations==0.4.6
import albumentations 
from albumentations.pytorch import ToTensorV2

In [None]:
!gdown --id 19BErvkVNU02jPVlbPtjlfNkm_dg7Y7-t

Downloading...
From: https://drive.google.com/uc?id=19BErvkVNU02jPVlbPtjlfNkm_dg7Y7-t
To: /content/noaugmentation_512_384_dunet.zip
126MB [00:00, 145MB/s]


In [None]:
!unzip /content/noaugmentation_512_384_dunet.zip
!rm /content/noaugmentation_512_384_dunet.zip

In [None]:
image_dir = "/content/new_data/valid/image/"
mask_dir = "/content/new_data/valid/mask/"

images = os.listdir(image_dir)
masks = os.listdir(mask_dir)
print(len(images), len(masks))

259 259


In [None]:
for img_file in images:
  image = Image.open(os.path.join(image_dir, img_file))
  image = image.resize((256, 256))
  image.save(os.path.join(image_dir, img_file))

In [None]:
for mask_file in masks:
  mask = Image.open(os.path.join(mask_dir, mask_file))
  mask = mask.resize((256, 256))
  mask.save(os.path.join(mask_dir, mask_file))

In [None]:
both_transform = albumentations.Compose(
    [albumentations.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_only_input = albumentations.Compose(
    [
        albumentations.HorizontalFlip(p=0.5),
        # albumentations.ColorJitter(p=0.2),
        albumentations.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = albumentations.Compose(
    [
        albumentations.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

In [None]:
class map_dataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.image_list = os.listdir(os.path.join(self.root_dir, "image/"))
    self.mask_list = os.listdir(os.path.join(self.root_dir, "mask/"))

  def __len__(self):
    return len(self.image_list)

  def __getitem__(self, index):
    image_file = self.image_list[index]
    mask_file = self.mask_list[index]
    img_path = os.path.join(self.root_dir, "image/")
    image_path = os.path.join(img_path, image_file)
    msk_path = os.path.join(self.root_dir, "mask/")
    mask_path = os.path.join(msk_path, mask_file)
    input_image = np.array(Image.open(image_path))
    target_image = np.array(Image.open(mask_path))

    augmentations = both_transform(image=input_image, image0=target_image)
    input_image = augmentations["image"]
    target_image = augmentations["image0"]

    input_image = transform_only_input(image=input_image)["image"]
    target_image = transform_only_mask(image=target_image)["image"]

    return input_image, target_image

### Utils

In [None]:
def save_some_examples(DEVICE, gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="/content/weights/my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

### Configs

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/content/new_data/valid/"
VAL_DIR = "/content/new_data/test/"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 50
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

### Training

In [None]:
torch.backends.cudnn.benchmark = True

def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


def main():
    disc = discriminator(in_channels=3).to(DEVICE)
    gen = generator(in_channels=3, channels=64).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    train_dataset = map_dataset(root_dir=TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = map_dataset(root_dir=VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

        save_some_examples(DEVICE, gen, val_loader, epoch, folder="/content/evaluation")

In [None]:
main()

100%|██████████| 17/17 [00:06<00:00,  2.59it/s, D_fake=0.447, D_real=0.567]
100%|██████████| 17/17 [00:03<00:00,  5.05it/s, D_fake=0.319, D_real=0.64]
100%|██████████| 17/17 [00:03<00:00,  5.05it/s, D_fake=0.167, D_real=0.848]
100%|██████████| 17/17 [00:03<00:00,  5.08it/s, D_fake=0.148, D_real=0.855]
100%|██████████| 17/17 [00:03<00:00,  5.07it/s, D_fake=0.0433, D_real=0.954]
100%|██████████| 17/17 [00:03<00:00,  5.03it/s, D_fake=0.129, D_real=0.865]
100%|██████████| 17/17 [00:03<00:00,  5.09it/s, D_fake=0.358, D_real=0.602]
100%|██████████| 17/17 [00:03<00:00,  5.05it/s, D_fake=0.209, D_real=0.758]
100%|██████████| 17/17 [00:03<00:00,  5.00it/s, D_fake=0.52, D_real=0.613]
100%|██████████| 17/17 [00:03<00:00,  4.98it/s, D_fake=0.507, D_real=0.641]
100%|██████████| 17/17 [00:03<00:00,  5.01it/s, D_fake=0.414, D_real=0.519]
100%|██████████| 17/17 [00:03<00:00,  5.01it/s, D_fake=0.454, D_real=0.555]
100%|██████████| 17/17 [00:03<00:00,  5.01it/s, D_fake=0.384, D_real=0.62]
100%|█████████