In [1]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import os
from PIL import Image

import torchvision
from torchvision.utils import save_image

import albumentations
from albumentations.pytorch import ToTensorV2

from tqdm.auto import tqdm

root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/'

# Discriminator Model

In [2]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        first_block = self._make_conv_block(in_channels*2, features[0], 4, 2, 1, use_bn=False)
        middle_blocks = []
        for i in range(len(features)-1):
            middle_blocks.append(
                self._make_conv_block(features[i], features[i+1], 4, 1 if features[i+1] == features[-1] else 2, 1)
            )
        last_block = self._make_conv_block(features[-1], 1, 4, 1, 1, use_bn=False, use_act=False)
        self.model = nn.Sequential(
            first_block,
            *middle_blocks,
            last_block
            )

    def _make_conv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True):
        layers = [
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
                padding_mode='reflect',
            )
        ]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
        if use_act:
            layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)
    
    def forward(self, x, y):
        return self.model(torch.cat([x, y], dim=1))
    
test_x = torch.randn((1, 3, 256, 256))
test_y = torch.randn((1, 3, 256, 256))
model = Discriminator(in_channels=3)
patch = model(test_x, test_y)
print(model)
print(patch.shape)
assert patch.shape == (1, 1, 30, 30)

Discriminator(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Le

# Generator Model

In [3]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.first_down = self._make_unet_block(in_channels, features, mode='down', use_bn=False, act='leaky_relu')

        self.down2 = self._make_unet_block(features, features*2, mode='down', act='leaky_relu', use_drop=False)
        self.down3 = self._make_unet_block(features*2, features*4, mode='down', act='leaky_relu', use_drop=False)
        self.down4 = self._make_unet_block(features*4, features*8, mode='down', act='leaky_relu', use_drop=False)
        self.down5 = self._make_unet_block(features*8, features*8, mode='down', act='leaky_relu', use_drop=False)
        self.down6 = self._make_unet_block(features*8, features*8, mode='down', act='leaky_relu', use_drop=False)
        self.down7 = self._make_unet_block(features*8, features*8, mode='down', act='leaky_relu', use_drop=False)

        self.last_down = self._make_unet_block(features*8, features*8, mode='down', use_bn=False, use_drop=False)
        self.first_up = self._make_unet_block(features*8, features*8, mode='up', use_bn=True, act='relu', use_drop=True)

        self.up2 = self._make_unet_block(features*8*2, features*8, mode='up', act='relu', use_drop=True)
        self.up3 = self._make_unet_block(features*8*2, features*8, mode='up', act='relu', use_drop=True)
        self.up4 = self._make_unet_block(features*8*2, features*8, mode='up', act='relu', use_drop=False)
        self.up5 = self._make_unet_block(features*8*2, features*4, mode='up', act='relu', use_drop=False)
        self.up6 = self._make_unet_block(features*4*2, features*2, mode='up', act='relu', use_drop=False)
        self.up7 = self._make_unet_block(features*2*2, features, mode='up', act='relu', use_drop=False)

        self.last_up = self._make_unet_block(features*2, in_channels, mode='up', use_bn=False, use_act=False, is_last_act=True)
    
    def _make_unet_block(self, in_channels, out_channels, mode='down', use_bn=True, use_act=True, act='relu', is_last_act=False, use_drop=False, leak=0.2, drop=0.5):
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect') if mode == 'down' else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
        if use_act:
            layers.append(nn.ReLU() if act == 'relu' else nn.LeakyReLU(leak))
        if is_last_act:
            layers.append(nn.Tanh())
        if use_drop:
            layers.append(nn.Dropout(drop))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        d1 = self.first_down(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)

        bottleneck = self.last_down(d7)
        u1 = self.first_up(bottleneck)

        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))

        y_hat = self.last_up(torch.cat([u7, d1], 1))
        return y_hat
    
test_x = torch.randn((1, 3, 256, 256))
model = Generator(in_channels=3)
test_y_hat = model(test_x)
print(model)
print(test_y_hat.shape)
assert test_y_hat.shape == test_x.shape

Generator(
  (first_down): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (down2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (down3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (down4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (down5):

# Parameters

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_DIR = root_path + 'dogs/mix'
VAL_DIR = root_path + 'dogs/mix'
IMG_SIZE = 256
NUM_CHANNELS = 3
BATCH_SIZE = 32
BLUR_KERNEL = 21
NUM_IMG_TO_SHOW = 5
LR = 2e-4
L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = 'disc.pth.tar'
CHECKPOINT_GEN = 'gen.pth.tar'

both_transformations = albumentations.Compose(
    transforms=[
        albumentations.Resize(width=IMG_SIZE, height=IMG_SIZE),
    ],
    additional_targets={'image0': 'image'},
)

input_transformations = albumentations.Compose(
    [
        # albumentations.HorizontalFlip(p=0.5),
        albumentations.ColorJitter(p=1),
        albumentations.GaussianBlur(blur_limit=(BLUR_KERNEL, BLUR_KERNEL), p=1), 
        # albumentations.ElasticTransform(p=1),
        albumentations.Normalize(
            mean=[0.5 for _ in range(NUM_CHANNELS)],
            std=[0.5 for _ in range(NUM_CHANNELS)],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)
target_transformations = albumentations.Compose(
    [
        albumentations.Normalize(
            mean=[0.5 for _ in range(NUM_CHANNELS)],
            std=[0.5 for _ in range(NUM_CHANNELS)],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

# Dataset

In [5]:
class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.list_files)
    
    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        input_image, target_image = np.array(Image.open(img_path)), np.array(Image.open(img_path))

        augmentations = both_transformations(image=input_image, image0=target_image)
        input_image, target_image = augmentations['image'], augmentations['image0']
        input_image, target_image = input_transformations(image=input_image)['image'], target_transformations(image=target_image)['image']

        return input_image, target_image
    
images = MapDataset(TRAIN_DIR)
images_loader = DataLoader(images, batch_size=8)
for x, y in images_loader:
    print(x.shape, y.shape)
    save_image(x, 'testing/x.png')
    save_image(y, 'testing/y.png')
    break

torch.Size([8, 3, 256, 256]) torch.Size([8, 3, 256, 256])


# Utils

In [6]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    num_img_to_show = 5
    x, y = x[:num_img_to_show].to(DEVICE), y[:num_img_to_show].to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x) 
        y_fake = (y_fake * 0.5) + 0.5
        res_grid = torchvision.utils.make_grid(
            torch.cat(
                [
                    y_fake,
                    x*0.5 + 0.5,
                    y*0.5 + 0.5
                    ]
                    ),
            nrow=y_fake.shape[0]
            )
        save_image(res_grid, folder + f'/{epoch+1}.png')
    gen.train()

def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
    print('==> Saving Checkpoint <==')
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('==> Loading Checkpoint <==')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Training

In [7]:
def train_func(disc, gen, loader, opt_disc, opt_gen, l1, bce, d_scaler, g_scaler, save_step, folder):
    # loop = tqdm(loader, leave=True)
    loop = loader
    list_losses = []
    for idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            d_real = disc(x, y)
            d_fake = disc(x, y_fake.detach())
            d_real_loss = bce(d_real, torch.ones_like(d_real))
            d_fake_loss = bce(d_fake, torch.zeros_like(d_fake))
            d_loss = d_real_loss + d_fake_loss  # Authors say dividing by 2 trains D slow than G

        disc.zero_grad()
        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            d_fake = disc(x, y_fake)
            g_fake_loss = bce(d_fake, torch.ones_like(d_fake))
            g_loss = g_fake_loss + (l1(y_fake, y) * L1_LAMBDA)

        gen.zero_grad()
        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
    
        if idx % 3 == 0:
            list_losses.append(round(g_loss.item(),4))
            gen.eval()
            with torch.no_grad():
                y_fake = gen(x[: NUM_IMG_TO_SHOW]) 
                y_fake = (y_fake * 0.5) + 0.5
                res_grid = torchvision.utils.make_grid(
                    torch.cat([
                        y_fake,
                        x[: NUM_IMG_TO_SHOW]*0.5 + 0.5,
                        y[: NUM_IMG_TO_SHOW]*0.5 + 0.5
                        ]),
                    nrow=y_fake.shape[0]
                    )
                save_image(res_grid, f'{folder}/{save_step}.png')
                save_step += 1
            gen.train()
    return save_step, list_losses

def main_func():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LR)
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LR)

    train_dataset = MapDataset(root_dir=TRAIN_DIR)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    # val_dataset = MapDataset(root_dir=VAL_DIR)
    # val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    gen_losses = []
    save_step = 1
    for epoch in tqdm(range(NUM_EPOCHS)):
        save_step, list_losses = train_func(disc, gen, train_dataloader, opt_disc, opt_gen, L1_LOSS, BCE, d_scaler, g_scaler, save_step, 'results')
        gen_losses += list_losses
        # if SAVE_MODEL and epoch%5==0:
        #     save_checkpoint(gen, opt_gen, CHECKPOINT_GEN)
        #     save_checkpoint(disc, opt_disc, CHECKPOINT_DISC)
        
        # save_some_examples(gen, val_dataloader, epoch, 'results')
    return gen_losses

In [8]:
gen_losses = main_func()

  0%|          | 0/500 [00:00<?, ?it/s]

In [9]:
import pickle
with open('gen_losses.pkl', 'wb') as f:
    pickle.dump(gen_losses, f)