In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.utils.data as Data
import torchvision.transforms as T
from glob import glob
import os
import cv2
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision.utils import save_image
import pandas as pd
import time
import pickle
%matplotlib inline

to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.ToTensor(),T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.n_epoch = 50
        self.batch_size = 64
        self.z_dim = 64
        self.n_dim = 64
        self.y_dim = 2
        self.n_ch = 3
        self.D_lr = 0.0004
        self.G_lr = 0.0001# 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.img_size = 64
        self.model_save_freq = 1000
        self.img_save_freq = 100
        self.show_freq = 50
        self.model_path = './SuperResolution3/Model/'
        self.img_path = './SuperResolution3/Image/' 
        self.conv_dim = 64
        self.d_dim = 64
        self.D_out_dim = 16
        self.train_img_path = "./data/celeba/"
        self.num_res = 1
        self.D_mode = ''
        self.G_mode = ''
        self.L_mode = ''
        self.k = 0
        self.lam = 0.001
        self.gamma = 0.5
        self.model_name = 'SuperResolutio'
        self.ds_rate = 4
        
args = Parser()  

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.img_path):
    os.makedirs(args.img_path)
    
    
class CelebADataset(Data.Dataset):
    def __init__(self, mode='train', args=None):
        
        self.image_transform = T.Compose([
            T.Resize((args.img_size,args.img_size)),
            #T.RandomResizedCrop(args.img_size, scale=(1.0,1.0)),
            #T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        self.down_sample_t = T.Compose([
            T.Resize((args.img_size//args.ds_rate, args.img_size//args.ds_rate)),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        self.flip = T.Compose([
            T.RandomHorizontalFlip(),
        ])
        
        with open("Path_Male.pickle", 'rb') as f:
            self.path ,self.male= pickle.load(f)
            print('Loaded')
        self.path = self.path[:-2000] if mode == 'train' else self.path[-2000:]
        self.male = self.male[:-2000] if mode == 'train' else self.male[-2000:]
        
    def __getitem__(self, index):
        
        idx = index % len(self.path)
        img = self.flip(Image.open(os.path.join(args.train_img_path,
                                                             self.path[idx])))
        
        o_img = self.image_transform(img)
        d_img = self.down_sample_t(img)
        
        #is_male = self.male[idx]
        return o_img, d_img #is_male

    def __len__(self):
        return len(self.path)
    

#

In [None]:
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
    
    def __init__(self, model, name = 'weight', power_iteration = 1):
        super(SpectralNorm, self).__init__()
        self.model = model
        self.name = name
        self.power_iteration = power_iteration
        self.register_params()
        
    def register_params(self):
        
        w = getattr(self.model, self.name)
        height = w.shape[0]
        width = w.view(height, -1).shape[1]
        u = nn.Parameter(w.new(height).normal_(0,1), requires_grad = False)
        v = nn.Parameter(w.new(width).normal_(0,1), requires_grad = False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data) # use .data keep it as nn.Parameters
        w_bar = nn.Parameter(w)
        del self.model._parameters[self.name]
        
        self.model.register_parameter(self.name+'_u',u)
        self.model.register_parameter(self.name+'_v',v)
        self.model.register_parameter(self.name+'_bar',w_bar)
        
    def update_u_v(self,):
        
        u = getattr(self.model, self.name + "_u")
        v = getattr(self.model, self.name + "_v")
        w = getattr(self.model, self.name + "_bar")
        
        height = w.shape[0]
        
        for _ in range(self.power_iteration):
            v = l2normalize(torch.mv(torch.t(w.view(height,-1)),u))
            u = l2normalize(torch.mv(w.view(height,-1),v))
        
        sigma = u.dot(w.view(height,-1).mv(v))
        setattr(self.model, self.name, w/ sigma.expand_as(w))
    
    def forward(self,x):
        self.update_u_v()
        return self.model(x)
    
class Self_Attn(nn.Module):
    # Residule like structure
    def __init__(self, dim):
        super(Self_Attn, self).__init__()
        self.dim = dim
        
        self.query_conv = nn.Conv2d(dim, dim//8, 1)
        self.key_conv = nn.Conv2d(dim, dim//8, 1)
        self.value_conv = nn.Conv2d(dim, dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        self.sm = nn.Softmax(dim= -1)
        
        
    def forward(self, x):
        
        batch_size, C, width, height = x.size()
        ## N = Width*Height
        
        query_out = self.query_conv(x).view(batch_size,-1,width*height).permute(0,2,1) # B,N,C
        key_out = self.key_conv(x).view(batch_size,-1,width*height) # B,C,N
        energy = torch.bmm(query_out,key_out)
        attention = self.sm(energy) # B,N,N
        #print(attention.shape)
        value_out = self.value_conv(x).view(batch_size,-1,width*height) # B,C,N
        out = torch.bmm(value_out,attention.permute(0,2,1)).view(batch_size,C,width,height)
        out = self.gamma * out + x
        
        return out #, attention
    
class SpectralNormConvT(nn.Module):
    def __init__(self, in_dim, out_dim, k, s=1, p=0 ):
        super(SpectralNormConvT, self).__init__()
        self.model = nn.Sequential(
            SpectralNorm(nn.ConvTranspose2d(in_dim, out_dim, k, s, p)),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.model(x)
    
    
class SpectralNormConvUp(nn.Module):
    def __init__(self, in_dim, out_dim, k, s=1, p=0 ):
        super(SpectralNormConvUp, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            SpectralNorm(nn.Conv2d(in_dim, out_dim, k, s, p)),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.model(x)
    
class SpectralNormConvBN(nn.Module):
    def __init__(self, in_dim, out_dim, k, s=1, p=0 ):
        super(SpectralNormConvBN, self).__init__()
        self.model = nn.Sequential(
            SpectralNorm(nn.Conv2d(in_dim, out_dim, k, s, p)),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.1, inplace=True),
        )
        
    def forward(self,x):
        return self.model(x)

class SpectralNormConv(nn.Module):
    def __init__(self, in_dim, out_dim, k, s=1, p=0 ):
        super(SpectralNormConv, self).__init__()
        self.model = nn.Sequential(
            SpectralNorm(nn.Conv2d(in_dim, out_dim, k, s, p)),
            nn.LeakyReLU(0.1, inplace=True),
        )
        
    def forward(self,x):
        return self.model(x)

In [None]:
# class ResBlock(nn.Module):
#     def __init__(self):
#         super(ResBlock, self).__init__()
        
#         self.init = True
#     def forward(self, x):
#         if self.init:
#             self._init_model(x)

#         return F.relu(x+self.model(x))
    
#     def _init_model(self, x):
#         dim = x.size(1)
#         self.model = nn.Sequential(
#             SpectralNorm(nn.Conv2d(dim, dim, 3, 1, 1)),
#             nn.BatchNorm2d(dim),
#             nn.ReLU(inplace=True),
#             SpectralNorm(nn.Conv2d(dim, dim, 3, 1, 1)),
#             nn.BatchNorm2d(dim),
#         )
#         self.init = False
#         print('initialised')
        
        
class ResBlock(nn.Module):
    def __init__(self,dim):
        super(ResBlock, self).__init__()
        self.model = nn.Sequential(
            SpectralNorm(nn.Conv2d(dim, dim, 3, 1, 1)),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True),
            SpectralNorm(nn.Conv2d(dim, dim, 3, 1, 1)),
            nn.BatchNorm2d(dim),
        )
        
    def forward(self, x):


        return F.relu(x+self.model(x))
    
    

class Generator(nn.Module):
    def __init__(self,):
        super(Generator, self).__init__()
        
        res_seq = [ResBlock(args.n_dim) for i in range(args.num_res)]
        res_seqx2 = [ResBlock(args.n_dim*2) for i in range(args.num_res)]
        res_seqx4 = [ResBlock(args.n_dim*4) for i in range(args.num_res)]

        self.model = nn.Sequential(
            SpectralNormConvBN(args.n_ch, args.n_dim*4,3,1,1),
            *res_seqx4,
            Self_Attn(args.n_dim*4),
            #SpectralNormConvT(args.n_dim*4, args.n_dim*2,4,2,1),
            SpectralNormConvUp(args.n_dim*4, args.n_dim*2,3,1,1),
            *res_seqx2,
            Self_Attn(args.n_dim*2),
            #SpectralNormConvT(args.n_dim*2, args.n_dim,4,2,1),
            SpectralNormConvUp(args.n_dim*2, args.n_dim,3,1,1),
            *res_seq,
            nn.Conv2d(args.n_dim,args.n_ch,3,1,1),
            nn.Tanh()
        )
        
    def forward(self,x): # 
        
        return self.model(x)
    

# g = Generator()        
# g()       



class Discriminator(nn.Module):
    def __init__(self,):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            SpectralNormConv(3, args.n_dim, 4,2,1),
            SpectralNormConv(args.n_dim, args.n_dim*2, 4,2,1),
            SpectralNormConv(args.n_dim*2, args.n_dim*4, 4,2,1),
            Self_Attn(args.n_dim*4),
            SpectralNormConv(args.n_dim*4, args.n_dim*8, 4,2,1),
            Self_Attn(args.n_dim*8),
            nn.Conv2d(args.n_dim*8, 1, 4),
        )

        
    def forward(self,x):
        return self.model(x)
        
# D = Discriminator()        
# D(torch.randn(5,3,64,64))      



class SAGAN(nn.Module):
    
    def __init__(self,):
        super(SAGAN, self).__init__()
        
        self.D = Discriminator()
        self.G = Generator()
        
        self.G_optim = optim.Adam(filter( lambda p : p.requires_grad, self.G.parameters()), lr = args.G_lr , betas= (args.b1 ,args.b2))
        self.D_optim = optim.Adam(filter( lambda p : p.requires_grad, self.D.parameters()), lr = args.D_lr , betas= (args.b1 ,args.b2))
        
        self.BCE = nn.BCELoss()
        self.CE = nn.CrossEntropyLoss()
        self.L1 = nn.L1Loss()
        
        self.train_hist = {}

        self.train_hist['G_loss'] = []
        self.train_hist['D_loss'] = []
        
        #self.apply(self.weight_init)
        self.progress_photo = []
        
        
    def forward(self, o_img, d_img):
        
        o_img = o_img.to(device)
        d_img = d_img.to(device)
        
        ############# Train D #############
        
        self.D_optim.zero_grad()
        D_real= self.D(o_img)
        D_real_loss = F.relu(1.0 - D_real).mean() # Real D gotta be larger
        
        self.G_img = self.G(d_img)
        D_fake = self.D(self.G_img.detach())
        D_fake_loss = F.relu(1.0 + D_fake).mean()
        
        self.D_loss = D_real_loss + D_fake_loss
        self.train_hist['D_loss'].append(self.D_loss.item())
        self.D_loss.backward()
        self.D_optim.step()
        
        ############# Train G #############
        
        self.G_optim.zero_grad()
        
        self.G_img = self.G(d_img)
        D_fake = self.D(self.G_img)
        
        self.l1loss = self.L1(F.adaptive_avg_pool2d(self.G_img, args.img_size//args.ds_rate),d_img)
        
        self.G_loss = - D_fake.mean() + self.l1loss*10
        self.train_hist['G_loss'].append(self.G_loss.item())
        self.G_loss.backward()
        self.G_optim.step()
        
        
        self.progress_photo.append(self.G_img[0].detach())
        self.progress_photo = self.progress_photo[-args.img_save_freq:]
        
    def weight_init(self,m):
        if type(m) in [nn.Conv2d, nn.ConvTranspose2d, nn.Linear]:
            #nn.init.xavier_normal_(m.weight,nn.init.calculate_gain('leaky_relu',param=0.02))
            nn.init.kaiming_normal_(m.weight,0.2,nonlinearity='leaky_relu')
            
    def image_save(self, step):
        
        img_save_path = args.img_path + args.model_name+"_Step_"+str(step)+".png"
        save_image( torch.stack(self.progress_photo[:args.img_save_freq]), img_save_path , nrow=10, normalize=True, range=(-1,1))
        print('Image saved')  
        
    def model_save(self,step):
        path = args.model_path + args.model_name+'_Step_' + str(step) + '.pth'
        torch.save({args.model_name:self.state_dict()}, path)
        print('Model saved')
        
    def load_step_dict(self, step):
        
        path = args.model_path + args.model_name +'_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)[args.model_name])
 
    def plot_all_loss(self,step):
        
        fig, ax = plt.subplots(figsize= (20,8))
        for k in self.train_hist.keys():
            plt.plot(self.train_hist[k], label= k)
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig( args.model_name +"_Loss_"+str(step)+".png")
        
    def num_all_params(self,):
        return sum([param.nelement() for param in self.parameters()])


In [None]:
dataset = CelebADataset(mode='train',args= args)
training_loader = DataLoader(dataset,batch_size=args.batch_size,shuffle=True,drop_last=True,pin_memory=True)


In [None]:
gan = SAGAN().to(device)
epoch = 0
all_steps = 1

In [None]:
while epoch < args.n_epoch:
    for i, (o_img, d_img) in enumerate(training_loader):    
        
        start_t = time.time()
        gan(o_img, d_img)
        end_t = time.time()
        
#         G_scheduler.step()
#         D_scheduler.step()
        
        print('| Step [%d] | lr [%.6f] | D Loss: [%.4f] | G Loss: [%.4f] | L1 Loss: [%.4f] | Time: %.1fs' %\
              ( all_steps, gan.G_optim.param_groups[0]['lr'], gan.D_loss.item(), gan.G_loss.item(),
               gan.l1loss.item(), end_t - start_t))


        if all_steps % args.show_freq == 0: #args.show_freq
            fig=plt.figure(figsize=(8, 8))
            fig.add_subplot(1,3,1)
            plt.imshow(to_img(d_img[0].cpu()*0.5+0.5))
            fig.add_subplot(1,3,2)
            plt.imshow(to_img(gan.G_img[0].cpu()*0.5+0.5))
            fig.add_subplot(1,3,3)
            plt.imshow(to_img(o_img[0].cpu()*0.5+0.5))
            plt.show()
            if all_steps % args.img_save_freq ==0: # args.img_save_freq
                gan.image_save(all_steps)
                gan.plot_all_loss('Training')
                if all_steps % args.model_save_freq == 0: #args.model_save_freq
                    gan.model_save(all_steps)
        all_steps += 1
        if all_steps > 5000:
            raise StopIteration
    epoch +=1


In [None]:
gan.G_optim.param_groups[0]['lr'] = 0.0001
gan.D_optim.param_groups[0]['lr'] = 0.0004