In [None]:
import torch
import torch.nn as nn
#Discriminator

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64,128,256, 256, 512]):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(in_channels*2, features[0], kernel_size=4,stride=1, padding="same", padding_mode="reflect",),
        nn.LeakyReLU(0.2),# 256 x 256 x 64
        nn.Conv2d(features[0], features[1], kernel_size=4,stride=1, padding="same", padding_mode="reflect",),
        nn.BatchNorm2d(features[1]),
        nn.LeakyReLU(0.2),# 256 x 256 x 128
        nn.MaxPool2d(kernel_size=4,stride=2),#127 x 127 x128
        nn.Conv2d(features[1], features[2], kernel_size=4,stride=1, padding="same", padding_mode="reflect",),
        nn.BatchNorm2d(features[2]),
        nn.LeakyReLU(0.2),#127 x 127 x256
        nn.Conv2d(features[2], features[3], kernel_size=4, stride=1,padding="valid", padding_mode="reflect",),
        nn.BatchNorm2d(features[3]),
        nn.LeakyReLU(0.2),# 124 x 124 x 256
        nn.MaxPool2d(kernel_size=4,stride=2), #61 x 61 x 256
        nn.Conv2d(features[3], features[4], kernel_size=4,stride =1, padding="valid", padding_mode="reflect",),
        nn.BatchNorm2d(features[4]),
        nn.LeakyReLU(0.2), # 58 x 58 x 512
        nn.MaxPool2d(kernel_size=4,stride=2), # 28 x 28 x 512
        nn.Conv2d(features[4], 1, kernel_size=4,stride=1, padding= "same", padding_mode="reflect",),# 28 x 28 x 1
    )
    

  def forward(self, x, y):
    x = torch.cat([x,y],1)# concatenating input image and generated image
    x = self.model(x)# sending input to initial layer
    # print(x.shape)
    return x
    # x = torch.cat([x,y],1) # concatenating input image and generated image
    # x = self.initial(x)
    # return self.model(x) 
    


def test():
  x = torch.randn((1, 3, 256, 256))
  y = torch.randn((1, 3, 256, 256))
  model = Discriminator()
  preds = model(x,y)
  print(preds.shape)


if __name__ == "__main__":
  test()

In [None]:
#generator
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4 ,2,1, bias=False, padding_mode="reflect")
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4,2,1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2)

    )
    self.use_dropout = use_dropout
    self.dropout= nn.Dropout(0.5) 

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
  def __init__(self, in_channels=3, features=64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 3, 1, padding="same", padding_mode = "reflect"),
        nn.LeakyReLU(0.2),# 256 x 256 x 64
        nn.Conv2d(features, features, 4, 2, 1, padding_mode = "reflect"),
        nn.LeakyReLU(0.2), # 128 x 128 x 64
    )#initial layer
    
    self.l1 = nn.Sequential(
        nn.Conv2d(features, features*2, 3, 1, padding="same", padding_mode = "reflect"),
        nn.BatchNorm2d(features*2),
        nn.LeakyReLU(0.2),
    
    )# 128 x 128 x 128
    self.down1 = Block(features*2, features*2, down = True, act="leaky", use_dropout=False)# next layer 64 x 64 x 128 

    self.l2 = nn.Sequential(
        nn.Conv2d(features*2, features*4, 3, 1, padding="same", padding_mode = "reflect"),
        nn.BatchNorm2d(features*4),
        nn.LeakyReLU(0.2),
    
    )# 64 x 64 x 256
    self.down2 = Block(features*4, features*4, down = True, act="leaky", use_dropout=False)# next layer 32 x 32 x 256

    self.l3 = nn.Sequential(
        nn.Conv2d(features*4, features*8, 3, 1, padding="same", padding_mode = "reflect"),
        nn.BatchNorm2d(features*8),
        nn.LeakyReLU(0.2),
    
    )# 32 x 32 x 512
    self.down3 = Block(features*8, features*8, down = True, act="leaky", use_dropout=False)# next layer 16 x 16 x 512

    self.l4 = nn.Sequential(
        nn.Conv2d(features*8, features*8, 3, 1, padding="same", padding_mode = "reflect"),
        nn.BatchNorm2d(features*8),
        nn.LeakyReLU(0.2),
    
    )# 16 x 16 x 512
    self.down4 = Block(features*8, features*8, down = True, act="leaky", use_dropout=False)# next layer 8 x 8 x 512

    self.l5 = nn.Sequential(
        nn.Conv2d(features *8, features*8, 3, 1, padding="same", padding_mode = "reflect"),
        nn.BatchNorm2d(features*8),
        nn.LeakyReLU(0.2),
    
    )# 8 x 8 x 512
    self.down5 = Block(features*8, features*8, down = True, act="leaky", use_dropout=False)# next layer 4 x 4 x 512

    self.l6 = nn.Sequential(
        nn.Conv2d(features*8, features*8, 3, 1, padding="same", padding_mode = "reflect"),
        nn.BatchNorm2d(features*8),
        nn.LeakyReLU(0.2),
    
    ) # 4 x 4 x 512
    self.down6 = Block(features*8, features*8, down = True, act="leaky", use_dropout=False)# next layer 2 x 2 x 512


    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*16, 4, 2, 1, padding_mode="reflect"), 
        nn.ReLU()

    )#feature space 1 x 1 x 1024

    self.up1 = Block(features*16, features*8, down=False, act="relu", use_dropout=True) #upsampling
    self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)#upsampling
    self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)#upsampling
    self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)#upsampling
    self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)#upsampling
    self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)#upsampling
    self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)#upsampling

    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2,in_channels, kernel_size=4,stride=2, padding=1),
        nn.Tanh(),
    )# final layer (-1,3,256,256)
  
  def forward(self, x):
    d1 = self.initial_down(x)
    # print("d1", d1.shape)

    l1 = self.l1(d1)
    # print("l1", l1.shape)
    d2 = self.down1(l1)
    # print("d2", d2.shape)

    l2 = self.l2(d2)
    # print("l2", l2.shape)
    d3 = self.down2(l2)
    # print("d3", d3.shape)

    l3 = self.l3(d3)
    # print("l3", l3.shape)
    d4 = self.down3(l3)
    # print("d4", d4.shape)

    l4 = self.l4(d4)
    # print("l4", l4.shape)
    d5 = self.down4(d4)
    # print("d5", d5.shape)

    l5 = self.l5(d5)
    # print("l5", l5.shape)
    d6 = self.down5(l5)
    # print("d6", d6.shape)

    l6 = self.l6(d6)
    # print("l6", l6.shape)
    d7 = self.down6(l6)
    # print("d7", d7.shape)

    bottleneck = self.bottleneck(d7)


    up1 = self.up1(bottleneck)
    up2 = self.up2(torch.cat([up1, d7],1))# skip connections
    up3 = self.up3(torch.cat([up2, d6],1))
    up4 = self.up4(torch.cat([up3, d5],1))
    up5 = self.up5(torch.cat([up4, d4],1))
    up6 = self.up6(torch.cat([up5, d3],1))
    up7 = self.up7(torch.cat([up6, d2],1))

    return self.final_up(torch.cat([up7, d1], 1))
    


    



In [None]:
!pip install albumentations==0.4.6

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
#Hyper perameters
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "drive/MyDrive/Sketch_gan/file_name"# Enter your own path for train data
VAL_DIR = "drive/MyDrive/Sketch_gan/file_name"# Enter your own path for val data
TEST_DIR = "drive/MyDrive/Sketch_gan/file_name"# Enter your own path for test data
LEARNING_RATE = 0.0005
BATCH_SIZE = 32
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 1000

NUM_EPOCHS = 50
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "drive/MyDrive/Sketch_gan/file_name"# Enter your own path for saving and loading weights
CHECKPOINT_GEN = "drive/MyDrive/Sketch_gan/file_name"# Enter your own path for saving and loading weights



In [None]:
from torchvision.utils import save_image
# saving weights and examples
def save_some_examples(gen, loader, epoch, folder):

    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        
    gen.train()

def save_some_test_examples(gen, test_loader, epoch, folder):
    x = next(iter(test_loader))
    x = x.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
       
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
import numpy as np
# import config
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import cv2 as cv

#loading dataset
class SketchDataset(Dataset):
    def __init__(self, root_dir):
        self.x_root_dir = root_dir + "/x_label"
        self.x_list_files = os.listdir(self.x_root_dir)
        self.y_root_dir = root_dir + "/y_label"
        self.y_list_files = os.listdir(self.y_root_dir)
        self.x_sorted_files =  sorted(self.x_list_files)
        self.y_sorted_files =  sorted(self.y_list_files)

    def __len__(self):
        return len(self.x_list_files)

    def __getitem__(self, index):
        x_img_file = self.x_sorted_files[index]
        x_img_path = os.path.join(self.x_root_dir, x_img_file)
        y_img_file = self.y_sorted_files[index]
        y_img_path = os.path.join(self.y_root_dir, y_img_file)

        input_image = np.array(Image.open(x_img_path).convert('RGB'))
        target_image = np.array(Image.open(y_img_path).convert('RGB'))
        
        
        both_transform = A.Compose(
        [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
        )

        transform_only_input = A.Compose(
            [        
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
            ]
        )

        transform_only_mask = A.Compose(
            [
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                ToTensorV2(),
            ]
        )

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]
        

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        
        return input_image, target_image



        

        

In [None]:
class TestSketchDataset(Dataset):
    def __init__(self, root_dir):
        self.x_root_dir = root_dir + "/x_label"
        self.x_list_files = os.listdir(self.x_root_dir)
       
        self.x_sorted_files =  sorted(self.x_list_files)
        

    def __len__(self):
        return len(self.x_list_files)

    def __getitem__(self, index):
        x_img_file = self.x_sorted_files[index]
        x_img_path = os.path.join(self.x_root_dir, x_img_file)
        

        input_image = np.array(Image.open(x_img_path).convert('RGB'))
        
        
        
        transform = A.Compose(
        [A.Resize(width=256, height=256),]
        )

        transform_only_input = A.Compose(
            [        
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
            ]
        )

        

        augmentations = transform(image=input_image)
        input_image = augmentations["image"]
        
    
      

        input_image = transform_only_input(image=input_image)["image"]
        
        
        return input_image



        

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch

import torch.nn as nn
import torch.optim as optim



from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image

# torch.backends.cudnn.benchmark = True
#training model

def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 32

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3, features=64).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()
    

    if LOAD_MODEL:
        load_epoch = 21 # enter the epoch number of weights to load
        gen_file = f"/gen_{load_epoch}.pth.tar"
        disc_file = f"/disc_{load_epoch}.pth.tar"
        load_checkpoint(
            CHECKPOINT_GEN+gen_file, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC+disc_file, disc, opt_disc, LEARNING_RATE,
        )

    train_dataset = SketchDataset(root_dir=TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = SketchDataset(root_dir=VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False)
    test_dataset = TestSketchDataset(root_dir=TEST_DIR)
    test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
    save_some_examples(gen, val_loader, 1, folder="drive/MyDrive/Sketch_gan/val_eval_simha")
    save_some_test_examples(gen, test_loader, 1, folder="drive/MyDrive/Sketch_gan/test_eval_simha")

    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if SAVE_MODEL and epoch % 1 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN+ f"/gen_{epoch}.pth.tar")
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC+ f"/disc_{epoch}.pth.tar")

        save_some_examples(gen, val_loader, 22+epoch, folder="drive/MyDrive/Sketch_gan/val_eval_simha")
        # save_some_test_examples(gen, test_loader, epoch, folder="drive/MyDrive/Sketch_gan/test_eval_simha")




if __name__ == "__main__":
    main()