In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
import torchvision.utils as vutils
import os


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# import warnings
# warnings.filterwarnings("ignore")

In [2]:
# UTILS

from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def train_hr_transform(crop_size):
    return transforms.Compose([
        transforms.RandomCrop(crop_size),
        transforms.ToTensor()])

def train_lr_transform(crop_size, upscale_factor):
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(crop_size // upscale_factor, interpolation = Image.BICUBIC),
        transforms.ToTensor()
    ])

def display_transform():
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(400),
        transforms.CenterCrop(400),
        transforms.ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)
        
    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image
    
    def __len__(self):
        return len(self.image_filenames)
    
class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        
    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w,h),self.upscale_factor)
        lr_scale = transforms.Resize(crop_size // self.upscale_factor, interpolation = Image.BICUBIC)
        hr_scale = transforms.Resize(crop_size, interpolation = Image.BICUBIC)
        hr_image = transforms.CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return transforms.ToTensor()(lr_image), transforms.ToTensor()(hr_restore_img), transforms.ToTensor()(hr_image)
    
    def __len__(self):
        return len(self.image_filenames)
    
class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
        
    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w,h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = transforms.Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, transforms.ToTensor()(lr_image), transforms.ToTensor()(hr_restore_img), transforms.ToTensor()(hr_image)
        

In [3]:
# LOSS

from torchvision.models.vgg import vgg19


class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg19(pretrained = True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()
        
    def forward(self, out_labels, out_images, target_images):
        adversarial_loss = torch.mean(1 - out_labels)
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        image_loss = self.mse_loss(out_images, target_images)
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
    
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight
        
    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:] - x[:,:,:h_x - 1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:] - x[:,:,:,:w_x - 1]),2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
    
    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]
    
if __name__ == "__main__":
    g_loss = GeneratorLoss()
    print(g_loss)

GeneratorLoss(
  (loss_network): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), strid

In [4]:
# RESIDUAL BLOCK

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels,channels, kernel_size=3,padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        
    def forward(self,x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        
        return x + residual
    
    
    
# UPSAMPLE BLOCK

class UpsampleBlock(nn.Module):
    def __init__(self,in_channels, up_scale):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels,in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        self.prelu(x)
        return x
    
    
    
# GENERATOR

import math


class Generator(nn.Module):
    
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor,2))
        
        super(Generator, self).__init__()
        
        self.block1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=9,padding=4),nn.PReLU())
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(nn.Conv2d(64,64,kernel_size=3,padding=1), nn.BatchNorm2d(64))
        block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64,3,kernel_size=9,padding=4))
        self.block8 = nn.Sequential(*block8)
        
    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)
        
        return (torch.tanh(block8) + 1) / 2


    
    
# DISCRIMINATOR

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.net = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,padding=1),
                                 nn.LeakyReLU(0.2),
                                 
                                nn.Conv2d(64,64,kernel_size=3,stride=2,padding=1),
                                nn.BatchNorm2d(64), 
                                nn.LeakyReLU(0.2),
                                
                                nn.Conv2d(64,128,kernel_size=3,padding=1),
                                nn.BatchNorm2d(128), 
                                nn.LeakyReLU(0.2),
                                 
                                nn.Conv2d(128,128,kernel_size=3,stride=2,padding=1),
                                nn.BatchNorm2d(128), 
                                nn.LeakyReLU(0.2),
                                 
                                nn.Conv2d(128,256,kernel_size=3,padding=1),
                                nn.BatchNorm2d(256), 
                                nn.LeakyReLU(0.2),
                                 
                                nn.Conv2d(256,256,kernel_size=3,stride=2,padding=1),
                                nn.BatchNorm2d(256), 
                                nn.LeakyReLU(0.2),
                                 
                                nn.Conv2d(256,512,kernel_size=3,padding=1),
                                nn.BatchNorm2d(512), 
                                nn.LeakyReLU(0.2),
                                 
                                nn.Conv2d(512,512,kernel_size=3,stride=2,padding=1),
                                nn.BatchNorm2d(512), 
                                nn.LeakyReLU(0.2),
                                 
                                nn.AdaptiveAvgPool2d(1),
                                nn.Conv2d(512,1024,kernel_size=1),
                                nn.LeakyReLU(0.2),
                                nn.Conv2d(1024,1,kernel_size=1)
                                )
        
    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

In [6]:
# netG = Generator(2)
# print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
# netD = Discriminator()
# print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

# generator_criterion = GeneratorLoss()

# if torch.cuda.is_available():
#     netG.cuda()
#     netD.cuda()
#     generator_criterion.cuda()

# generator parameters: 586506
# discriminator parameters: 5215425


In [5]:
netG = Generator(2).eval()
netG.cuda()
netG.load_state_dict(torch.load('modele/' + 'netG_epoch.pth'))

netD = Discriminator().eval()
netD.cuda()
netD.load_state_dict(torch.load('modele/' + 'netD_epoch.pth'))

generator_criterion = GeneratorLoss()
generator_criterion.cuda()

GeneratorLoss(
  (loss_network): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), strid

In [6]:
from math import log10
import argparse

import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader 
from tqdm import tqdm
import torch
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage


from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
import time

s = time.time()

# torch.autograd.set_detect_anomaly(True)
if __name__ == "__main__":
    CROP_SIZE = 88
    UPSCALE_FACTOR = 2
    NUM_EPOCHS = 1000
    out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
    
    train_set = TrainDatasetFromFolder('F:\CelebAMask-HQ\VOCdevkit\VOC2012\JPEGImages', crop_size = CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('F:/CelebAMask-HQ/cacatest', upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=0,batch_size=20,shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)
    
    
    optimizerG = torch.optim.Adam(netG.parameters())
    optimizerD = torch.optim.Adam(netD.parameters())
    
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
    
    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
        
        netG.train()
        netD.train()
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size
            
            real_img = Variable(target)
            if torch.cuda.is_available():
                real_img = real_img.cuda()
            z = Variable(data)
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z)
            
            #UPDATE D network
            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            
            
            #UPDATE G network
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()
            
            optimizerG.step()
            
            
            optimizerD.step()
            
            #loss for current batch before oprimization
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size
            
            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))
#             print(f'g loss: {running_results["g_loss"]} g score: {running_results["g_score"]}  d loss: {running_results["d_loss"]}  d score: {running_results["d_score"]}')
    
        netG.eval()
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        with torch.no_grad():
            image = Image.open('testam.jpg')
            image = Variable(ToTensor()(image)).unsqueeze(0)
            image = image.cuda()
            out = netG(image)
            
            out_img = ToPILImage()(out[0].data.cpu())
            out_img.save('test/testam.jpg')
            
            print((time.time()-s)/60)
#             val_bar = tqdm(val_loader)
#             valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
# #             val_images = []
#             for val_lr, val_hr_restore, val_hr in val_bar:
#                 batch_size = val_lr.size(0)
#                 valing_results['batch_sizes'] += batch_size
#                 lr = val_lr
#                 hr = val_hr
#                 if torch.cuda.is_available():
#                     lr = lr.cuda()
#                     hr = hr.cuda()
#                 sr = netG(lr)

#                 batch_mse = ((sr - hr) ** 2).data.mean()
#                 valing_results['mse'] += batch_mse * batch_size
# #                 batch_ssim = ssim(sr, hr).item()
# #                 valing_results['ssims'] += batch_ssim * batch_size
#                 valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
#                 valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
#                 val_bar.set_description(
#                     desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
#                         valing_results['psnr'], valing_results['ssim']))
# #                 val_images.extend(
# #                     [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
# #                      display_transform()(sr.data.cpu().squeeze(0))])
# #             val_images = torch.stack(val_images)
# #             val_images = torch.chunk(val_images,val_images.size(0) // 1)
# #             val_save_bar = tqdm(val_images, desc='savin training results')
# #             index = 1
# #             for image in val_images:
# #                 image = utils.make_grid(image, nrow=3,padding=5)
# #                 utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
# #                 index += 1
                
    
        torch.save(netG.state_dict(), 'modele/netG_epoch.pth')
        torch.save(netD.state_dict(), 'modele/netD_epoch.pth')

        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
#         results['psnr'].append(valing_results['psnr'])
#         results['ssim'].append(valing_results['ssim'])

#         if epoch % 1 == 0 and epoch != 0:
#             out_path = 'statistics/'
        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                  'Score_G': results['g_score']},
                index=range(1, epoch + 1))
#                 data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
        print(data_frame.iloc[[-1]])

print(f'mins:  {(time.time()-s)/60}')

[1/1000] Loss_D: 1.0000 Loss_G: 0.0069 D(x): 0.0000 D(G(z)): 0.0000: 100%|██████████| 857/857 [08:55<00:00,  1.60it/s]
  0%|          | 0/857 [00:00<?, ?it/s]8.935074762503307
   Loss_D   Loss_G       Score_D       Score_G
1     1.0  0.00689  1.782052e-40  1.529065e-40
[2/1000] Loss_D: 1.0000 Loss_G: 0.0077 D(x): 0.0000 D(G(z)): 0.0000: 100%|██████████| 857/857 [08:39<00:00,  1.65it/s]
  0%|          | 0/857 [00:00<?, ?it/s]17.599195579687755
   Loss_D   Loss_G       Score_D       Score_G
2     1.0  0.00766  2.039022e-34  1.212483e-35
[3/1000] Loss_D: 1.0000 Loss_G: 0.0068 D(x): 0.0000 D(G(z)): 0.0000: 100%|██████████| 857/857 [08:39<00:00,  1.65it/s]
  0%|          | 0/857 [00:00<?, ?it/s]26.261258280277254
   Loss_D    Loss_G       Score_D       Score_G
3     1.0  0.006768  9.814358e-39  2.887645e-38
[4/1000] Loss_D: 1.0000 Loss_G: 0.0068 D(x): 0.0000 D(G(z)): 0.0000: 100%|██████████| 857/857 [08:37<00:00,  1.65it/s]
  0%|          | 0/857 [00:00<?, ?it/s]34.89535812139511
   Loss_D 

KeyboardInterrupt: 

In [7]:
# TEST
import argparse
import time

import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage

UPSCALE_FACTOR = 2
TEST_MODE = True
IMAGE_NAME = 'pu.jpg'
MODEL_NAME = 'netG_epoch.pth'

model = Generator(UPSCALE_FACTOR).eval()
model.cuda()
model.load_state_dict(torch.load('modele/' + MODEL_NAME))

image = Image.open(IMAGE_NAME)
image = Variable(transforms.ToTensor()(image)).unsqueeze(0)
image = image.cuda()

with torch.no_grad():
    out = model(image)
    
out_img = ToPILImage()(out[0].data.cpu())
out_img.save('test/' + str('20') + '_' + IMAGE_NAME)

In [26]:
# path = 'F:/CelebAMask-HQ/VOCdevkit/VOC2012/JPEGImages'
# img_names = []

# for folder,subfolders,filenames in os.walk(path):
#     for img in filenames:
#         img_names.append(folder+'/'+img)
        
# len(img_names)

# img_sizes = []
# rejected = []

# for item in img_names:
#     try:
#         with Image.open(item) as img:
#             img_sizes.append(img.size)
            
#     except:
#         rejected.append(item)
        
# df = pd.DataFrame(img_sizes)
# x = 0
# for i, row in df.iterrows():
#     if row[1] < 88:
#         print(i)

1662


Unnamed: 0,a,b,c,d
2,1000,2000,3000,4000
