In [None]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from unet_model import UNet
from discriminator import Discriminator
from torchinfo import summary
from kornia.filters import spatial_gradient
%matplotlib inline

torch.manual_seed(1337)
np.random.seed(1337)

In [None]:
class MyDataset(Dataset):

    def __init__(self, X_path="dataset/x_train.npy", y_path="dataset/y_train.npy", transform_flag=False):
        self.X = np.load(X_path).transpose(0, 3, 1, 2)
        self.y = np.load(y_path)
        self.transform_flag = transform_flag
    
    def __len__(self):
        return self.X.shape[0]
    
    def transform(self, image, mask):
        # Random crop
        image = torch.tensor(image)
        mask = torch.tensor(mask)
        i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(480,480))
        image = transforms.functional_tensor.crop(image, i, j, h, w)
        mask = transforms.functional_tensor.crop(mask, i, j, h, w)

        # Random horizontal flipping
        if np.random.rand() > 0.5:
            image = transforms.functional_tensor.hflip(image)
            mask = transforms.functional_tensor.hflip(mask)
        
        # Random brightness
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_brightness(image, np.random.rand() + 0.5)
        
        # Random Contrast
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_contrast(image, np.random.rand() + 0.5)
        
        # Random Gamma
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_gamma(image, np.random.rand() + 0.5)
            
        # Random Hue
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_hue(image, np.random.rand() - 0.5)
            
        # Random Saturation
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_saturation(image, np.random.rand() + 0.5)
            
        return image, mask

    def __getitem__(self, idx):
        if self.transform_flag:
            return self.transform(self.X[idx], np.expand_dims(self.y[idx], 0))
        else:
            return self.X[idx], np.expand_dims(self.y[idx], 0)

In [None]:
train_dataset = MyDataset("dataset/x_train.npy", "dataset/y_train.npy", transform_flag=True)
val_dataset = MyDataset("dataset/x_val.npy", "dataset/y_val.npy", transform_flag=False)
test_dataset = MyDataset("dataset/x_test.npy", "dataset/y_test.npy", transform_flag=False)

#hyper params
batch_size = 1


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
device = "cuda"

generator = UNet(3, 1, bilinear=False)
generator.load_state_dict(torch.load("models/Unet_l1/model_100.pth"))
generator = generator.to(device)
optimizerG = torch.optim.Adam(generator.parameters(), lr=1e-4, weight_decay = 1e-5)

discriminator = Discriminator()
discriminator = discriminator.to(device)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=1e-4, weight_decay= 1e-5)

In [None]:
def scaleInvLoss(pred, groundtruth, lamda = 1, grad=True):
    #as implemented in https://arxiv.org/pdf/1406.2283.pdf
    log_pred = torch.log(pred)
    log_gt = torch.log(groundtruth)
    d = log_pred - log_gt
    n = torch.numel(pred)
    first_term = torch.sum(d**2)/n
    second_term = torch.sum(d)**2 / n**2
    if grad:
        grad = spatial_gradient(d)
        grad = torch.mean(grad**2)
        return first_term - lamda * second_term + grad
    else:
        return first_term - lamda * second_term

def train(epoch):
    generator_content_loss = 0
    generator_advarsarial_loss = 0
    discriminator_total_loss = 0
    generator.train()
    discriminator.train()
    print(f"Training Epoch {epoch}")
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        
        optimizerD.zero_grad()
        #get real samples
        data, real_target = data.to(device), target.to(device)
        
        #get fake targets
        fake_target = generator(data.float())
        
        #update discriminator for real and fake samples
        discriminator_output_real = discriminator(real_target)
        discriminator_output_fake = discriminator(fake_target.detach())
        
        #generate real and fake labels
        fake_label = torch.zeros_like(discriminator_output_fake)
        real_label = torch.ones_like(discriminator_output_real)
        
        discriminator_loss_real = F.binary_cross_entropy_with_logits(discriminator_output_real, real_label)
        discriminator_loss_fake = F.binary_cross_entropy_with_logits(discriminator_output_fake, fake_label)
        
        discriminator_loss = (discriminator_loss_real + discriminator_loss_fake)/2
        discriminator_total_loss += discriminator_loss 
        discriminator_loss.backward()
        optimizerD.step()
        
        #train generator
        optimizerG.zero_grad()
        fake_target = generator(data.float())
        discriminator_output_fake = discriminator(fake_target)
        content_loss = F.mse_loss(fake_target, real_target)
        advarsarial_loss = F.binary_cross_entropy_with_logits(discriminator_output_fake, real_label)
        generator_loss = content_loss + 0.5*advarsarial_loss
        generator_loss.backward()
        optimizerG.step()
        
        generator_content_loss += content_loss
        generator_advarsarial_loss += advarsarial_loss

    discriminator_total_loss /= len(train_loader.dataset)
    generator_content_loss /= len(train_loader.dataset)
    generator_advarsarial_loss /= len(train_loader.dataset)
    print(f"Training : Epoch {epoch} : Content Loss : {generator_content_loss}, Advarsarial Loss : {generator_advarsarial_loss} discriminator loss : {discriminator_total_loss}")
    return (generator_content_loss, generator_advarsarial_loss, discriminator_total_loss)

def validation():
    generator_content_loss = 0
    generator_advarsarial_loss = 0
    discriminator_total_loss = 0
    generator.eval()
    discriminator.eval()
    print(f"Validating Epoch {epoch}")
    for batch_idx, (data, target) in enumerate(tqdm(val_loader)):
        with torch.no_grad():
            #get real samples
            data, real_target = data.to(device), target.to(device)
        
            #get fake targets
            fake_target = generator(data.float())
        
            #update discriminator for real and fake samples
            discriminator_output_real = discriminator(real_target)
            discriminator_output_fake = discriminator(fake_target)
            
            #generate real and fake labels
            fake_label = torch.zeros_like(discriminator_output_fake)
            real_label = torch.ones_like(discriminator_output_real)
            
            discriminator_loss_real = F.binary_cross_entropy_with_logits(discriminator_output_real, real_label)
            discriminator_loss_fake = F.binary_cross_entropy_with_logits(discriminator_output_fake, fake_label)
        
            discriminator_loss = (discriminator_loss_real + discriminator_loss_fake)/2
            discriminator_total_loss += discriminator_loss
        
            #train generator
            content_loss = F.mse_loss(fake_target, real_target)
            advarsarial_loss = F.binary_cross_entropy_with_logits(discriminator_output_fake, real_label)
            generator_loss = content_loss + 0.5*advarsarial_loss
        
            generator_content_loss += content_loss
            generator_advarsarial_loss += advarsarial_loss

    discriminator_total_loss /= len(train_loader.dataset)
    generator_content_loss /= len(train_loader.dataset)
    generator_advarsarial_loss /= len(train_loader.dataset)
    print(f"Validation : Epoch {epoch} : Content Loss : {generator_content_loss}, Advarsarial Loss : {generator_advarsarial_loss} discriminator loss : {discriminator_total_loss}")
    return (generator_content_loss, generator_advarsarial_loss, discriminator_total_loss)

def test():
    images, predictions, ground_truths = [], [], []
    generator.eval()
    test_loss = 0
    print("Visualizing test results")
    for data, target in tqdm(test_loader):
        data, target = data.to(device), target.to(device)
        output = generator(data.float())
        test_loss += scaleInvLoss(output, target).item() # sum up batch loss
        images.append(data.detach().cpu())
        predictions.append(output.detach().cpu())
        ground_truths.append(target.detach().cpu())

    test_loss /= len(test_loader.dataset)
    print(f'test set: Average loss: {test_loss}')
    
    return images, predictions, ground_truths

In [None]:
epochs = 100

train_loss = []
validation_loss = []
learning_rate = []

for epoch in range(1, epochs + 1):
    loss = train(epoch)
    train_loss.append(loss)
    loss = validation()
    validation_loss.append(loss)
    model_file = 'models/gan/model_' + str(epoch) + '.pth'
    torch.save(generator.state_dict(), model_file)
    images, predictions, ground_truths = test()
    fig, axes = plt.subplots(5, 3, figsize=(15, 30))
    for i,ax in enumerate(axes):
        image = np.squeeze(images[i])
        pred = torch.exp(np.squeeze(predictions[i]))
        ground_truth = np.squeeze(ground_truths[i])
    
        ax[0].imshow(image.permute(1, 2, 0))
        ax[1].imshow(pred)
        ax[2].imshow(ground_truth)
    plt.show()