In [78]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from PIL import Image
import numpy as np
import os

In [60]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=2):
    super(CNNBlock, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, padding_mode='reflect', padding=1, stride=stride, bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)

In [61]:
class Discriminator(nn.Module):
  def  __init__(self, in_channels=3, features=[64, 128, 256, 512]):
    super(Discriminator, self).__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels=in_channels*2,
                  out_channels=features[0],
                  kernel_size=4,
                  stride=2,
                  padding=1,
                  padding_mode='reflect'),
        nn.LeakyReLU(0.2)
    )

    layers = []
    in_channels = features[0]

    for feature in features[1:]:
      layers.append(CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2))
      in_channels=feature

    layers.append(
        nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
    )

    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = torch.cat([x, y], dim=1)
    x = self.initial(x)
    return self.model(x)


In [62]:
def test_disc():
  x = torch.rand(1, 3, 256, 256)
  y = torch.rand(1, 3, 256, 256)

  model = Discriminator()

  preds = model(x, y)

  print(preds.shape)

In [63]:
test_disc()

torch.Size([1, 1, 30, 30])


In [64]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, act='relu', dropout=False, down=True):
    super(Block, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, padding_mode='reflect', bias=False) if down else \
        nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act == 'relu' else nn.LeakyReLU(0.2),
    )

    self.use_dropout = dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x


In [65]:
class Generator(nn.Module):
  def __init__(self, in_channels, features=64):
    super(Generator, self).__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
        nn.ReLU()
    )

    self.down1 = Block(features, features*2, act='leaky', down=True, dropout=False)
    self.down2 = Block(features*2, features*4, act='leaky', down=True, dropout=False)
    self.down3 = Block(features*4, features*8, act='leaky', down=True, dropout=False)
    self.down4 = Block(features*8, features*8, act='leaky', down=True, dropout=False)
    self.down5 = Block(features*8, features*8, act='leaky', down=True, dropout=False)
    self.down6 = Block(features*8, features*8, act='leaky', down=True, dropout=False)

    self.neck = nn.Sequential(
        nn.Conv2d(in_channels=features*8, out_channels=features*8, padding_mode='reflect', kernel_size=4, stride=2, padding=1),
        nn.ReLU()
    )

    self.up1 = Block(features*8, features*8, act='relu', down=False, dropout=True)
    self.up2 = Block(features*8*2, features*8, act='relu', down=False, dropout=True)
    self.up3 = Block(features*8*2, features*8, act='relu', down=False, dropout=True)
    self.up4 = Block(features*8*2, features*8, act='relu', down=False, dropout=False)
    self.up5 = Block(features*8*2, features*4, act='relu', down=False, dropout=False)
    self.up6 = Block(features*4*2, features*2, act='relu', down=False, dropout=False)
    self.up7 = Block(features*2*2, features, act='relu', down=False, dropout=False)

    self.final = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, 4, 2, 1),
        nn.Tanh()
    )

  def forward(self, x):
    d1 = self.initial(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)

    neck = self.neck(d7)

    up1 = self.up1(neck)
    up2 = self.up2(torch.cat([up1, d7], 1))
    up3 = self.up3(torch.cat([up2, d6], 1))
    up4 = self.up4(torch.cat([up3, d5], 1))
    up5 = self.up5(torch.cat([up4, d4], 1))
    up6 = self.up6(torch.cat([up5, d3], 1))
    up7 = self.up7(torch.cat([up6, d2], 1))
    return self.final(torch.cat([up7, d1], 1))

In [66]:
def test_gen():
  x = torch.rand(1, 3, 256, 256)
  model = Generator(in_channels=3, features=64)
  preds = model(x)
  print(preds.shape)

In [67]:
test_gen()

torch.Size([1, 3, 256, 256])


In [98]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 10
L1_LAMBDA = 100
BATCH_SIZE = 16
IMG_CHANNELS = 3
IMG_SIZE = 256
LEARNING_RATE = 2e-4
NUM_WORKERS = 2
GEN_CHECKPOINT = 'checkpoints/gen.pth.tar'
DISC_CHECKPOINT = 'checkpoints/disc.pth.tar'
SAVE_MODEL = True
LOAD_MODEL = False

In [99]:
both_transform = A.Compose(
    [A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5)], additional_targets={'image0': 'image'})

image_transform = A.Compose([
    A.ColorJitter(p=0.1),
    A.Normalize([0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)], max_pixel_value=255.0),
    ToTensorV2()
])

In [100]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.file_paths = os.listdir(self.root_dir)

  def __len__(self):
    return len(self.file_paths)

  def __getitem__(self, index):
    img_file = self.file_paths[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image = image[:, :600, :]
    target_image = image[:, 600:, :]

    augumentations = both_transform(image=input_image, image0=target_image)
    input_image, target_image = augumentations['image'], augumentations['image0']

    input_image = image_transform(image=input_image)['image']
    target_image = image_transform(image=target_image)['image']

    return input_image, target_image


In [101]:
def save_examples(gen, val_loader, epoch, folder):
  x, y = next(iter(val_loader))
  x, y = x.to(DEVICE), y.to(DEVICE)

  gen.eval()
  with torch.no_grad():
    y_fake = gen(x)
    y_fake = (y_fake * 0.5) + 0.5
    save_image(y_fake, folder + f'/y_gen_{epoch}.png')
    save_image((x * 0.5) + 0.5, folder + f'/input_{epoch}.png')
    if epoch == 1:
      save_image(y, folder + f'/label_{epoch}.png')
  gen.train()


def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
  print(f'=> Saving Checkpoint At {filename}')
  checkpoint = {
      'state_dict': model.state_dict(),
      'optimizer': optimizer.state_dict()
  }

  torch.save(checkpoint, filename)

def load_checkpoint(filename, model, optimizer, lr):
  print(f'=> Loading Checkpoint From {filename}')
  checkpoint = torch.load(filename, map_location=DEVICE)
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [102]:
def train_step(loader, gen, disc, gen_opt, disc_opt, g_scalar, d_scalar, BCE, L1):
  for idx, (x, y) in enumerate(tqdm(loader, leave=True)):
    x, y = x.to(DEVICE), y.to(DEVICE)

    with torch.cuda.amp.autocast():
      y_fake = gen(x)
      D_real = disc(x, y)
      D_fake = disc(x, y_fake.detach())
      D_real_loss = BCE(D_real, torch.ones_like(D_real))
      D_fake_loss = BCE(D_fake, torch.zeros_like(D_fake))
      D_loss = (D_real_loss + D_fake_loss) / 2

    disc.zero_grad()
    d_scalar.scale(D_loss).backward()
    d_scalar.step(disc_opt)
    d_scalar.update()

    with torch.cuda.amp.autocast():
      D_fake = disc(x, y_fake)
      G_fake_loss = BCE(D_fake, torch.ones_like(D_fake))
      l1 = L1(y_fake, y) * L1_LAMBDA
      G_loss = G_fake_loss + l1

    gen.zero_grad()
    g_scalar.scale(G_loss).backward()
    g_scalar.step(gen_opt)
    g_scalar.update()

In [103]:
gen = Generator(in_channels=3).to(DEVICE)
disc = Discriminator(in_channels=3).to(DEVICE)

gen_opt = optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas=(0.5, 0.999))
disc_opt = optim.Adam(disc.parameters(), lr = LEARNING_RATE, betas=(0.5, 0.999))

bce = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()

if LOAD_MODEL:
  load_checkpoint(DISC_CHECKPOINT, disc, disc_opt)
  load_checkpoint(GEN_CHECKPOINT, gen, gen_opt)

train_dataset = MapDataset('data/pix2pix/maps/maps/train')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = MapDataset('data/pix2pix/maps/maps/val')
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

g_scalar = torch.cuda.amp.GradScaler()
d_scalar = torch.cuda.amp.GradScaler()


for epoch in range(EPOCHS):
  train_step(train_loader, gen, disc, gen_opt, disc_opt, g_scalar, d_scalar, bce, l1)

  if SAVE_MODEL:
    save_checkpoint(gen, gen_opt, GEN_CHECKPOINT)
    save_checkpoint(disc, disc_opt, DISC_CHECKPOINT)

  save_examples(gen, val_loader, epoch, 'logs/Pix2Pix')

100%|██████████████████████| 69/69 [00:16<00:00,  4.14it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.20it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.20it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.18it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.19it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.22it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.21it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.20it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.21it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar


100%|██████████████████████| 69/69 [00:16<00:00,  4.22it/s]


=> Saving Checkpoint At checkpoints/gen.pth.tar
=> Saving Checkpoint At checkpoints/disc.pth.tar
