In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch 
import torchvision
import pickle
import numpy as np 
import shutil
import matplotlib.pyplot as plt

import torch.optim
from scipy.io.idl import AttrDict

from PIL import Image


# IMAGE UNZIP

In [None]:
import zipfile
import os
from os import walk

extract_x64 = False 

if extract_x64: 
  _, _, filenames = next(walk("."))

  for filename in filenames: 
    print(filename)
    with zipfile.ZipFile(filename, 'r') as zip_ref:
      zip_ref.extractall(".")

# ARCHITECTURE

In [None]:
class generator (torch.nn.Module):

  def __init__( self):
    super(generator, self).__init__()
    self.Conv1 = torch.nn.Conv2d(3, 64, 9, padding=4)

    self.residual_block = torch.nn.Sequential()
    for i in range(16):
      self.residual_block = torch.nn.Sequential(*list(self.residual_block) + [ResidualBlock()])

    self.Conv2 = torch.nn.Conv2d(64, 64, 3, padding=1 )
    self.Conv3 = torch.nn.Conv2d(64, 256, 3, padding=1 )
    self.Conv4 = torch.nn.Conv2d(64, 256, 3, padding=1 )
    self.ConvFinal = torch.nn.Conv2d(64, 3, 9, padding=4)
    self.PixelShuffle = torch.nn.PixelShuffle(upscale_factor=2)
    self.PreLU = torch.nn.PReLU();
    self.BatchNorm = torch.nn.BatchNorm2d(num_features=64)


  def forward(self, x ):

    x = self.Conv1(x)
    x = self.PreLU(x)
    
    x1 = x
    x2 = x

    self.residual_block(x1)

    x = self.Conv2(x)
    x = self.BatchNorm(x)
    x = x + x2 

    x = self.Conv3(x)
    x = self.PixelShuffle(x)
    x = self.PreLU(x)
    
    x = self.Conv4(x)
    x = self.PixelShuffle(x)
    x = self.PreLU(x)

    x = self.ConvFinal(x)
    return x 



class ResidualBlock (torch.nn.Module):
  def __init__( self):
    super(ResidualBlock, self).__init__()
    self.block = torch.nn.Sequential(
        torch.nn.Conv2d(64, 64, 3, padding=1),
        torch.nn.BatchNorm2d(64),
        torch.nn.PReLU(),
        torch.nn.Conv2d(64, 64, 3, padding=1),
        torch.nn.BatchNorm2d(64),
    )

  def forward(self, x ):
    return self.block(x) + x 



In [None]:
class discriminator (torch.nn.Module):

  def __init__(self):
    super(discriminator, self).__init__()
    self.Conv1 = torch.nn.Conv2d(3, 64, 3)
    self.Conv2 = torch.nn.Conv2d(64, 64, 3, stride=2)
    self.BatchNorm2 = torch.nn.BatchNorm2d(num_features=64)
    self.Conv3 = torch.nn.Conv2d(64, 128, 3)
    self.BatchNorm3 = torch.nn.BatchNorm2d(num_features=128)
    self.Conv4 = torch.nn.Conv2d(128, 128, 3, 2)
    self.BatchNorm4 = torch.nn.BatchNorm2d(num_features=128)
    self.Conv5 = torch.nn.Conv2d(128, 256, kernel_size=3)
    self.BatchNorm5 = torch.nn.BatchNorm2d(num_features=256)
    self.Conv6 = torch.nn.Conv2d(256, 256, kernel_size=3,stride=2)
    self.BatchNorm6 = torch.nn.BatchNorm2d(num_features=256)
    self.Conv7 = torch.nn.Conv2d(256, 512, kernel_size=3)
    self.BatchNorm7 = torch.nn.BatchNorm2d(num_features=512)
    self.Conv8 = torch.nn.Conv2d(512, 512, kernel_size=3, stride=2)
    self.BatchNorm8 = torch.nn.BatchNorm2d(num_features=512)

    self.hidden = torch.nn.Linear(in_features=512, out_features=1024)
    self.final_classifier = torch.nn.Linear(in_features=1024, out_features=1)
    self.leakyReLU = torch.nn.LeakyReLU(negative_slope=0.2)
    self.dropout = torch.nn.Dropout(p=0.3)


  def forward(self, x ):
    x = self.leakyReLU(self.Conv1(x))
    x = self.leakyReLU(self.BatchNorm2(self.Conv2(x)))
    x = self.leakyReLU(self.BatchNorm3(self.Conv3(x)))
    x = self.dropout(x)
    x = self.leakyReLU(self.BatchNorm4(self.Conv4(x)))
    x = self.leakyReLU(self.BatchNorm5(self.Conv5(x)))
    x = self.dropout(x)
    x = self.leakyReLU(self.BatchNorm6(self.Conv6(x)))
    x = self.leakyReLU(self.BatchNorm7(self.Conv7(x)))
    x = self.leakyReLU(self.BatchNorm8(self.Conv8(x)))
    x = x.reshape((x.shape[0], -1))
    x = self.dropout(x)
    x = self.leakyReLU(self.hidden(x))
    x = self.final_classifier(x)
    return torch.sigmoid(x)

  

In [None]:
class vgg_loss ():

  def __init__(self, device):
    super(vgg_loss, self).__init__()
    vgg = torchvision.models.vgg19(pretrained=True)
    vgg = vgg.features

    # Freeze all vgg layers 
    for param in vgg.parameters():
      param.requires_grad = False 
    
    self.vgg = vgg[:27].to(device) 
    self.MSE = torch.nn.MSELoss()

  def __call__(self, input ):
    return self.vgg(input)




# DATA LOADERS



In [None]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo)
    return dict 

In [None]:
def get_images_x64(filename):
  d = unpickle(filename)
  x = d['data']
  x = np.dstack((x[:, :4096], x[:, 4096:8192], x[:, 8192:]))
  x = x.reshape((x.shape[0], 64, 64, 3))
  return x

In [None]:
data_dir = '/content/drive/MyDrive/CSC413/x64/'

class ImageNetSR(torch.utils.data.Dataset):

    def __init__(self, train=False,  batch=""):
        self.train = train
        self.transform = torchvision.transforms.ToTensor() 
        self.downsample = torchvision.transforms.Resize((16,16), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)

        if train: 
          self.dir = data_dir + batch
        else: 
          self.dir = data_dir + "val_data" 

        self.X = get_images_x64( self.dir)
        

    def __len__(self):
        return self.X.shape[0]


    def __getitem__(self, idx):
        x = self.X[idx]
        image = self.transform(x)
        downsampled = self.downsample(image)
        sample = {'HR': image, 'LR': downsampled}
        return sample

# TRAINING CODE

## Helper Functions

In [None]:
def gan_checkpoint(path, G, D):
    """Saves the parameters of the generator G and discriminator D.
    """
    G_path = os.path.join(path, 'G.pkl')
    D_path = os.path.join(path, 'D.pkl')
    torch.save(G.state_dict(), G_path)
    torch.save(D.state_dict(), D_path)

def load_checkpoint(opts):
    """Loads the generator and discriminator models from checkpoints.
    """
    G_path = os.path.join(opts.load, 'G.pkl')
    D_path = os.path.join(opts.load, 'D.pkl')

    G = generator()
    D = discriminator()

    G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
    D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))

    G.to(opts.device)

    return G, D

In [None]:
"""
CREDITS: CSC413 PA4 DCGAN
"""

def train(opts):
    """Runs the training loop.
        * Saves checkpoint every opts.checkpoint_every iterations
        * Saves generated samples every opts.sample_every iterations
    """
    # Create generators and discriminators
    G = generator().to(opts.device)
    D = discriminator().to(opts.device)

    g_params = G.parameters()  # Get generator parameters
    d_params = D.parameters()  # Get discriminator parameters

    # Create optimizers for the generators and discriminators
    g_optimizer = torch.optim.Adam(g_params, opts.lr, [opts.beta1, opts.beta2])
    d_optimizer = torch.optim.Adam(d_params, opts.lr, [opts.beta1, opts.beta2])

    print("loading training ... ")
    batch = 1 
    train_set = ImageNetSR(train=True, batch="train_data_batch_" + str(batch))
    train_loader = torch.utils.data.DataLoader(train_set, opts.batch_size,
                              shuffle=True, num_workers=opts.num_workers)

    print("loading validation ... ")
    val_set = ImageNetSR(train=False)
    val_loader = torch.utils.data.DataLoader(val_set, opts.batch_size,
                              shuffle=True, num_workers=opts.num_workers)


    train_iter = iter(train_loader)
    test_iter = iter(val_loader)

    iter_per_epoch = len(train_iter)
    total_train_iters = opts.train_iters
 
    # adversarial_loss = torch.nn.BCEWithLogitsLoss()
    gp_weight = 1

    losses ={}
    losses['iteration'] = []
    losses['D_loss'] = []
    losses['G_loss'] = []


    VGG = vgg_loss(opts.device)
    MSE = torch.nn.MSELoss()
    BCE = torch.nn.BCEWithLogitsLoss()


    data_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.ToTensor()
    ])


    image_name = "/content/drive/MyDrive/CSC413/LR_Test.jpg"
    og_image = Image.open(image_name)
    image = data_transforms(og_image).float()
    image = torch.tensor(image, requires_grad=True)

    image = image.unsqueeze(0).to(opts.device)


    try:
        for iteration in range(1, opts.train_iters + 1):

            G.train()
            g_optimizer.zero_grad()

            # Reset data_iter for each epoch
            if iteration % iter_per_epoch == 0:
              
              if batch == 10: batch = 0

              batch += 1
              print("loading train_data_batch_" + str(batch))
              train_set = ImageNetSR(train=True, batch="train_data_batch_" + str(batch))
              train_loader = torch.utils.data.DataLoader(train_set, opts.batch_size,
                                        shuffle=True, num_workers=opts.num_workers)
              train_iter = iter(train_loader)
              

            next = train_iter.next()
            HR =  next['HR']
            LR = next['LR']
            HR = HR.to(opts.device)
            LR = LR.to(opts.device)

            # ------------------------ DISCRIMINATOR LOSS ---------------------


            d_optimizer.zero_grad()
  
            HRE = G(LR)
            logits_HRE = D(HRE)
            logits_HR = D(HR)
            
            D_1 = torch.mean(logits_HR)
            D_2 = torch.mean(logits_HRE)                
            D_loss = 1 - D_1 + D_2 

            if D_loss > 0.1: 
              D_loss.backward(retain_graph=True)
              d_optimizer.step()


            # ------------------------ GENERATOR LOSS -----------------------

            g_optimizer.zero_grad()

            HRE = G(LR)
            logits_HRE = D(HRE)

            vgg_HR = VGG(HR)
            vgg_HRE = VGG(HRE)

            G_1 = torch.mean(1 - logits_HRE)
            G_VGG = MSE(vgg_HRE, vgg_HR)
            G_MSE = MSE(HRE, HR)
            G_loss = G_MSE + 0.006 * G_VGG + 0.001 * G_1 
            G_loss.backward()
            g_optimizer.step()


            # ------------------------  CHECKPOINTS & SAMPLING -------------------

            if iteration % opts.log_step == 0:
                losses['iteration'].append(iteration)
                losses['D_loss'].append(D_loss.item())
                losses['G_loss'].append(G_loss.item())
                print('Iteration [{:4d}/{:4d}] | D_loss: {:6.4f} |  G_loss: {:6.4f}'.format(
                    iteration, total_train_iters, D_loss.item(), G_loss.item()))

            # Save the model parameters
            if iteration % opts.checkpoint_every == 0:
                checkpoint = opts.checkpoint

                dir = checkpoint + str(iteration)
                if os.path.exists(dir):
                  shutil.rmtree(dir)
                os.mkdir(dir)
                gan_checkpoint(dir, G, D)

                G.eval()

                image_est = (G(image).clone().detach().cpu().numpy())
                image_est = np.squeeze(image_est, axis=0)
                image_est = np.transpose(image_est, (1,2,0) )

                f, axarr = plt.subplots(1,2)
                axarr[0].imshow(og_image)
                axarr[1].imshow(image_est)
                f.savefig(dir + '/sample.png', bbox_inches='tight')


    except KeyboardInterrupt:
        print('Exiting early from training.')
        return G, D

    return G, D

## Hyperparameters and running loop

In [None]:
opts = AttrDict()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

args_dict = {
    'checkpoint': "/content/drive/MyDrive/CSC413/checkpoints/",
    'lr': 0.0001,
    'beta1':0.9,
    'beta2':0.999,
    'batch_size': 32,
    'device': device, 
    'epochs': 12,
    'num_workers': 4,
    'resume': False,
    'log_step': 100, 
    'checkpoint_every':1000,
    'train_iters':200000,
    'batch' : 'train_data_batch_1'
}

torch.autograd.set_detect_anomaly(True)
opts.update(args_dict)
train(opts)

# TEST IMAGE CODE

In [None]:
from torchvision.utils import save_image
from tifffile import imsave
import cv2 




def torch_to_saveable_image(image_est):
  image_est = image_est.cpu().clone().detach().numpy()
  image_est = np.transpose(image_est, (1,2,0))
  image_est = cv2.normalize(image_est, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
  image_est = cv2.cvtColor(image_est, cv2.COLOR_RGB2BGR)
  return image_est


use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")




load_opts = AttrDict()
args_dict = {
    'load': '/content/drive/MyDrive/CSC413/checkpoints/17000',
    'device': device
}

load_opts.update(args_dict)


G, _ = load_checkpoint(load_opts)

data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((64,64))
])


filename = "obama"
filetype = ".jpg"
directory = "/content/drive/MyDrive/CSC413/checkpoints/172000/"


image_name = directory + filename  + filetype
og_image = Image.open(image_name)
image = data_transforms(og_image).float()


image = torch.tensor(image, requires_grad=True)
image = image.unsqueeze(0).to(device)



image_est = G(image).squeeze(0)
image_est = torch_to_saveable_image(image_est)



bicubic = torchvision.transforms.Resize((256,256))
image =  image.squeeze(0)
bicubic = bicubic(image)
bicubic = torch_to_saveable_image(bicubic)

plt.imshow(image_est)
cv2.imwrite(directory + filename + "_SRGAN.png", image_est)
cv2.imwrite(directory + filename + "_bicubic.png", bicubic)

