In [None]:
import os
import cv2
import math
import torch
import shutil 
import random
import numpy as np 
import pandas as pd 
from PIL import Image
import torch.nn as nn
from torch import cuda
from torch.optim import Adam
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from skimage.color import rgb2lab, lab2rgb, rgb2gray
import torchvision.models as models
from tqdm import tqdm

In [None]:
import os, time, shutil, argparse
import pandas as pd
from functools import partial
import pickle
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb
from skimage import io
import torch.optim as optim

In [None]:
from google.colab import drive
import shutil
def save_model(model):
  drive.mount('/content/drive')
  shutil.copy("/content/" + model, "/content/drive/MyDrive/cv thesis/model")
  print("Model Saved")
  drive.flush_and_unmount()

In [None]:
!pip install --upgrade --no-cache-dir gdown
!gdown --id 1Csq7o2JqQjUkmPM4JPFwPIA74-_ik5Mt #Sample 12k imagenet
!gdown --id 1smp7oD7RftKQCyQrlL7Jf3T0OGg8UURl #Caption file - sampled.csv
!unzip images.zip

In [None]:
# !gdown --id 1omdfiq5ijDIskNjPepir3mxMaNV3twgc #DAVIS dataset
# !unzip DAVIS.zip

In [None]:
# !gdown --id 1YHdTCB4eMCbn4TAPtrdgIJLYF5cC9rt4 #Kota dataset
# !unzip kota.zip

In [None]:
!ls

images	images.zip  sample_data  sampled.csv


In [None]:
!gdown --id 1-zdr4QV1_-OQxQ-URN2zb1JNojHm3pyO #resnet gray weight
# # !gdown --id 1-jFyCtViLuwmRlfepHKKIdgFjvDvPlTS #epoch trained weight
# !gdown --id 1HM1fSJ4JhRGcBoODbmSbFpWkUXEkLg9f #Finetuned model epoch 2
# !gdown --id 1-7GOmJvjcfx4HSUV5NLUVRcJpvlLPjI9 #Finetuned model epoch 5
!gdown --id 1-k5sNFrq7K1ZSfbxIHGDTtSLtgCxY310 #Finetuned generator model epoch 10

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, captions_file, color_transform = None, transform = None):
        self.df = pd.read_csv(captions_file, index_col=None)
        self.transform = transform
        self.color_transform = color_transform
        self.images = self.df["image"]        
        self.root = root

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root , self.images[index])).convert("RGB")
        if self.color_transform:
          img = self.color_transform(img)
        img = np.array(img)
        
        img_lab = rgb2lab(img).astype("float32")
        lab_scaled =  (img_lab + 128) / 255
        

        L = rgb2gray(img)
        ab = lab_scaled[: , : , 1:]
        # print(L.shape)
        if self.transform:
            L = self.transform(L)
            ab = self.transform(ab)

        return L , ab

In [None]:
class ColorizationNet(nn.Module):
    def __init__(self, midlevel_input_size=128, global_input_size=512):
        super(ColorizationNet, self).__init__()
        # Fusion layer to combine midlevel and global features
        self.midlevel_input_size = midlevel_input_size
        self.global_input_size = global_input_size
        self.fusion = nn.Linear(midlevel_input_size + global_input_size, midlevel_input_size)
        self.bn1 = nn.BatchNorm1d(midlevel_input_size)

        # Convolutional layers and upsampling
        self.deconv1_new = nn.ConvTranspose2d(midlevel_input_size, 128, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.Conv2d(midlevel_input_size, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.conv5 = nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.Upsample(scale_factor=2)

        print('Loaded colorization net.')

    def forward(self, midlevel_input, global_input):
        
        # Convolutional layers and upsampling
        x = F.relu(self.bn2(self.conv1(midlevel_input)))
        x = self.upsample(x)
        x = F.relu(self.bn3(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.upsample(x)
        x = F.sigmoid(self.conv4(x))
        x = self.upsample(self.conv5(x))
        return x

In [None]:
class ColorNet(nn.Module):
    def __init__(self):
        super(ColorNet, self).__init__()
        
        # Build ResNet and change first conv layer to accept single-channel input
        resnet_gray_model = models.resnet18(num_classes=365)
        resnet_gray_model.conv1.weight = nn.Parameter(resnet_gray_model.conv1.weight.sum(dim=1).unsqueeze(1).data)
        
        # Only needed if not resuming from a checkpoint: load pretrained ResNet-gray model
        if torch.cuda.is_available(): # and only if gpu is available
            resnet_gray_weights = torch.load('/content/resnet_gray_weights.pth.tar') #torch.load('pretrained/resnet_gray.tar')['state_dict']
            resnet_gray_model.load_state_dict(resnet_gray_weights)
            print('Pretrained ResNet-gray weights loaded')

        # Extract midlevel and global features from ResNet-gray
        self.midlevel_resnet = nn.Sequential(*list(resnet_gray_model.children())[0:6])
        self.global_resnet = nn.Sequential(*list(resnet_gray_model.children())[0:9])
        self.fusion_and_colorization_net = ColorizationNet()

    def forward(self, input_image):

        # Pass input through ResNet-gray to extract features
        midlevel_output = self.midlevel_resnet(input_image)
        global_output = self.global_resnet(input_image)

        # Combine features in fusion layer and upsample
        output = self.fusion_and_colorization_net(midlevel_output, global_output)
        return output


In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.InstanceNorm2d(out_channels, affine = True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=4, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [None]:
color_transform = transforms.Compose([
            transforms.Resize((300), Image.BICUBIC),
            transforms.CenterCrop(256),

        ])
transform = transforms.Compose([
            transforms.ToTensor()
        ])

dataset = ImageDataset(r"/content/images", r"/content/sampled.csv", color_transform, transform)
loader = DataLoader(dataset=dataset, batch_size = 64, num_workers = 0, shuffle = True, pin_memory = True, drop_last = False)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
print(len(loader), len(loader.dataset))

188 12020


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# device = "cpu"

Using cuda device


In [None]:
net_D = Discriminator()
net_D.to(device)

In [None]:
net_G = ColorNet()
# Use GPU if available
if device == "cuda":  
  net_G.to(device)    
  print('Loaded model onto GPU.') 

if os.path.exists("/content/netG_pretrained_fine_tune_10.pt"):
  checkpoint = torch.load('/content/netG_pretrained_fine_tune_10.pt')
  net_G.load_state_dict(checkpoint)
  print("Pretrained Model loaded")

Pretrained ResNet-gray weights loaded
Loaded colorization net.
Loaded model onto GPU.
Pretrained Model loaded


In [None]:
BCE = nn.BCEWithLogitsLoss()
opt_disc = optim.Adam(net_D.parameters(), lr=3e-4, betas=(0.5, 0.999),)
d_scaler = torch.cuda.amp.GradScaler()

In [None]:
if __name__ == '__main__':
    print('Starting validation.')    
    # Switch model to validation mode
    net_G.eval()
    net_D.train()
    for e in range(1, 11):  
      for (gray, ab) in tqdm(loader):

          # Use GPU if available
          L = gray.to(device = device, dtype = torch.float32)
          ab = ab.to(device = device, dtype = torch.float32)
          out_ab = net_G(L) 

          fake = torch.cat([L, out_ab], dim = 1)
          color = torch.cat([L, ab], dim = 1)

          D_real = net_D(L, color)
          D_fake = net_D(L, fake.detach()) 

          # Train Discriminator
          with torch.cuda.amp.autocast():
              D_real_loss = BCE(D_real, torch.ones_like(D_real))
              D_fake_loss = BCE(D_fake, torch.zeros_like(D_fake))
              D_loss = (D_real_loss + D_fake_loss) / 2

          net_D.zero_grad()
          d_scaler.scale(D_loss).backward()
          d_scaler.step(opt_disc)
          d_scaler.update()
      print(f"Epoch {e}")
      print(f"Loss: {D_loss/len(loader):.5f}")  
      path = "netD_no_gan_train" + str(e) + ".pt"
      print("Path :", path)
      torch.save(net_D.state_dict(), "/content/" + path)
      save_model(path)

Starting validation.


100%|██████████| 188/188 [08:30<00:00,  2.72s/it]


Epoch 1
Loss: 0.00338
Path : netD_no_gan_train1.pt
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Model Saved


100%|██████████| 188/188 [08:33<00:00,  2.73s/it]


Epoch 2
Loss: 0.00254
Path : netD_no_gan_train2.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:34<00:00,  2.74s/it]


Epoch 3
Loss: 0.00190
Path : netD_no_gan_train3.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:36<00:00,  2.75s/it]


Epoch 4
Loss: 0.00117
Path : netD_no_gan_train4.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:34<00:00,  2.74s/it]


Epoch 5
Loss: 0.00154
Path : netD_no_gan_train5.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:34<00:00,  2.74s/it]


Epoch 6
Loss: 0.00042
Path : netD_no_gan_train6.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:31<00:00,  2.72s/it]


Epoch 7
Loss: 0.00055
Path : netD_no_gan_train7.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:24<00:00,  2.68s/it]


Epoch 8
Loss: 0.00027
Path : netD_no_gan_train8.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:13<00:00,  2.62s/it]


Epoch 9
Loss: 0.00040
Path : netD_no_gan_train9.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:15<00:00,  2.64s/it]


Epoch 10
Loss: 0.00021
Path : netD_no_gan_train10.pt
Mounted at /content/drive
Model Saved


In [None]:
if __name__ == '__main__':
    print('Starting validation.')    
    # Switch model to validation mode
    net_G.eval()
    net_D.train()
    for e in range(11, 21):  
      for (gray, ab) in tqdm(loader):

          # Use GPU if available
          L = gray.to(device = device, dtype = torch.float32)
          ab = ab.to(device = device, dtype = torch.float32)
          out_ab = net_G(L) 

          fake = torch.cat([L, out_ab], dim = 1)
          color = torch.cat([L, ab], dim = 1)

          D_real = net_D(L, color)
          D_fake = net_D(L, fake.detach()) 

          # Train Discriminator
          with torch.cuda.amp.autocast():
              D_real_loss = BCE(D_real, torch.ones_like(D_real))
              D_fake_loss = BCE(D_fake, torch.zeros_like(D_fake))
              D_loss = (D_real_loss + D_fake_loss) / 2

          net_D.zero_grad()
          d_scaler.scale(D_loss).backward()
          d_scaler.step(opt_disc)
          d_scaler.update()
      print(f"Epoch {e}")
      print(f"Loss: {D_loss/len(loader):.5f}")  
      path = "netD_no_gan_train" + str(e) + ".pt"
      print("Path :", path)
      torch.save(net_D.state_dict(), "/content/" + path)
      save_model(path)

Starting validation.


100%|██████████| 188/188 [08:34<00:00,  2.74s/it]


Epoch 11
Loss: 0.00022
Path : netD_no_gan_train11.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:33<00:00,  2.73s/it]


Epoch 12
Loss: 0.00020
Path : netD_no_gan_train12.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:33<00:00,  2.73s/it]


Epoch 13
Loss: 0.00064
Path : netD_no_gan_train13.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:32<00:00,  2.72s/it]


Epoch 14
Loss: 0.00011
Path : netD_no_gan_train14.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:33<00:00,  2.73s/it]


Epoch 15
Loss: 0.00022
Path : netD_no_gan_train15.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:37<00:00,  2.75s/it]


Epoch 16
Loss: 0.00020
Path : netD_no_gan_train16.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:36<00:00,  2.75s/it]


Epoch 17
Loss: 0.00016
Path : netD_no_gan_train17.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:38<00:00,  2.76s/it]


Epoch 18
Loss: 0.00007
Path : netD_no_gan_train18.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:37<00:00,  2.75s/it]


Epoch 19
Loss: 0.00020
Path : netD_no_gan_train19.pt
Mounted at /content/drive
Model Saved


100%|██████████| 188/188 [08:36<00:00,  2.75s/it]


Epoch 20
Loss: 0.00006
Path : netD_no_gan_train20.pt
Mounted at /content/drive
Model Saved


In [None]:
if __name__ == '__main__':
    print('Starting validation.')    
    # Switch model to validation mode
    model.eval()
    
    for (gray, ab) in tqdm(val_loader):

        # Use GPU if available
        L = gray.to(device = device, dtype = torch.float32)
        ab = ab.to(device = device, dtype = torch.float32)
        output = model(L) # throw away class predictions
        
        fake = torch.cat([L, output], dim = 1).detach().cpu().numpy()
        for i in range(fake.shape[0]):
          color_image = fake[i]
          color_image = color_image.transpose((1, 2, 0))
          color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
          color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
          color_image = lab2rgb(color_image.astype(np.float64))
          plt.axis(False)
          plt.imshow(color_image)
          plt.show()