<a href="https://colab.research.google.com/github/Diishasing/From-Drawings-to-Rembrandt-style-using-Pix2piX-/blob/main/From_Drawings_to_Rembrandt_style_using_Pix2piX_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

IMPORTS

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader

DISCRIMINATOR

In [None]:
class CNNB(nn.Module):
  def __init__(self, in_channels, out_channels, stride = 2):
    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, bias = False, padding_mode = 'reflect'),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
        )
    
  def forward(self, x):
    return self.conv(x)


class Discriminator(nn.Module):
  def __init__(self, in_channels = 3, features = [64, 128, 256, 512] ):
    super().__init__()

    self.initial = nn.Sequential(
        nn.Conv2d(in_channels * 2, features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode =  'reflect'),
        nn.LeakyReLU(0.2),
    )     

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          CNNB(in_channels, feature, stride = 1 if feature == features[-1] else 2),
      )
      in_channels = feature
    
    layers.append(
        nn.Conv2d(
            in_channels, 1, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect'
        ),
    )

    self.model = nn.Sequential(*layers)  

  def forward(self, x, y):
    x = torch.cat([x, y], dim = 1)
    x = self.initial(x)
    return self.model(x)

GENERATOR--MODIFIED U-NET ARCHITECTURE

In [None]:
'''U-NET Architecture'''

class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down = True, act = 'relu', use_dropout = False):
    super(Block, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias = False, padding_mode = 'reflect')
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act == 'relu' else nn.LeakyReLU(0.2),
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)
    self.down = down

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
  def __init__(self, in_channels = 3, features = 64):
    super().__init__()

    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode = 'reflect'),
        nn.LeakyReLU(0.2),
    )
    self.down1 = Block(features, features * 2, down = True, act = 'leaky', use_dropout = False)
    self.down2 = Block(features * 2, features * 4, down = True, act = 'leaky', use_dropout = False)
    self.down3 = Block(features * 4, features * 8, down = True, act = 'leaky', use_dropout = False)
    self.down4 = Block(features * 8, features * 8, down = True, act = 'leaky', use_dropout = False)
    self.down5 = Block(features * 8, features * 8, down = True, act = 'leaky', use_dropout = False)
    self.down6 = Block(features * 8, features * 8, down = True, act = 'leaky', use_dropout = False)

    self.bottom = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode = 'reflect'),
        nn.ReLU(),
    )

    '''UPWARD PART'''
    self.up1 = Block(features*8, features * 8, down = False, act = 'relu', use_dropout = True)
    self.up2 = Block(features*16, features * 8, down = False, act = 'relu', use_dropout = True)
    self.up3 = Block(features*16, features * 8, down = False, act = 'relu', use_dropout = True)
    self.up4 = Block(features*16, features * 8, down = False, act = 'relu', use_dropout = False)
    self.up5 = Block(features*16, features * 4, down = False, act = 'relu', use_dropout = False)
    self.up6 = Block(features*8, features * 2, down = False, act = 'relu', use_dropout = False)
    self.up7 = Block(features *4, features, down = False, act = 'relu', use_dropout = False)

    self.final = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, kernel_size = 4, stride = 2, padding = 1 ),
        nn.Tanh(),
    )

  def forward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)
    bottom = self.bottom(d7)
    u1 = self.up1(bottom)
    u2 = self.up2(torch.cat([u1, d7], dim=1))
    u3 = self.up3(torch.cat([u2, d6], dim=1))
    u4 = self.up4(torch.cat([u3, d5], dim=1))
    u5 = self.up5(torch.cat([u4, d4], dim=1))
    u6 = self.up6(torch.cat([u5, d3], dim=1))
    u7 = self.up7(torch.cat([u6, d2], dim=1))
    return self.final(torch.cat([u7, d1], dim=1))



AUGMENTATION

In [None]:
transform_both = A.Compose(
    [A.Resize(width = 256, height = 256), A.HorizontalFlip(p = 0.5),], additional_targets = {'image0': 'image'},
)

transform_input = A.Compose(
    [
        A.ColorJitter(p = 0.1),
        A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5], max_pixel_value = 255.0),
        ToTensorV2(),
    ]
)

transform_mask = A.Compose(
    [
        A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5 , 0.5, 0.5], max_pixel_value = 255.0),
        ToTensorV2(),
    ]
)

DATASET

In [None]:
class Data(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(self.root_dir)
    print(self.list_files)

  def __len__(self):
    return len(self.list_files)

  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    img = np.array(Image.open(img_path))
    input_img = img[:, :256, :]
    target_img = img[:, 256:, :]

    augmentation = transform_both(image = input_img, image0 = target_img)
    input_img, target_img = augmentation['image'], augmentation['image0']

    input_img = transform_input(image = input_img)['image']

    target_img = transform_mask(image = target_img)['image']

    return input_img, target_img

HYPERPARAMETERS

In [None]:
gpu = 'cuda' if torch.cuda.is_available() else 'cpu'
l_r = 2e-4 #as stated in the paper
BATCH_SIZE = 1
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
EPOCHS = 1000
LOAD_MODEL = True
SAVE_MODEL = True
# CHECKPOINT_DISC = 'disc.path.tar'
# CHECKPOINT_GEN = 'gen.path.tar'

TRAINING

In [None]:
gen = Generator(in_channels = 3).to(gpu)
disc = Discriminator(in_channels = 3).to(gpu)

# initialize_weights(gen)
# initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr = l_r, betas = (0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr = l_r, betas = (0.5, 0.999))

#loss function
criterion = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

train_data = Data(root_dir = '/content/drive/MyDrive/Pix2piX/gd')

train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_WORKERS)

g_scaler = torch.cuda.amp.GradScaler()

d_scaler = torch.cuda.amp.GradScaler()


#val dataloader




# fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(gpu)

# #tensorboard directory writer
# writer_real = SummaryWriter(f'runs/real')
# writer_fake = SummaryWriter(f'runs/fake')
# step = 0

# gen.train()
# disc.train()


['282.jpg', '23.jpg', '165.jpg', '75.jpg', '35.jpg', '242.jpg', '22.jpg', '199.jpg', '16.jpg', '232.jpg', '274.jpg', '56.jpg', '233.jpg', '172.jpg', '237.jpg', '187.jpg', '247.jpg', '281.jpg', '120.jpg', '288.jpg', '95.jpg', '61.jpg', '173.jpg', '100.jpg', '79.jpg', '118.jpg', '290.jpg', '171.jpg', '271.jpg', '139.jpg', '137.jpg', '12.jpg', '245.jpg', '236.jpg', '28.jpg', '21.jpg', '44.jpg', '104.jpg', '197.jpg', '97.jpg', '250.jpg', '217.jpg', '52.jpg', '179.jpg', '155.jpg', '62.jpg', '48.jpg', '42.jpg', '198.jpg', '51.jpg', '85.jpg', '157.jpg', '230.jpg', '161.jpg', '259.jpg', '277.jpg', '145.jpg', '255.jpg', '45.jpg', '152.jpg', '177.jpg', '235.jpg', '6.jpg', '124.jpg', '289.jpg', '275.jpg', '135.jpg', '65.jpg', '71.jpg', '125.jpg', '166.jpg', '87.jpg', '258.jpg', '84.jpg', '77.jpg', '1.jpg', '298.jpg', '150.jpg', '129.jpg', '34.jpg', '287.jpg', '32.jpg', '123.jpg', '215.jpg', '138.jpg', '201.jpg', '3.jpg', '101.jpg', '246.jpg', '151.jpg', '116.jpg', '122.jpg', '216.jpg', '293.jpg',

In [None]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler):
  loop = tqdm(loader, leave = True)

  for idx, (x, y) in enumerate(loop):
    x, y = x.to(gpu), y.to(gpu)

    #train the discriminator
    with torch.cuda.amp.autocast():
      y_fake = gen(x)
      D_real = disc(x, y)
      D_fake = disc(x, y_fake.detach())
      D_real_loss = bce(D_real, torch.ones_like(D_real))
      D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
      D_loss = (D_real_loss + D_fake_loss) / 2

    disc.zero_grad()
    d_scaler.scale(D_loss).backward(retain_graph = True)
    d_scaler.step(opt_disc)
    d_scaler.update()

    #train generator
    with torch.cuda.amp.autocast():
      D_fake = disc(x, y_fake)
      G_fake_loss  = bce(D_fake, torch.ones_like(D_fake))
      L1 = l1(y_fake, y)  * L1_LAMBDA
      G_loss = G_fake_loss + L1

    gen.zero_grad()
    g_scaler.scale(G_loss).backward(retain_graph = True)
    g_scaler.step(opt_gen)
    g_scaler.update()  



In [None]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(gpu), y.to(gpu)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        # save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

In [None]:
for epoch in range(EPOCHS):
  train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, criterion, g_scaler, d_scaler)

  save_some_examples(gen, train_loader, epoch, folder="/content/drive/MyDrive/Pix2piX/evaluated")



  # for batch_idx, (real, _) in enumerate(dataloader):
  #   real = real.to(gpu)
  #   noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(gpu)
  #   fake = gen(noise)

  #   #discriminator loss
  #   disc_real = disc(real).reshape(-1)
  #   loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

  #   disc_fake = disc(fake).reshape(-1)
  #   loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

  #   loss_disc = (loss_disc_real + loss_disc_fake) / 2

  #   disc.zero_grad()
  #   loss_disc.backward(retain_graph = True)
  #   opt_disc.step()

  #   #generator loss
  #   output = disc(fake).reshape(-1)
  #   loss_gen = criterion(output, torch.ones_like(output))
  #   gen.zero_grad()
  #   loss_gen.backward()
  #   opt_gen.step()

  #   #print the losses to the tensorboard
  #   if batch_idx % 100 == 0:
  #     print(
  #         f'Epoch [{epoch}/{EPOCHS}] \
  #           Batch [{batch_idx} / {len(dataloader)}] \
  #           Loss D {loss_disc:.4f} \
  #           Loss G {loss_gen:.4f}'
  #     )

  #     with torch.no_grad():
  #       fake = gen(fixed_noise)

  #       img_grid_real = torchvision.utils.make_grid(
  #           real[:25], normalize = True
  #       )
  #       img_grid_fake = torchvision.utils.make_grid(
  #           fake[:25], normalize = True
  #       )

  #       writer_real.add_image('real', img_grid_real, global_step = step)
  #       writer_real.add_image('fake', img_grid_fake, global_step = step)

  #       step += 1


100%|██████████| 302/302 [00:16<00:00, 18.80it/s]
100%|██████████| 302/302 [00:16<00:00, 18.79it/s]
100%|██████████| 302/302 [00:16<00:00, 18.68it/s]
100%|██████████| 302/302 [00:16<00:00, 18.76it/s]
100%|██████████| 302/302 [00:16<00:00, 18.85it/s]
100%|██████████| 302/302 [00:16<00:00, 18.87it/s]
100%|██████████| 302/302 [00:15<00:00, 18.91it/s]
100%|██████████| 302/302 [00:17<00:00, 17.22it/s]
100%|██████████| 302/302 [00:16<00:00, 18.62it/s]
100%|██████████| 302/302 [00:16<00:00, 18.85it/s]
100%|██████████| 302/302 [00:15<00:00, 18.97it/s]
100%|██████████| 302/302 [00:15<00:00, 18.90it/s]
100%|██████████| 302/302 [00:15<00:00, 18.92it/s]
100%|██████████| 302/302 [00:15<00:00, 18.90it/s]
100%|██████████| 302/302 [00:15<00:00, 18.91it/s]
100%|██████████| 302/302 [00:15<00:00, 18.90it/s]
100%|██████████| 302/302 [00:16<00:00, 18.69it/s]
100%|██████████| 302/302 [00:15<00:00, 18.89it/s]
100%|██████████| 302/302 [00:15<00:00, 18.94it/s]
100%|██████████| 302/302 [00:15<00:00, 18.89it/s]
