### Imports

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torch import nn, optim
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(device)

### Fetching and cleaning the Dataset

In [None]:
!pip install -U fastai
import fastai
from fastai.data.external import untar_data, URLs
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"
paths = glob.glob(coco_path + "/*.jpg")
np.random.seed(123)
paths_subset = np.random.choice(paths, 10_000, replace=False) # choosing 1000 images randomly
rand_idxs = np.random.permutation(10_000)
train_idxs = rand_idxs[:8000] # choosing the first 8000 as training set
val_idxs = rand_idxs[8000:] # choosing last 2000 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(train_paths)

In [None]:
img_count = 0

fig, axes = plt.subplots(3, 3, figsize=(10,10))
for i in range(3):
    for j in range(3):        
        if img_count < len(train_paths):
            axes[i, j].imshow(Image.open(train_paths[img_count]))
            img_count+=1

In [None]:
IMGSIZE = 256
lentt=len(train_paths)
mn=10000000
mx=-11111
for i in range(0,lentt//100):
    img = Image.open(train_paths[0])
    img = img.convert("RGB")
    img = transforms.Resize((IMGSIZE, IMGSIZE),  Image.BICUBIC)(img)
    img = np.array(img)
    imgInLAB = rgb2lab(img).astype("float32")
    imgInLAB = transforms.ToTensor()(imgInLAB)
    mn=min((imgInLAB[[0],...]).min(),mn)
    mx=max((imgInLAB[[0],...]).max(),mx)
    L = imgInLAB[[0], ...] / 50. - 1. # Between -1 and 1
    ab = imgInLAB[[1, 2], ...] / 110. # Between -1 and 1
    # print(L,ab,sep='\n')
print(mx, mn)

### Initializing DataLoaders

In [None]:
IMGSIZE = 256
class MakeDataset(Dataset):
    def __init__(self, paths):
        self.transforms = transforms.Resize((IMGSIZE, IMGSIZE),  Image.BICUBIC)
        self.IMGSIZE = IMGSIZE
        self.paths=paths

    def __getitem__(self, i):
        img = Image.open(self.paths[i])
        img = img.convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        imgInLAB = rgb2lab(img).astype("float32")
        imgInLAB = transforms.ToTensor()(imgInLAB)
        L_array = imgInLAB[[0], ...] / 50. - 1.
        ab_array = imgInLAB[[1, 2], ...] / 110.
        return [L_array, ab_array]
        
    def __len__(self):
        return len(self.paths)
    

In [None]:
BatchSize, Workers = [16, 4]
train_dl = DataLoader(MakeDataset(paths=train_paths), batch_size=BatchSize, num_workers=Workers, pin_memory=True, shuffle = True)
val_dl = DataLoader(MakeDataset(paths=val_paths), batch_size=BatchSize, num_workers=Workers, pin_memory=True, shuffle = True)


In [None]:
def lab_to_rgb(L, ab):  
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

In [None]:
data = next(iter(train_dl))
L_array, ab_array = data[0], data[1]
print(L_array.shape, ab_array.shape)

fig, (ax0, ax1, ax2,ax3) = plt.subplots(1, 4, figsize=(15,10))
ax0.imshow(L_array[0][0])
ax0.set_title('L')
ax1.imshow(ab_array[0][0])
ax1.set_title('a')
ax2.imshow(ab_array[0][1])
ax2.set_title('b')
ax3.imshow(lab_to_rgb(L_array,ab_array)[0])
ax3.set_title('RGB')
plt.show()

## Modeling the Conditional Generative Adversarial Network (cGAN)

### Generator Code

In [None]:
class Gen_Block(nn.Module):
    def __init__(self, inputs, outputs, down=True, batchNorm=True, dropout=False):
        super(Gen_Block,self).__init__()
        
        if down:
            self.block1 = nn.Conv2d(inputs, outputs, kernel_size=4, stride=2, padding=1, bias=False)
            self.block4 = nn.LeakyReLU(0.2, True)
        else:
            self.block1 = nn.ConvTranspose2d(inputs, outputs, kernel_size=4, stride=2, padding=1, bias=False)
            self.block4 = nn.ReLU(True)
        if batchNorm:
            self.block2 = nn.BatchNorm2d(outputs)
        if dropout:
            self.block3 = nn.Dropout(0.5)
            
        self.batchNorm = batchNorm
        self.dropout = dropout
    
    def forward(self, x):
        out = self.block1(x)
        if self.batchNorm:
            out = self.block2(out)
        if self.dropout:
            out = self.block3(out)
        out = self.block4(out)
        return out

In [None]:
class generator(nn.Module):
    def __init__(self, inputs=1, outputs=64):
        super(generator,self).__init__()
        
        self.d1=  Gen_Block(1,64,batchNorm=False)
        self.d2=  Gen_Block(64,128)
        self.d3=  Gen_Block(128,256)
        self.d4=  Gen_Block(256,512)
        self.d5=  Gen_Block(512,512)
        self.d6=  Gen_Block(512,512)
        self.d7=  Gen_Block(512,512)
        self.d8=  nn.Sequential(nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2))
        
        
        self.u1 = Gen_Block(512,512,False,dropout=True)
        self.u2 = Gen_Block(1024,512,False,dropout=True)
        self.u3 = Gen_Block(1024,512,False,dropout=True)
        self.u4 = Gen_Block(1024,512,False)
        self.u5 = Gen_Block(1024,256,False)
        self.u6 = Gen_Block(512,128,False)
        self.u7 = Gen_Block(256,64,False)
        self.u8 = nn.Sequential(nn.ConvTranspose2d(128, 2, kernel_size=4, stride=2, padding=1, bias=False), nn.Tanh())
        
    
    def forward(self, x):
        dd1 = self.d1(x)
        dd2 = self.d2(dd1)
        dd3 = self.d3(dd2)
        dd4 = self.d4(dd3)
        dd5 = self.d5(dd4)
        dd6 = self.d6(dd5)
        dd7 = self.d7(dd6)
        dd8 = self.d8(dd7)
        uu1 = self.u1(dd8)
        uu2 = self.u2(torch.concat([uu1,dd7],1))
        uu3 = self.u3(torch.concat([uu2,dd6],1))
        uu4 = self.u4(torch.concat([uu3,dd5],1))
        uu5 = self.u5(torch.concat([uu4,dd4],1))
        uu6 = self.u6(torch.concat([uu5,dd3],1))
        uu7 = self.u7(torch.concat([uu6,dd2],1))
        uu8 = self.u8(torch.concat([uu7,dd1],1))
        return uu8

In [None]:
chk_block = generator(1,64)
# y=L_array.view(0,0,-1)
# print(y.shape)
# torch.manual_seed(0)
x=L_array[0][0]
ab=ab_array[0]
x= x.unsqueeze(0)
x= x.unsqueeze(0)
ab= ab.unsqueeze(0)
y=chk_block(x)
print('y',ab.shape)
# plt.imshow(x[0][0])
plt.imshow(lab_to_rgb(x,ab)[0])
plt.imshow(lab_to_rgb(x,y.detach())[0])
plt.show()
generator(1)

### Discriminator Code

In [None]:
class disc_block(nn.Module):
    def __init__(self, inputs, outputs,  kernel=4, stride=2, padding=1, batchNorm=True, activation=True):
        super(disc_block,self).__init__()
        
        self.block1 = nn.Conv2d(inputs, outputs, kernel, stride, padding, bias=not batchNorm)
        if batchNorm: self.block2 = nn.BatchNorm2d(outputs)
        if activation: self.block3 = nn.LeakyReLU(0.2, True)

        self.batchNorm = batchNorm
        self.activation = activation


    def forward(self, x):
        out = self.block1(x)
        if self.batchNorm:
            out = self.block2(out)
        if self.activation:
            out = self.block3(out)
        # print(out.shape)
        return out
        

class discriminator(nn.Module):
    def __init__(self, inputs=3):
        super(discriminator,self).__init__()

        self.b1 = disc_block(inputs,64,batchNorm=False)
        self.b2 = disc_block(64,128)
        self.b3 = disc_block(128,256)
        self.b4 = disc_block(256,512,stride=1)
        self.b5 = disc_block(512,1,stride=1,batchNorm=False,activation=False)
                                
    def forward(self, x):
        #print(x.shape())
        y1 = self.b1(x)
        y2 = self.b2(y1)
        y3 = self.b3(y2)
        y4 = self.b4(y3)
        y5 = self.b5(y4)
        return y5

In [None]:
x=torch.randn(1,1,256,256)
get=discriminator(1)
get(x).shape
discriminator(3)

In [None]:
def visualize(gen, dl, folder, epoch, SAVE = True):
  data = next(iter(dl))
  L, ab = data[0], data[1]
  L=L.to(device)
  ab=ab.to(device)
  # print(ab.shape)
  gen.eval()
  with torch.no_grad():
      ab_gen = gen(L)
  gen.train()
  real_imgs = lab_to_rgb(L, ab)
  fake_imgs = lab_to_rgb(L, ab_gen.detach())

  fig = plt.figure(figsize=(15, 8))
  for i in range(5):
      ax = plt.subplot(3, 5, i + 1)
      ax.imshow(L[i][0].cpu(), cmap='gray')
      ax = plt.subplot(3, 5, i + 1 + 5)
      ax.imshow(fake_imgs[i])
      ax = plt.subplot(3, 5, i + 1 + 10)
      ax.imshow(real_imgs[i])
  if SAVE:
    plt.savefig(folder + f"/Results_After_Epoch_{epoch}.png")
    plt.show()

def VisualizeLoss(Loss_Arr, folder, epoch, generator = True, SAVE = True):
  x=(range(0,len(Loss_Arr)))
  plt.figure(figsize = (12,10))
  plt.plot(x,Loss_Arr)
  if SAVE:
    str = "Discriminator"
    if generator:
      str = "Generator"
    plt.savefig(folder + f"/{str}_Loss_After_Epoch_{epoch}.png")
  plt.show()


## Training The Model

In [None]:
LearningRate = 2e-4
EPOCHS = 350
LAMBDA = 100
epoch = 1
DISC_LOSS = []
GEN_LOSS = []

In [None]:
CHECKPOINT_DISC = "../input/model-params/disc.pth.tar"
CHECKPOINT_GEN = "../input/model-params/gen.pth.tar"
folder = "/kaggle/working"
LOAD_MODEL = True
SAVE = True

In [None]:
def save_checkpoint(model, optimizer, epoch, filename):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch":epoch,
        "DISC_LOSS" : DISC_LOSS,
        "GEN_LOSS" : GEN_LOSS
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    global epoch
    global DISC_LOSS
    global GEN_LOSS
    epoch = checkpoint["epoch"]
    DISC_LOSS = checkpoint["DISC_LOSS"].copy()
    GEN_LOSS = checkpoint["GEN_LOSS"].copy()

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
disc = discriminator(3).to(device)
gen = generator(1).to(device)
disc_opt = optim.Adam(disc.parameters(),lr=LearningRate, betas=(0.5,0.999))
gen_opt = optim.Adam(gen.parameters(),lr=LearningRate, betas=(0.5,0.999))
loss_fn = nn.BCEWithLogitsLoss()
L1_loss = nn.L1Loss()
disc_scaler = torch.cuda.amp.GradScaler()
gen_scaler = torch.cuda.amp.GradScaler()

if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_GEN, gen, gen_opt, LearningRate)
    load_checkpoint(CHECKPOINT_DISC, disc, disc_opt, LearningRate)

In [None]:
SAVE_MODEL = True
CHECKPOINT_DISC = "/kaggle/working/disc.pth.tar"
CHECKPOINT_GEN = "/kaggle/working/gen.pth.tar"

In [None]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler):
    loop = tqdm(loader, leave=True)
    for idx, (L, ab) in enumerate(loop):
        L = L.to(device)
        ab = ab.to(device)
        # print(L.shape, ab.shape)
        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(L)
            D_real = disc(torch.concat([L, ab],1))
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(torch.concat([L, y_fake.detach()],1))
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2
            DISC_LOSS.append(D_loss.item())
        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(torch.concat([L, y_fake],1))
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, ab) * LAMBDA
            G_loss = G_fake_loss + L1
            GEN_LOSS.append(G_loss.item())

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

In [None]:
while epoch <= EPOCHS:
    print("\nEpoch",epoch,'\n')
    visualize(gen, val_dl,folder,epoch,SAVE)
    if SAVE_MODEL and epoch % 1 == 0:
            save_checkpoint(gen, gen_opt, epoch, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, disc_opt, epoch, filename=CHECKPOINT_DISC)
    if epoch%2==0:
      print("Generator Loss\n")
      VisualizeLoss(GEN_LOSS,folder,epoch,True,SAVE)
      print("Discriminator Loss\n")
      VisualizeLoss(DISC_LOSS,folder,epoch,False,SAVE)
    train_fn(disc, gen, train_dl, disc_opt, gen_opt, L1_loss, loss_fn, disc_scaler, gen_scaler)
    epoch+=1


In [None]:
VisualizeLoss(GEN_LOSS,folder,epoch,True,SAVE=False)

In [None]:
VisualizeLoss(DISC_LOSS,folder,epoch,False,SAVE=False)

In [None]:
visualize(gen, val_dl,folder,epoch,SAVE=True)

In [None]:
visualize(gen, train_dl,folder,epoch,SAVE=False)