<a href="https://colab.research.google.com/github/Karko93/CIL-GAN-Project/blob/master/SRGAN_Ketzel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import os
import csv 
import torchvision
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch
import torch.utils.data as datatorch
import torch.nn as nn
import torch.backends.cudnn as cudnn
#from google.colab import drive
import time
import datetime
import pandas as pd
import torch.nn.functional as F
from torchvision.models import vgg19
import math
 

In [None]:
labeled_data = pd.read_csv("cosmology_aux_data_170429/labeled.csv", dtype = np.int64).to_numpy()
scored_data = pd.read_csv("cosmology_aux_data_170429/scored.csv").to_numpy()


In [None]:
labeled_dir = 'cosmology_aux_data_170429/labeled'
query_dir = 'cosmology_aux_data_170429/query'
scored_dir = 'cosmology_aux_data_170429/scored'
labeled_files = os.listdir(labeled_dir)
scored_files = os.listdir(scored_dir)
query_files = os.listdir(query_dir)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, file_array,dir, hr_shape, mode = 'train', labels = None):
        self.files = file_array
        self.mode = mode
        self.labels  = labels
        self.dir = dir
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                transforms.Normalize([0.5], [0.5])
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                transforms.Normalize([0.5], [0.5])
            ]
        )


    def __getitem__(self, index):
        img = Image.open(os.path.join(self.dir, self.files[index]))
        #img = imread(os.path.join(self.dir, self.files[index])).astype(np.float)
        #img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
        #img =torch.stack([img,img,img],0)
        #img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)
        
        if self.labels is not None:
            mask = np.isin(self.labels[:,0], int(self.files[index][:-4]))
            label = self.labels[mask,1]
            
            return img_lr,img_hr,label
        else:
            return img_lr,img_hr

    def __len__(self):
        return len(self.files)

In [None]:
hr_shape = (1000, 1000)
scored_dataset = ImageDataset(scored_files, scored_dir, hr_shape=hr_shape, mode = 'train', labels =scored_data)
labeled_dataset = ImageDataset(labeled_files, labeled_dir, hr_shape=hr_shape, mode = 'train', labels =labeled_data)
query_dataset = ImageDataset(query_files, query_dir, hr_shape=hr_shape, mode = 'test', labels =None)

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )

    def forward(self, x):
        return x + self.conv_block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GeneratorResNet, self).__init__()

        # First layer
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())

        # Residual blocks
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)

        # Second conv layer post residual blocks
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))

        # Upsampling layers
        upsampling = []
        for out_features in range(2):
            upsampling += [
                # nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)

        # Final output layer
        self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4)+1, int(in_width / 2 ** 4)+1 ####
        #self.output_shape = (1, patch_h, patch_w)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [None]:
os.makedirs("saved_images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

In [None]:
learning_rate = 0.0001
batch_size =1
epochs = 4
logstep = int(10000 // batch_size)
start_epoch = 0
checkpoint_interval= 1


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize generator and discriminator
generator = GeneratorResNet(in_channels=3, out_channels=3, n_residual_blocks=16).to(device)
discriminator = Discriminator(input_shape=(3, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()


# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

labeled_Dataloader = datatorch.DataLoader(dataset=labeled_dataset, shuffle=False, batch_size=batch_size)
scored_Dataloader = datatorch.DataLoader(dataset=scored_dataset, shuffle=False, batch_size=batch_size)

In [None]:
feature_extractor.eval()

In [None]:
if start_epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("saved_models/generator_%d.pth"))
    discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth"))

In [None]:

for idx, (lr_img, hr_img, label) in enumerate(scored_Dataloader):
    lr_img,hr_img = torch.cat([lr_img,lr_img,lr_img],1).cuda(),torch.cat([hr_img,hr_img,hr_img],1).cuda()
    print(lr_img.size())

    valid = torch.Tensor(np.ones((lr_img.size(0), *discriminator.output_shape))).cuda()
    fake = torch.Tensor(np.zeros((hr_img.size(0), *discriminator.output_shape))).cuda()


    # ------------------
    #  Train Generators
    # ------------------

    optimizer_G.zero_grad()

    # Generate a high resolution image from low resolution input
    gen_hr = generator(lr_img)

    #print(lr_img.size(), gen_hr.size(),discriminator(gen_hr).size(), valid.size(),lr_img.size(0))
    # Adversarial loss
    loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

    # Content loss
    print(gen_hr.size(), hr_img.size())
    gen_features = feature_extractor(gen_hr)
    real_features = feature_extractor(hr_img)
    loss_content = criterion_content(gen_features, real_features.detach())
    # Total loss
    loss_G = loss_content + 1e-3 * loss_GAN

    loss_G.backward()
    optimizer_G.step()

# ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    # Loss of real and fake images

    print(discriminator(gen_hr).size())
    loss_real = criterion_GAN(discriminator(hr_img), valid)
    loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
    # Total loss
    loss_D = (loss_real + loss_fake) / 2

    loss_D.backward()
    optimizer_D.step()
    print(idx, len(scored_Dataloader), loss_D.item(), loss_G.item())
    
    batches_done = epoch * len(dataloader) + idx
        if batches_done % 100 == 0:
            # Save image grid with upsampled inputs and SRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, gen_hr), -1)
            save_image(img_grid, "images/%d.png" % batches_done, normalize=False)

if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
    # Save model checkpoints
    torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
    torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch)
    

for epoch in range(start_epoch, epochs):

    for idx,(lr_img,hr_img,label) in enumerate(labeled_Dataloader):
        lr_img,hr_img,label = lr.cuda(),hr.cuda(),label.cuda()

        # Adversarial ground truths
        valid = torch.Tensor(np.ones((lr_img.size(0), *discriminator.output_shape)))
        fake = torch.Tensor(np.zeros((hr_img.size(0), *discriminator.output_shape)))
        break
        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)

        # Adversarial loss
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())

        # Total loss
        loss_G = loss_content + 1e-3 * loss_GAN

        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss of real and fake images
        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        sys.stdout.write(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, epochs, i, len(labeled_Dataloader), loss_D.item(), loss_G.item())
        )

        batches_done = epoch * len(labeled_Dataloader) + i
        if batches_done % opt.sample_interval == 0:
            # Save image grid with upsampled inputs and SRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, gen_hr), -1)
            save_image(img_grid, "images/%d.png" % batches_done, normalize=False)
    break
    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        
        # Save model checkpoints
        torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
        torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch)

xx = torch.rand(1,1,20,20)
y =torch.cat([xx,xx,xx],1)
print(xx.size(), y.size())