In [1]:
# Importing required libraries

import torch
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from matplotlib import pyplot as plt
from tqdm import tqdm
import torch.nn as nn
import torch.backends.cudnn
import torch.optim as optim
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2


In [11]:
# Some hyperparameters and paths

Device = 'cuda:0' if torch.cuda.is_available() else "cpu"
train_r = "/home/adhvik/Downloads/Adhvik/courses/summer/impl/FLIR_ADAS_1_3/train/RGB"
train_t = "/home/adhvik/Downloads/Adhvik/courses/summer/impl/FLIR_ADAS_1_3/train/thermal_8_bit"
val_r = "/home/adhvik/Downloads/Adhvik/courses/summer/impl/FLIR_ADAS_1_3/val/RGB"
val_t = "/home/adhvik/Downloads/Adhvik/courses/summer/impl/FLIR_ADAS_1_3/val/thermal_8_bit"
lr_g = 3e-4
lr_d = 3e-4
batch_size = 16
num_workers = 2
num_epoch = 5
lamda = 10

print(Device)


cuda:0


In [3]:
# Transformations

transform = A.Compose(
    [
#         A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        A.Resize(width=256, height=256),
        ToTensorV2(),
    ]
)

transform_2 = A.Compose(
    [A.Resize(width=2*256, height=2*256), ],
    additional_targets={"image": "image0"},
)

In [4]:
# Data preprocessing and Input data prep..
class pairing_data(Dataset):
    def __init__(self, par_dir_rgb, par_dir_th):
        super(pairing_data, self).__init__()
        self.rgb = par_dir_rgb
        self.th = par_dir_th
        self.files_rgb = os.listdir(self.rgb)
        self.n = len(self.files_rgb)
        self.files_th = []
        jpg = ['jpg']*(self.n)
        zip_object = zip(self.files_rgb, jpg)
        for list1_i, list2_i in zip_object:
            self.files_th.append(list1_i.replace(list2_i, 'jpeg'))
        # self.files_th = os.listdir(self.th)

    def __len__(self):
        lrgb = len(self.files_rgb)
        return lrgb

    def __getitem__(self, index):
        r_file = self.files_rgb[index]
        t_file = self.files_th[index]
        r_path = os.path.join(self.rgb, r_file)
        t_path = os.path.join(self.th, t_file)
        inp = np.array(Image.open(r_path))
        out = np.array(Image.open(t_path))

        mod = transform_2(image=inp, image0=out)
        inp = mod["image"]
        out = mod["image0"]

        inp = transform(image=inp)["image"]
        out = transform(image=out)["image"]

        return inp, out

In [5]:
class CNN(nn.Module):
    def __init__(self, inp, out, stride=2, kernel_size=4):
        super(CNN, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(inp, out, stride, kernel_size,
                      bias=False, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2),

        )

    def forward(self, x):
        return self.conv_layer(x)


class CNNB(nn.Module):
    def __init__(self, inp, out, stride=2, kernel_size=4):
        super(CNNB, self).__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(inp, out, stride, kernel_size, padding=1,
                      bias=False, padding_mode='reflect'),
            nn.BatchNorm2d(out),
            nn.LeakyReLU(0.2),

        )

    def forward(self, x):
        return self.conv_layer1(x)

    

class unet_block(nn.Module):
    def __init__(self, inp, out, dir='down', act_fn="ReLU", drop=False):
        super(unet_block, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(inp, out, 4, 2, 1, bias=False, padding_mode='reflect') if dir == 'down'
            else
            nn.ConvTranspose2d(inp, out, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out),
            nn.ReLU() if act_fn == 'ReLU' else nn.LeakyReLU(0.2),
        )
        self.drop = drop
        self.dropout = nn.Dropout(0.5)
        self.dir = dir

    def forward(self, x):
        x = self.block(x)
        return self.dropout(x)


In [6]:
class Disc(nn.Module):
    def __init__(self, inp=3, feature_dim=[64, 128, 256, 512]):
        super(Disc, self).__init__()
        convlayers = []
        input = feature_dim[0]
        convlayers.append(CNN(inp+1, input, 4, 2),)
        for feature in feature_dim[1:]:
            convlayers.append(
                CNNB(input, feature, 4, 1 if feature == feature_dim[-1] else 2))
            input = feature
        convlayers.append(nn.Conv2d(input, 1, kernel_size=4,
                          stride=1, padding=1, padding_mode='reflect'),)
        convlayers.append(nn.Sigmoid())
        self.model = nn.Sequential(*convlayers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        return self.model(x)

In [7]:
class Gen(nn.Module):
    def __init__(self, inp=3, feature_dim=64):
        super().__init__()
        self.l1 = nn.Sequential(nn.Conv2d(inp, feature_dim, 4, 2, 1, padding_mode='reflect'),
                                nn.LeakyReLU(0.2))
        self.l2 = unet_block(
            feature_dim, feature_dim*2, dir='down', act_fn='LReLU', drop=False)
        self.l3 = unet_block(
            feature_dim*2, feature_dim*4, dir='down', act_fn='LReLU', drop=False)
        self.l4 = unet_block(
            feature_dim*4, feature_dim*8, dir='down', act_fn='LReLU', drop=False)
        self.l5 = unet_block(
            feature_dim*8, feature_dim*8, dir='down', act_fn='LReLU', drop=False)
        self.l6 = unet_block(
            feature_dim*8, feature_dim*8, dir='down', act_fn='LReLU', drop=False)
        self.l7 = unet_block(
            feature_dim*8, feature_dim*8, dir='down', act_fn='LReLU', drop=False)

        self.ul = nn.Sequential(nn.Conv2d(feature_dim*8, feature_dim*8, 4, 2, 1),
                                nn.ReLU()
                                )

        self.l11 = unet_block(
            feature_dim*8, feature_dim*8, dir='up', act_fn='ReLU', drop=True)
        self.l12 = unet_block(
            feature_dim*8*2, feature_dim*8, dir='up', act_fn='ReLU', drop=True)

        self.l13 = unet_block(
            feature_dim*8*2, feature_dim*8, dir='up', act_fn='ReLU', drop=True)

        self.l14 = unet_block(
            feature_dim*8*2, feature_dim*8, dir='up', act_fn='ReLU', drop=False)

        self.l15 = unet_block(
            feature_dim*8*2, feature_dim*4, dir='up', act_fn='ReLU', drop=False)

        self.l16 = unet_block(
            feature_dim*8, feature_dim*2, dir='up', act_fn='ReLU', drop=False)

        self.l17 = unet_block(
            feature_dim*4, feature_dim, dir='up', act_fn='ReLU', drop=False)

        self.map_ = nn.Sequential(
            nn.ConvTranspose2d(feature_dim*2, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        layer1 = self.l1(x)
        layer2 = self.l2(layer1)
        layer3 = self.l3(layer2)
        layer4 = self.l4(layer3)
        layer5 = self.l5(layer4)
        layer6 = self.l6(layer5)
        layer7 = self.l7(layer6)

        b_layer = self.ul(layer7)

        layer11 = self.l11(b_layer)
        layer22 = self.l12(torch.cat([layer11, layer7], dim=1))
        layer33 = self.l13(torch.cat([layer22, layer6], dim=1))
        layer44 = self.l14(torch.cat([layer33, layer5], dim=1))
        layer55 = self.l15(torch.cat([layer44, layer4], dim=1))
        layer66 = self.l16(torch.cat([layer55, layer3], dim=1))
        layer77 = self.l17(torch.cat([layer66, layer2], dim=1))

        map_layer = self.map_(torch.cat([layer77, layer1], dim=1))

        return map_layer

In [8]:
disc = Disc(inp=3).to(Device)
gen = Gen(inp=3, feature_dim=64).to(Device)
opt_disc = optim.Adam(
    disc.parameters(), lr=lr_d, betas=(0.5, 0.999),)
opt_gen = optim.Adam(
    gen.parameters(), lr=lr_g, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

In [9]:
train_dataset = pairing_data(train_r,train_t)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)
val_dataset = pairing_data(val_r,val_t)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
# Training the gen and disc

for epoch in range(num_epoch):
    print('epoch: ',epoch+1)
    for idx, (x, y) in enumerate(tqdm(train_loader,leave = True)):
        if(idx<10):
            x = x.float()/255
            y = y.float()/255
            x = x.to(Device)
            y = y.to(Device)
            # Training disc
            y_fake = gen(x)
            D_real = disc(x, y)
            D_fake = disc(x, y_fake.detach())
        
            opt_disc.zero_grad()
            D_real_loss = BCE(D_real, torch.ones_like(D_real))
            D_fake_loss = BCE(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2
            D_loss.backward(retain_graph=True)
            opt_disc.step()
            
            # Training gen
            D_fake = disc(x, y_fake)
            
            opt_gen.zero_grad()
            G_fake_loss = BCE(D_fake, torch.ones_like(D_fake))
            L1 = L1_LOSS(y_fake, y) * lamda
            G_loss = G_fake_loss + L1
            G_loss.backward()
            opt_gen.step()
        else:
            break
        
#         if idx % 10 == 0:
#             loop.set_postfix(
#                 D_real=torch.sigmoid(D_real).mean().item(),
#                 D_fake=torch.sigmoid(D_fake).mean().item(),
#             )
            
    for x,y in val_loader :
        x = x.float()/255
        y = y.float()/255
        x = x.to(Device)
        y = y.to(Device)
        img = gen(x).detach().numpy()[0]
        img = np.moveaxis(img, 0, -1)
        x = x.detach().numpy()[0]
        x = np.moveaxis(x, 0, -1)
        plt.imshow(img)
        plt.show()
        plt.imshow(x)

        break

In [None]:
for x,y in val_loader :
    x = x.float()/255
    y = y.float()/255
    x = x.to(Device)
    y = y.to(Device)
    img = gen(x).detach().numpy()[0]
    img = np.moveaxis(img, 0, -1)
    x = x.detach().numpy()[0]
    x = np.moveaxis(x, 0, -1)
    plt.imshow(img)
    plt.show()
    plt.imshow(x)

    break