# Data

In [0]:
# # Detect if we are in Google Colaboratory
# try:
#     import google.colab
#     IN_COLAB = True
# except ImportError:
#     IN_COLAB = False

# from pathlib import Path
# # Determine the locations of auxiliary libraries and datasets.
# # `AUX_DATA_ROOT` is where 'notmnist.py', 'animation.py' and 'tiny-imagenet-2020.zip' are.
# if IN_COLAB:
#     google.colab.drive.mount("/content/drive")
    
#     # Change this if you created the shortcut in a different location
#     AUX_DATA_ROOT = Path("/content/drive/My Drive/")
    
#     assert AUX_DATA_ROOT.is_dir(), "Have you forgotten to 'Add a shortcut to Drive'?"
    
#     import sys
#     sys.path.insert(0, str(AUX_DATA_ROOT))
# else:
#     AUX_DATA_ROOT = Path(".")

# # unzip the data
# ![ ! -d "Data" ] && unzip -q "{AUX_DATA_ROOT / 'cropped_data_720x720'}"
# ![ ! -d "Ground_truth" ] && unzip -q "{AUX_DATA_ROOT / 'cropped_gt_720x720'}"

In [0]:
import os
from collections import Counter

def get_idx_and_num_samples(root):
    data = [int(file.split("_")[0]) for files in os.walk(root) for file in files[2]]
    counter = dict(Counter(data))

    n = [key for key, _ in counter.items()]
    count = [value for _, value in counter.items()]

    return n, count

In [0]:
# get indexes and number of samples for each example

root = '/content/content/cropped_data_720x720'
n, count = get_idx_and_num_samples(root)

In [0]:
## to remove 3, 4 and 5 samples

# for ind, number in zip(n, count):
#     if number == 3:
#         os.remove(f"/content/content/cropped_data/{ind}_3.png")

#     if number == 4:
#         os.remove(f"/content/content/cropped_data/{ind}_3.png")
#         os.remove(f"/content/content/cropped_data/{ind}_4.png")

#     if number == 5:
#         os.remove(f"/content/content/cropped_data/{ind}_3.png")
#         os.remove(f"/content/content/cropped_data/{ind}_4.png")
#         os.remove(f"/content/content/cropped_data/{ind}_5.png")

# Generator

In [0]:
### ================== GENERATOR============================ ####

import torch.nn.functional as F
from torch import nn
import torch
import torch.optim as optim


class DenseBlock(nn.Module):
    """
    DenseNet
    """
    def __init__(self,channels,beta = 0.5):
        super(DenseBlock,self).__init__()
        self.beta = beta
        self.conv_module1 = nn.Sequential(
                nn.Conv2d(channels,channels, 3, 1, padding=1),
                nn.LeakyReLU(inplace=True)
                )
        self.conv_module2 = nn.Sequential(
                nn.Conv2d(channels,channels, 3, 1, padding=1),
                nn.LeakyReLU(inplace=True)
                )
        self.conv_module3 = nn.Sequential(
                nn.Conv2d(channels,channels, 3, 1, padding=1),
                nn.LeakyReLU(inplace=True)
                )
        self.conv_module4 = nn.Sequential(
                nn.Conv2d(channels,channels, 3, 1, padding=1),
                nn.LeakyReLU(inplace=True)
                )
        self.last_conv = nn.Conv2d(channels, channels, 3, 1, padding = 1) 

    def forward(self, x): #three layer
        module1_out = self.conv_module1(x)
        module1_out_temp = x + module1_out
        module2_out = self.conv_module2(module1_out_temp)
        module2_out_temp = x + module1_out_temp + module2_out
        module4_out_temp = x + module1_out_temp + module2_out_temp
        last_conv = self.last_conv(module4_out_temp)
        out = x + last_conv * self.beta

        return out


class Light(nn.Module):

    """
    Denseblock,Unet
    """
    def __init__(self, in_c, out_c, residual_beta = 0.5):
        """
        in_c : input channels
        out_c: output channels
        """
        super(Light,self).__init__()
        self.residual_beta = residual_beta

        self.inconv = nn.Sequential(
            nn.Conv2d(in_c, 64, 9, 1, padding=4),
            nn.PReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, stride = 1, padding = 1),
            nn.PReLU(),
            nn.Sequential(*[DenseBlock(64, beta = residual_beta) for _ in range(2)])
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride = 1, padding = 1),
            nn.PReLU(),
            nn.Sequential(*[DenseBlock(128, beta = residual_beta) for _ in range(2)])

        )
        self.down3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride = 1, padding = 1),
            nn.PReLU(),
            nn.Sequential(*[DenseBlock(256, beta = residual_beta) for _ in range(2)])
        )
        self.bottom = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1 ,padding = 1),
            nn.PReLU(),
            nn.Conv2d(512, 512, 3, stride = 1, padding = 1),
            nn.PReLU(),
            nn.Conv2d(512, 256, 3, 1, padding = 1),
            nn.PReLU()
        )
        self.up1 = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding = 1),
            nn.PReLU(),
            nn.Sequential(*[DenseBlock(256 ,beta = residual_beta) for _ in range(2)]),
            nn.Conv2d(256, 128, 3, padding = 1),
            nn.PReLU()
        )
        self.up2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding = 1),
            nn.PReLU(),
            nn.Sequential(*[DenseBlock(128,beta = residual_beta) for _ in range(2)]),
            nn.Conv2d(128, 64, 3, padding = 1),
            nn.PReLU()
        )
        self.up3 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding = 1),
            nn.PReLU(),
            nn.Sequential(*[DenseBlock(64,beta = residual_beta) for _ in range(2)]),
            nn.Conv2d(64, 64, 3, padding = 1),
            nn.PReLU()
        )
        self.out = nn.Conv2d(64, out_c, 9, 1, padding = 4)

    def forward(self,x):
        cin = self.inconv(x)
        down1 = self.down1(cin)
        downsample1 = F.avg_pool2d(down1, kernel_size = 2, stride = 2)
        down2 = self.down2(downsample1)
        downsample2 = F.avg_pool2d(down2, kernel_size = 2, stride = 2)
        down3 = self.down3(downsample2)
        downsample3 = F.avg_pool2d(down3, kernel_size = 2, stride = 2)

        bottom = self.bottom(downsample3)

        upsample1 = F.interpolate(bottom, scale_factor = 2)

        cat1 = torch.cat([down3, upsample1], dim = 1)
        up1 = self.up1(cat1)
        upsample2 = F.interpolate(up1, scale_factor = 2)
        cat2 = torch.cat([down2, upsample2],dim = 1)
        up2 = self.up2(cat2)
        upsample3 = F.interpolate(up2, scale_factor = 2)
        cat3 = torch.cat([down1, upsample3], dim = 1)
        up3 = self.up3(cat3)
        out = self.out(up3)
        out = (torch.tanh(out) + 1) / 2

        return out

# Discriminator

In [0]:
import torch.nn.functional as F
from torch import nn

class DiscriminatorPatch64(nn.Module):
    def __init__(self):
        super(DiscriminatorPatch64, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.MaxPool2d(kernel_size=(4, 4), stride=(4, 4)),

            nn.Conv2d(512, 1, kernel_size=1)
        )

    def forward(self, x):
        output = self.net(x)
        output_shape = (output.size(0), output.size(2), output.size(3))
        final_output = torch.sigmoid(output.view(output_shape))
        
        return final_output

In [0]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np


class microDataset(Dataset):
    def __init__(self, img_dir, gt_dir, n, transform=None):
        self.n = n
        self.img_dir = img_dir
        self.gt_dir = gt_dir
        self.img = sorted(next(os.walk(self.img_dir))[2])
        self.gt = sorted(next(os.walk(self.gt_dir))[2])
        self.transform = transform

    def __len__(self):
        return len(self.gt)
    
    def __getitem__(self, idx):
        img1 = Image.open(open(os.path.join(self.img_dir, f"{self.n[idx]}_1.png"), 'rb'))
        img2 = Image.open(open(os.path.join(self.img_dir,  f"{self.n[idx]}_2.png"), 'rb'))
        ground_truth = Image.open(open(os.path.join(self.gt_dir, f"{self.n[idx]}_gt.png"), 'rb'))

        seed = np.random.randint(666)

        random.seed(seed)
        img1 = self.transform(img1)

        random.seed(seed)
        img2 = self.transform(img2)

        random.seed(seed)
        ground_truth = self.transform(ground_truth)

        
        return img1, img2, ground_truth

In [0]:
data_transform = transforms.ToTensor()
train_dataset = microDataset('/content/content/cropped_data_720x720', '/content/content/cropped_gt_720x720', n, transform=data_transform)
train_dataloader = DataLoader(train_dataset, batch_size = 1, shuffle=True, pin_memory=True)

In [0]:
def gan_train(epochs, dataset, batch_size, optim_disc, optim_gen, discriminator, generator, scheduler_optim_discriminator, scheduler_optim_generator, exp_name = 'my GAN', flag_gp=False):
    
    writer = SummaryWriter(f'logs/{exp_name}')
    losses_D_hist = []
    losses_G_hist = []
  
    for epoch in range(epochs):
      
        "Doesn't improve the results"
        if epoch == 100:
           batch_size *= 2
           
        # Train discriminator
        discriminator_running_loss = 0.0
        total = 0

        # Dataloader
        loader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle=True, pin_memory=True)

        for batch_idx, (img1, img2, target) in enumerate(loader):

            real_data1, real_data2, target = Variable(img1.cuda()), Variable(img2.cuda()), Variable(target.cuda())
            mask1 = torch.ones((1, 1, 720, 720)).cuda()
            mask2 = torch.ones((1, 1, 720, 720)).cuda()

            cat_data = torch.cat((real_data1, real_data2, mask1, mask2), dim=1)

            # for _ in range(3):
            optim_disc.zero_grad()
            optim_gen.zero_grad()
            total += batch_size
            fake_data = generator(cat_data)

            disc_loss = nn.BCELoss()(discriminator(target), torch.ones(discriminator(target).shape).cuda()) + nn.BCELoss()(discriminator(fake_data), torch.zeros(discriminator(fake_data).shape).cuda())
            
            discriminator_running_loss += disc_loss.item() * batch_size

            disc_loss.backward()
            optim_disc.step()


            # discriminator loss
            loss_discriminator = discriminator_running_loss / total
            losses_D_hist.append(loss_discriminator)

            # update generator
            optim_disc.zero_grad()
            optim_gen.zero_grad()
            fake_data = generator(cat_data)
            gen_loss = nn.L1Loss()(fake_data, target) +  0.001 * nn.BCELoss()(discriminator(fake_data), torch.ones(discriminator(fake_data).shape).cuda())
            losses_G_hist.append(gen_loss.item())

            gen_loss.backward()
            optim_gen.step()

            # SENDING LOSS TO TENSORBOARD
            writer.add_scalar('Discriminator loss', loss_discriminator, global_step = len(losses_D_hist)) 

            # SENDING LOSS TO TENSORBOARD
            writer.add_scalar('Generator loss', gen_loss.item(), global_step = len(losses_G_hist))

        # # SENDING IMAGES TO TENSORBOARD
        # writer.add_images("GAN/Outputs", torch.FloatTensor(display_images(generator, sharp=True)), epoch+1)
        scheduler_optim_generator.step()
        scheduler_optim_discriminator.step()
        
        print(f'epoch : {epoch} | D_loss : {np.round(loss_discriminator, 4)} | G_loss : {np.round(gen_loss.item(), 4)}')
        

    return losses_D_hist, losses_G_hist

In [0]:
from IPython import display
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torch.optim.lr_scheduler import ExponentialLR
import torch.optim as optim
from torch.autograd import Variable

### utils
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
exp_name = datetime.now().isoformat(timespec='seconds') + f'GAN'
batch_size = 1
epochs = 100

### models
generator = Light(8, 3).cuda()
discriminator = DiscriminatorPatch64().cuda()

### optimizers
optim_discriminator = optim.Adam(discriminator.parameters(), lr = 0.25 * 1e-4, betas=(0.0, 0.9))
optim_generator = optim.Adam(generator.parameters(), lr = 0.5 * 1e-4, betas=(0.0, 0.9))

### scheduler:  The learning rates of the generator and discriminator are attenuated after each epoch, lr=lr×0.8
scheduler_optim_discriminator = torch.optim.lr_scheduler.StepLR(optim_discriminator, 1, 0.8)
scheduler_optim_generator = torch.optim.lr_scheduler.StepLR(optim_generator, 1, 0.8)

loss_D_naive, loss_G_naive = gan_train(epochs, train_dataset, batch_size, optim_discriminator, optim_generator, discriminator, generator, scheduler_optim_discriminator, scheduler_optim_generator, exp_name = exp_name, flag_gp=False)