In [0]:
# wget -N images.cocodataset.org/zips/train2017.zip
# wget -N images.cocodataset.org/zips/val2017.zip
# wget -N images.cocodataset.org/zips/test2017.zip
# pip3 install tensorboard
# tensorboard --logdir=runs

In [0]:
import os
import time
import numpy as np 
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
import cv2
import matplotlib.pyplot as plt
from torchvision.utils import save_image

In [0]:
class Configuration:
    model_file_name = 'checkpoint.pt'
    load_model_to_train = True
    load_model_to_test = True
    device = "cuda" if torch.cuda.is_available() else "cpu"
    point_batches = 500

In [0]:
class HyperParameters:
    epochs = 30
    batch_size = 32
    learning_rate = 0.001
    num_workers = 16
    learning_rate_decay = 0.5

In [0]:
config = Configuration()
hparams = HyperParameters()
print('Device:',config.device)

Device: cuda


In [0]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, process_type):
        self.root_dir = root_dir
        self.files = [f for f in os.listdir(root_dir)]
        self.process_type = process_type
        print('File[0]:',self.files[0],'| Total Files:', len(self.files), '| Process:',self.process_type,)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        try:
            #*** Read the image from file ***
            self.rgb_img = cv2.imread(os.path.join(self.root_dir,self.files[index])).astype(np.float32) 
            self.rgb_img /= 255.0 
            
            #*** Resize the color image to pass to encoder ***
            rgb_encoder_img = cv2.resize(self.rgb_img, (224, 224))
            
            #*** Resize the color image to pass to decoder ***
            rgb_inception_img = cv2.resize(self.rgb_img, (300, 300))
            
            ''' Encoder Images '''
            #*** Convert the encoder color image to normalized lab space ***
            self.lab_encoder_img = cv2.cvtColor(rgb_encoder_img,cv2.COLOR_BGR2Lab) 
            
            #*** Splitting the lab images into l-channel, a-channel, b-channel ***
            l_encoder_img, a_encoder_img, b_encoder_img = self.lab_encoder_img[:,:,0],self.lab_encoder_img[:,:,1],self.lab_encoder_img[:,:,2]
            
            #*** Normalizing l-channel between [-1,1] ***
            l_encoder_img = l_encoder_img/50.0 - 1.0
            
            #*** Repeat the l-channel to 3 dimensions ***
            l_encoder_img = torchvision.transforms.ToTensor()(l_encoder_img)
            l_encoder_img = l_encoder_img.expand(3,-1,-1)
            
            #*** Normalize a and b channels and concatenate ***
            a_encoder_img = (a_encoder_img/128.0)
            b_encoder_img = (b_encoder_img/128.0)
            a_encoder_img = torch.stack([torch.Tensor(a_encoder_img)])
            b_encoder_img = torch.stack([torch.Tensor(b_encoder_img)])
            ab_encoder_img = torch.cat([a_encoder_img, b_encoder_img], dim=0)
            
            ''' Inception Images '''
            #*** Convert the inception color image to lab space ***
            self.lab_inception_img = cv2.cvtColor(rgb_inception_img,cv2.COLOR_BGR2Lab)
            
            #*** Extract the l-channel of inception lab image *** 
            l_inception_img = self.lab_inception_img[:,:,0]/50.0 - 1.0
             
            #*** Convert the inception l-image to torch Tensor and stack it in 3 channels ***
            l_inception_img = torchvision.transforms.ToTensor()(l_inception_img)
            l_inception_img = l_inception_img.expand(3,-1,-1)
            
            ''' return images to data-loader '''
            rgb_encoder_img = torchvision.transforms.ToTensor()(rgb_encoder_img)
            return l_encoder_img, ab_encoder_img, l_inception_img, rgb_encoder_img, self.files[index]
        
        except Exception as e:
            print('Exception at ',self.files[index], e)
            return torch.tensor(-1), torch.tensor(-1), torch.tensor(-1), torch.tensor(-1), 'Error'

    def show_rgb(self, index):
        self.__getitem__(index)
        print("RGB image size:", self.rgb_img.shape)        
        cv2.imshow(self.rgb_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    def show_lab_encoder(self, index):
        self.__getitem__(index)
        print("Encoder Lab image size:", self.lab_encoder_img.shape)
        cv2.imshow(self.lab_encoder_img)
        c2.waitKey(0)
        cv2.destroyAllWindows()

    def show_lab_inception(self, index):
        self.__getitem__(index)
        print("Inception Lab image size:", self.lab_inception_img.shape)
        cv2.imshow(self.lab_inception_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    def show_other_images(self, index):
        a,b,c,d,_ = self.__getitem__(index)
        print("Encoder l channel image size:",a.shape)
        cv2.imshow((a.detach().numpy().transpose(1,2,0)))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        print("Encoder ab channel image size:",b.shape)
        cv2.imshow((b.detach().numpy().transpose(1,2,0)[:,:,0]))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        cv2.imshow((b.detach().numpy().transpose(1,2,0)[:,:,1]))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        print("Inception l channel image size:",c.shape)
        cv2.imshow(c.detach().numpy().transpose(1,2,0))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        print("Color resized image size:",d.shape)
        cv2.imshow(d.detach().numpy().transpose(1,2,0))
        cv2.waitKey(0) 
        cv2.destroyAllWindows()

In [0]:
train_dataset = CustomDataset('/content/train2017','train')

In [0]:
train_dataset.show_rgb(0)
train_dataset.show_lab_encoder(0)
train_dataset.show_lab_inception(0)
train_dataset.show_other_images(0)

In [0]:
class my_model(nn.Module):
    def __init__(self, depth_after_fusion):
        super(my_model,self).__init__()
        
        # Encoder Layer Network
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1) 
        self.r1 = nn.ReLU(inplace=True)
        self.b1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.r2 = nn.ReLU(inplace=True)
        self.b2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.r3 = nn.ReLU(inplace=True)
        self.b3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.r4 = nn.ReLU(inplace=True)
        self.b4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.r5 = nn.ReLU(inplace=True)
        self.b5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.r6 = nn.ReLU(inplace=True)
        self.b6 = nn.BatchNorm2d(512)
        self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.r7 = nn.ReLU(inplace=True)
        self.b7 = nn.BatchNorm2d(512)
        self.conv8 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.r8 = nn.ReLU(inplace=True)
        self.b8 = nn.BatchNorm2d(256)

        # Fusion Layer
        self.after_fusion = nn.Conv2d(in_channels=1256, out_channels=depth_after_fusion,kernel_size=1, stride=1,padding=0)
        self.after_fusion_res = nn.Conv2d(in_channels=depth_after_fusion*2, out_channels=depth_after_fusion,kernel_size=1, stride=1,padding=0)
        self.bf = nn.BatchNorm2d(256)
        self.rf = nn.ReLU(inplace=True)

        # Decoder Layer Network
        self.conv9 = nn.Conv2d(in_channels=depth_after_fusion, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.r9 = nn.ReLU(inplace=True)
        self.b9 = nn.BatchNorm2d(128)
        self.u9 = nn.Upsample(scale_factor=2.0)
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.r10 = nn.ReLU(inplace=True)
        self.b10 = nn.BatchNorm2d(64)
        self.conv11 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.r11 = nn.ReLU(inplace=True)
        self.b11 = nn.BatchNorm2d(64)
        self.u11 = nn.Upsample(scale_factor=2.0)
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.r12 = nn.ReLU(inplace=True)
        self.b12 = nn.BatchNorm2d(32)
        self.conv13 = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
        self.t13 = nn.Tanh()
        self.u13 = nn.Upsample(scale_factor=2.0)

    def forward(self, x, emb):
        #self.model = self.model.float()
        
        # Encoder forward
        y = self.conv1(x)
        skip_y1 = self.b1(y)
        y = self.r1(skip_y1)
        
        y = self.conv2(y)
        y = self.b2(y)
        y = self.r2(y)
        
        y = self.conv3(y)
        skip_y2 = self.b3(y)
        y = self.r3(skip_y2)
        
        y = self.conv4(y)
        y = self.b4(y)
        y = self.r4(y)
        
        y = self.conv5(y)
        y = self.b5(y)
        y = self.r5(y)
        
        y = self.conv6(y)
        y = self.b6(y)
        y = self.r6(y)
        
        y = self.conv7(y)
        y = self.b7(y)
        y = self.r7(y)
        
        y = self.conv8(y)
        skip_y3 = self.b8(y)
        y = self.r8(skip_y3)
        
        # Fusion layer
        emb = torch.stack([torch.stack([emb],dim=2)],dim=3)
        emb = emb.repeat(1,1,y.shape[2],y.shape[3])
        fusion = torch.cat((y,emb),1)
        y = self.after_fusion(fusion)
        y = torch.cat((y, skip_y3), 1) # Skip connection

        y = self.after_fusion_res(y)
        y = self.bf(y)
        y = self.rf(y)

        # Decoder forward
        y = self.u9(y)
        y = self.conv9(y)
        y = torch.cat((y, skip_y2), 1) # Skip connection
        
        y = self.conv9(y)
        y = self.b9(y)
        y = self.r9(y)
          
        y = self.u11(y)
        y = self.conv10(y)
        y = torch.cat((y, skip_y1), 1) # Skip connection

        y = self.conv10(y)
        y = self.b10(y)
        y = self.r10(y)
        
        y = self.conv11(y)
        y = self.b11(y)
        y = self.r11(y)

        y = self.conv12(y)
        y = self.b12(y)
        y = self.r12(y)
        
        y = self.conv13(y)
        y = self.t13(y)
        y = self.u13(y)
              
        return y.float()

In [0]:
if config.load_model_to_train or config.load_model_to_test:
    checkpoint = torch.load("/content/drive/My Drive/IDL Project/Models/Coco_checkpoint10.pt",map_location=torch.device(config.device))
    model = checkpoint['model']
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(config.device) 
    optimizer = checkpoint['optimizer']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print('Loaded pretrain model | Previous train loss:',checkpoint['train_loss'])
else:
    model = my_model(256).to(config.device) 
    # model.apply(init_weights)
    optimizer = torch.optim.Adam(model.parameters(),lr=hparams.learning_rate, weight_decay=1e-6)

In [0]:
print(model)

In [0]:
inception_model = models.inception_v3(pretrained=True).float().to(config.device)
inception_model = inception_model.float()
inception_model.eval()
loss_criterion = torch.nn.MSELoss(reduction='mean').to(config.device)
milestone_list  = list(range(0,hparams.epochs,2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)
writer = SummaryWriter()

In [0]:
if not config.load_model_to_test:
    train_dataset = CustomDataset('/content/train2017','train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True, num_workers=hparams.num_workers)
    

    # validataion_dataset = CustomDataset('data/validation','validation')
    # validation_dataloader = torch.utils.data.DataLoader(validataion_dataset, batch_size=hparams.batch_size, shuffle=False, num_workers=hparams.num_workers)
    
    print('Train:',len(train_dataloader), '| Total Images:',len(train_dataloader)*hparams.batch_size)
    # print('Valid:',len(validation_dataloader), '| Total Images:',len(validation_dataloader)*hparams.batch_size)

In [0]:
if not config.load_model_to_test:
    for epoch in range(0, hparams.epochs):
        print('Starting epoch:',epoch+1)

        #*** Training step ***
        loop_start = time.time()
        avg_loss = 0.0
        batch_loss = 0.0
        main_start = time.time()
        model.train()

        for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb, file_name) in enumerate(train_dataloader):
            #*** Skip bad data ***
            if not img_l_encoder.ndim:
                continue

            #*** Move data to GPU if available ***
            img_l_encoder = img_l_encoder.to(config.device)
            img_ab_encoder = img_ab_encoder.to(config.device)
            img_l_inception = img_l_inception.to(config.device)

            #*** Initialize Optimizer ***
            optimizer.zero_grad()

            #*** Forward Propagation ***
            img_embs = inception_model(img_l_inception.float())
            output_ab = model(img_l_encoder,img_embs)

            #*** Back propogation ***
            loss = loss_criterion(output_ab, img_ab_encoder.float())
            loss.backward()

            #*** Weight Update ****
            optimizer.step()

            #*** Loss Calculation ***
            avg_loss += loss.item()
            batch_loss += loss.item()

            #*** Print stats after every point_batches ***
            if idx%config.point_batches==0: 
                loop_end = time.time()   
                print('Batch:',idx, '| Processing time for',config.point_batches,':',loop_end-loop_start,'s | Batch Loss:', batch_loss/config.point_batches)
                loop_start = time.time()
                batch_loss = 0.0

            torch.cuda.empty_cache()

        #*** Print Training Data Stats ***
        train_loss = avg_loss/len(train_dataloader)*hparams.batch_size
        writer.add_scalar('Loss/train', train_loss, epoch)
        print('Training Loss:',train_loss,'| Processed in ',time.time()-main_start,'s')

        #*** Reduce Learning Rate ***
        scheduler.step(train_loss)

        # #*** Validation Step ***       
        # avg_loss = 0.0
        # loop_start = time.time()
        # #*** Intialize Model to Eval Mode for validation ***
        # model.eval()
        # for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb, file_name) in enumerate(validation_dataloader):
        #     #*** Skip bad data ***
        #     if not img_l_encoder.ndim:
        #         continue

        #     #*** Move data to GPU if available ***
        #     img_l_encoder = img_l_encoder.to(config.device)
        #     img_ab_encoder = img_ab_encoder.to(config.device)
        #     img_l_inception = img_l_inception.to(config.device)

        #     #*** Forward Propagation ***
        #     img_embs = inception_model(img_l_inception.float())
        #     output_ab = model(img_l_encoder,img_embs)

        #     #*** Loss Calculation ***
        #     loss = loss_criterion(output_ab, img_ab_encoder.float())
        #     avg_loss += loss.item()

        # val_loss = avg_loss/len(validation_dataloader)*hparams.batch_size
        # writer.add_scalar('Loss/validation', val_loss, epoch)
        # print('Validation Loss:', val_loss,'| Processed in ',time.time()-loop_start,'s')

        # #*** Save the Model to disk ***
        checkpoint = {'model': model,'model_state_dict': model.state_dict(), 'optimizer' : optimizer,'optimizer_state_dict' : optimizer.state_dict(), 'train_loss':train_loss}
        torch.save(checkpoint, "/content/drive/My Drive/IDL Project/Models/Coco_checkpoint"+str(epoch)+".pt")
        print("Model saved at:",os.getcwd(),'/',config.model_file_name)

In [0]:
test_dataset = CustomDataset('/content/drive/My Drive/val2017','test')
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
print('Test: ',len(test_dataloader), '| Total Image:',len(test_dataloader))

In [0]:
def concatente_and_colorize(im_lab, img_ab):
    # Assumption is that im_lab is of size [1,3,224,224]
    #print(im_lab.size(),img_ab.size())
    np_img = im_lab[0].cpu().detach().numpy().transpose(1,2,0)
    lab = np.empty([*np_img.shape[0:2], 3],dtype=np.float32)
    lab[:, :, 0] = np.squeeze(((np_img + 1) * 50))
    lab[:, :, 1:] = img_ab[0].cpu().detach().numpy().transpose(1,2,0) * 127
    np_img = cv2.cvtColor(lab,cv2.COLOR_Lab2RGB) 
    color_im = torch.stack([torchvision.transforms.ToTensor()(np_img)],dim=0)
    return color_im

In [0]:
def make_color(l_img, ab_img):
    np_img = l_img[0].detach().numpy().transpose(1,2,0)
    plt.imshow(np_img[:,:,0])
    ab_img = ab_img[0].detach().numpy().transpose(1,2,0)
    print(np.min(ab_img[:,:,0]), np.max(ab_img[:,:,0]))
    plt.imshow(ab_img[:,:,0])
    print(np.min(ab_img[:,:,1]), np.max(ab_img[:,:,1]))
    plt.imshow(ab_img[:,:,1])
    print(np_img.shape,ab_img.shape)
    np_img = np.concatenate((np_img,ab_img),axis=2)
    color_np_img = cv2.cvtColor(np_img,cv2.COLOR_Lab2RGB) 
    return color_np_img

In [0]:
def colorize(im_lab):
    # Assumption is that im_lab is of size [1,3,224,224]
    np_img = im_lab[0].detach().numpy().transpose(1,2,0)
    np_img = color.lab2rgb(np_img)
    color_im = torch.stack([torchvision.transforms.ToTensor()(np_img)],dim=0)
    return color_im

In [0]:
#*** Inference Step ***
avg_loss = 0.0
loop_start = time.time()
for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb, file_name) in enumerate(test_dataloader):
        #*** Skip bad data ***
        if not img_l_encoder.ndim:
            continue
            
        #*** Move data to GPU if available ***
        img_l_encoder = img_l_encoder.to(config.device)
        img_ab_encoder = img_ab_encoder.to(config.device)
        img_l_inception = img_l_inception.to(config.device)
        
        #*** Intialize Model to Eval Mode ***
        model.eval()
        
        #*** Forward Propagation ***
        img_embs = inception_model(img_l_inception.float())
        print(torch.min(img_l_encoder),torch.max(img_l_encoder))
        print(torch.min(img_embs),torch.max(img_embs))
        output_ab = model(img_l_encoder,img_embs)
        
        #*** Adding l channel to ab channels ***
        color_img = concatente_and_colorize(torch.stack([img_l_encoder[:,0,:,:]],dim=1),output_ab)
        #img_lab = concatente_and_colorize(torch.stack([img_l_encoder[:,0,:,:]],dim=1),output_ab)
        color_img_jpg = color_img[0].detach().numpy().transpose(1,2,0)
        # plt.imshow(color_img_jpg)
        # plt.show()
        # plt.imsave('outputs/'+file_name[0],color_img_jpg)
        save_image(color_img[0], '/content/drive/My Drive/Output_after_epoch12_coco/' + file_name[0]) 
        
        
        
#         #*** Printing to Tensor Board ***
#         grid = torchvision.utils.make_grid(img_lab)
#         writer.add_image('Output Lab Images', grid, 0)
        
        #*** Loss Calculation ***
        loss = loss_criterion(output_ab, img_ab_encoder.float())
        avg_loss += loss.item()
        
test_loss = avg_loss/len(test_dataloader)
writer.add_scalar('Loss/test', test_loss, epoch)
print('Test Loss:',avg_loss/len(test_dataloader),'| Processed in ',time.time()-loop_start,'s')

In [0]:
writer.close()