In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision.models as models
import torchvision.transforms as T

In [None]:
from collections import namedtuple

In [None]:
class Vgg19(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 32):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        h = self.slice1(x)
        h_relu_1_2 = h
        h = self.slice2(h)
        h_relu_5_2 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu_1_2', 'relu_5_2'])
        out = vgg_outputs(h_relu_1_2, h_relu_5_2)
        return out

In [3]:
class ConvUp(nn.Module):
    def __init__(self, input_nc, output_nc, kernel_size, stride, padding, batch_norm, dtype):
        super(ConvUp, self).__init__()
        features = [nn.Conv2d(input_nc, output_nc, kernel_size, stride, padding)]
        if batch_norm:
            features += [nn.BatchNorm2d(output_nc)]
        features += [nn.LeakyReLU(0.2, True)]
        self.features = nn.Sequential(*features).type(dtype)

    def forward(self, x):
        return self.features(x)
    

class DeconvDown(nn.Module):
    def __init__(self, input_nc, output_nc, kernel_size, stride, padding, batch_norm, dropout, dtype):
        super(DeconvDown, self).__init__()
        features = [nn.ConvTranspose2d(input_nc, output_nc, kernel_size, stride, padding)]
        if batch_norm:
            features += [nn.BatchNorm2d(output_nc)]
        if dropout:
            features += [nn.Dropout2d(0.5)]
        features += [nn.LeakyReLU(0.2, True)]
        self.features = nn.Sequential(*features).type(dtype)

    def forward(self, x, eo):
        x = torch.cat([x, eo], 1)
        return self.features(x)

In [2]:
class Generator(nn.Module):
    def __init__(self, dtype):
        super(Generator, self).__init__()
        self.dtype = dtype
        
        self.e1 = ConvUp(3, 64, 4, 2, 1, True, dtype)
        self.e2 = ConvUp(64, 128, 4, 2, 1, True, dtype)
        self.e3 = ConvUp(128, 256, 4, 2, 1, True, dtype)
        self.e4 = ConvUp(256, 512, 4, 2, 1, True, dtype)
        self.e5 = ConvUp(512, 512, 4, 2, 1, True, dtype)
        self.e6 = ConvUp(512, 512, 4, 2, 1, True, dtype)
        self.e7 = ConvUp(512, 512, 4, 2, 1, True, dtype)
        self.e8 = ConvUp(512, 512, 4, 2, 1, True, dtype)
        
        self.d1 = DeconvDown(512, 512, 4, 2, 1, True, True, dtype)
        self.d2 = DeconvDown(512*2, 512, 4, 2, 1, True, True, dtype)
        self.d3 = DeconvDown(512*2, 512, 4, 2, 1, True, True, dtype)
        self.d4 = DeconvDown(512*2, 512, 4, 2, 1, True, False, dtype)
        self.d5 = DeconvDown(512*2, 256, 4, 2, 1, True, False, dtype)
        self.d6 = DeconvDown(256*2, 128, 4, 2, 1, True, False, dtype)
        self.d7 = DeconvDown(128*2, 64, 4, 2, 1, True, False, dtype)
        self.d8 = DeconvDown(64*2, 3, 4, 2, 1, True, False, dtype)
        
    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(e1)
        e3 = self.e3(e2)
        e4 = self.e4(e3)
        e5 = self.e5(e4)
        e6 = self.e6(e5)
        e7 = self.e7(e6)
        e8 = self.e8(e7)
        
        d1 = self.d1(e8, torch.Tensor([]).type(self.dtype))
        d2 = self.d2(d1, e7)
        d3 = self.d3(d2, e6)
        d4 = self.d4(d3, e5)
        d5 = self.d5(d4, e4)
        d6 = self.d6(d5, e3)
        d7 = self.d7(d6, e2)
        d8 = self.d8(d7, e1)
        
        return d8

In [4]:
class Discriminator(nn.Module):
    def __init__(self, dtype):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            # Conv1 128x128
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            # Conv2 64x64
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            # Conv3 32x32
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            # Conv4 16x16
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # Conv5 8x8
            nn.Conv2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # Conv6 4x4
            nn.Conv2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # Conv7 4x4
            nn.Conv2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # Conv8 1x1
            nn.Conv2d(512, 1, 4, 2, 1),
            nn.Sigmoid()
        ).type(dtype)
    
    def forward(self, x):
        return self.features(x).squeeze()
        

In [5]:
dtype = torch.cuda.FloatTensor

G = Generator(dtype)
D = Discriminator(dtype)

In [12]:
from torch.utils.data import DataLoader, Dataset
import matplotlib.image as mpimg
import os

In [9]:
class DistortionDataset(Dataset):
    def __init__(self, dis_dir, raw_dir):
        self.dis_dir = dis_dir
        self.raw_dir = raw_dir
        self.dis_files = os.listdir(dis_dir)
        self.raw_files = os.listdir(raw_dir)
        self.transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def __getitem__(self, index):
        x = mpimg.imread(self.dis_dir + self.dis_files[index])
        y = mpimg.imread(self.raw_dir + self.raw_files[index])
        x = self.transform(x)
        y = self.transform(y)
        return x, y
    
    def __len__(self):
        return len(self.raw_files)

In [10]:
dataset = DistortionDataset('data/paired_dataset256/images_raw256/', 'data/paired_dataset256/images_dis256/')
dataloader = DataLoader(dataset, batch_size=5)

In [13]:
for x,y in dataloader:
    print(x.shape, y.shape)
    break

torch.Size([5, 3, 256, 256]) torch.Size([5, 3, 256, 256])


https://github.com/aitorzip/PyTorch-SRGAN/blob/master/train

In [None]:
class Trainer:
    def __init__(self, G, G_optimizer, D, D_optimizer, dtype):
        self.G = G
        self.D = D
        self.G_optimizer = G_optimizer
        self.D_optimizer = D_optimizer
        self.vgg19 = Vgg19()
        self.dtype = dtype
        self.MSE_loss = nn.MSELoss().type(dtype) # Content Criterion
        self.BCE_loss = nn.BCELoss().type(dtype) # Adversarial criterion
        
    def train(self, num_epochs, dataloader):
        for epoch in range(num_epochs):
            print('Starting epoch %d/%d' %(epoch+1, num_epochs))
            self.G.train()
            self.D.train()
            for dis_x, raw_y in dataloader:
                dis_x = Variable(dis_x)
                raw_y = Variable(raw_y)
                
                fake_x = self.G(dis_x)
                ### Train Discriminator
                for _ in range(10):
                    dis_x_result = self.Discriminator(dis_x)
                    fake_x_result = self.Discriminator(fake_x)
                    dis_x_loss = self.BCE_loss(
                        dis_x_result, 
                        Variable(torch.ones(dis_x_result.size())).type(self.dtype)
                    )
                    fake_x_loss = self.BCE_loss(
                        fake_x_result, 
                        Variable(torch.zeros(fake_x_result.size())).type(self.dtype)
                    )
                    D_loss = dis_x_loss + fake_x_result
                    self.D.zero_grad()
                    D_loss.backward()
                    self.D.step()
                
                ### Train Generator
                raw_y_feature = self.vgg19(raw_y)
                fake_x_feature = self.vgg19(fake_x)
                relu_1_2_loss = self.BCE_loss(
                    raw_x_feature.h_relu_1_2,
                    fake_x_feature.h_relu_1_2
                )
                relu_5_2_loss = self.BCE_loss(
                    raw_x_feature.h_relu_5_2,
                    fake_x_feature.h_relu_5_2
                )
                G_loss = self.BCE_loss(raw_y, dis_x)
                G_total_loss = G_loss + relu_1_2_loss + relu_5_2_loss
                self.G.zero_grad()
                G_total_loss.backward()
                self.G_optimizer.step()
                