In [None]:
!pip install torch==1.2.0
!pip install torchvision==0.4.0
#torch.__version__

Collecting torch==1.2.0
  Downloading torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl (748.9 MB)
[K     |████████████████████████████████| 748.9 MB 636 bytes/s 
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.9.0+cu111
    Uninstalling torch-1.9.0+cu111:
      Successfully uninstalled torch-1.9.0+cu111
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.10.0+cu111 requires torch==1.9.0, but you have torch 1.2.0 which is incompatible.
torchtext 0.10.0 requires torch==1.9.0, but you have torch 1.2.0 which is incompatible.[0m
Successfully installed torch-1.2.0
Collecting torchvision==0.4.0
  Downloading torchvision-0.4.0-cp37-cp37m-manylinux1_x86_64.whl (8.8 MB)
[K     |████████████████████████████████| 8.8 MB 4.7 MB/s 
Installing collected packages: torchvision
  Attempting uninstall: torc

In [None]:
import torch
from torch import nn
import argparse
import os
from math import log10

In [None]:
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
import pytorch_ssim

In [None]:
''' DATA LOADER'''

from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])

def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])

def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super().__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)

class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)

class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]

    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w, h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.lr_filenames)



In [None]:
''' Generator Loss'''

#import torch
#from torch import nn
from torchvision.models.vgg import vgg16
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super().__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]
class GeneratorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network.to('cuda')
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        adversarial_loss = torch.mean(1 - out_labels)
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        image_loss = self.mse_loss(out_images, target_images)
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss



    
    


In [None]:
''' Models '''
#import torch
import math
#from torch import nn

class Generator(nn.Module):
    def __init__(self,scale_factor):
        upsample_block_num = int(math.log(scale_factor,2))

        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3,64, kernel_size=9, padding = 4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64,64, kernel_size=3, padding = 1),
            nn.PReLU()
        )
        block8 = [UpsampleBLock(64,2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9,padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


In [None]:
def train(batch,epochs,train_loader, val_loader,):
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
    for epoch in range(1,epochs):
        running_results = {'d_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
        train_bar = tqdm(train_loader)
        netG.train()
        netD.train()
        device = 'cuda' if torch.cuda.is_available() else 'cpu' 
        for data,target in train_bar:
            g_update_first = True
            real_img = Variable(target)
            real_img = real_img.to(device, dtype=torch.float32)
            #if torch.cuda.is_available():
                #real_img = real_img.cuda()
            z = Variable(data)
            z = z.to(device, dtype=torch.float32)
            #if torch.cuda.is_available():
                #z = z.cuda()
            fake_img = netG(z)
            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            netG.zero_grad()

            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()
            
            
            optimizerG.step()
            batch_size = batch
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size
    
            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, epochs, running_results['d_loss'] / batch,
                running_results['g_loss'] / batch,
                running_results['d_score'] / batch,
                running_results['g_score'] / batch))
    
        netG.eval()
        out_path = 'training_results/SRF_' + str(2) + '/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)


        with torch.no_grad():
            val_bar = tqdm(val_loader)
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                valing_results['batch_sizes'] += batch_size
                lr = val_lr
                hr = val_hr
                if torch.cuda.is_available():
                    lr = lr.cuda()
                    hr = hr.cuda()
                sr = netG(lr)
        
                batch_mse = ((sr - hr) ** 2).data.mean()
                valing_results['mse'] += batch_mse * batch_size
                batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                valing_results['ssims'] += batch_ssim * batch_size
                valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        valing_results['psnr'], valing_results['ssim']))
        
                val_images.extend(
                    [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                     display_transform()(sr.data.cpu().squeeze(0))])
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            val_save_bar = tqdm(val_images, desc='[saving training results]')
            index = 1
            for image in val_save_bar:
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                index += 1
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (4, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (4, epoch))
        results['d_loss'].append(running_results['d_loss'] / batch)
        results['g_loss'].append(running_results['g_loss'] / batch)
        results['d_score'].append(running_results['d_score'] / batch)
        results['g_score'].append(running_results['g_score'] / batch)
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])
    
        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(2) + '_train_results.csv', index_label='Epoch')

In [None]:
'''Download images'''

!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip

In [None]:
!unzip /content/DIV2K_train_HR.zip
!unzip /content/DIV2K_valid_HR.zip

In [None]:
'''using Dataloader'''
from torch.utils.data import DataLoader
train_img = TrainDatasetFromFolder(
    dataset_dir= '/content/DIV2K_train_HR',
    crop_size=88,
    upscale_factor=2,
)
print(train_img)
val_img = ValDatasetFromFolder(
    dataset_dir= '/content/DIV2K_valid_HR',
    upscale_factor=2,
)

''' val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)'''
train_loader = DataLoader(dataset=train_img,num_workers=4, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_img,num_workers=4, batch_size=1, shuffle=False)

<__main__.TrainDatasetFromFolder object at 0x7f1b8cc57750>


In [None]:
netG = Generator(2)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

# generator parameters: 586379
# discriminator parameters: 5215425


In [None]:
generator_criterion = GeneratorLoss()

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:32<00:00, 17.2MB/s]


In [None]:
if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

In [None]:
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

In [None]:
train(batch=32,epochs = 50, train_loader = train_loader,val_loader = val_loader)

[1/50] Loss_D: 24.9454 Loss_G: 0.2062 D(x): 10.8354 D(G(z)): 10.5746: 100%|██████████| 25/25 [01:08<00:00,  2.75s/it]
[converting LR images to SR images] PSNR: 22.6442 dB SSIM: 0.6851: 100%|██████████| 100/100 [02:34<00:00,  1.54s/it]
[saving training results]: 100%|██████████| 20/20 [00:18<00:00,  1.10it/s]
[2/50] Loss_D: 25.0826 Loss_G: 0.1799 D(x): 10.6106 D(G(z)): 10.4817: 100%|██████████| 25/25 [01:09<00:00,  2.80s/it]
[converting LR images to SR images] PSNR: 23.3863 dB SSIM: 0.6975: 100%|██████████| 100/100 [02:36<00:00,  1.56s/it]
[saving training results]: 100%|██████████| 20/20 [00:18<00:00,  1.10it/s]
[3/50] Loss_D: 25.1010 Loss_G: 0.1777 D(x): 7.7174 D(G(z)): 7.7860: 100%|██████████| 25/25 [01:09<00:00,  2.77s/it]
[converting LR images to SR images] PSNR: 23.5094 dB SSIM: 0.7046: 100%|██████████| 100/100 [02:34<00:00,  1.55s/it]
[saving training results]: 100%|██████████| 20/20 [00:18<00:00,  1.09it/s]
[4/50] Loss_D: 25.1055 Loss_G: 0.1626 D(x): 7.6415 D(G(z)): 7.7387: 100%

In [None]:
def test(model,UPSCALE_FACTOR):
    model=model.eval()
    if torch.cuda.is_available():
        model = model.cuda()
    test_set = TestDatasetFromFolder('/content/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
    test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)
    test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')

    out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    for image_name, lr_image, hr_restore_img, hr_image in test_bar:
        image_name = image_name[0]
        lr_image = Variable(lr_image, volatile=True)
        hr_image = Variable(hr_image, volatile=True)
        if torch.cuda.is_available():
            lr_image = lr_image.cuda()
            hr_image = hr_image.cuda()

        sr_image = model(lr_image)
        mse = ((hr_image - sr_image) ** 2).data.mean()
        psnr = 10 * log10(1 / mse)
        ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0]

        test_images = torch.stack(
            [display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)),
            display_transform()(sr_image.data.cpu().squeeze(0))])
        image = utils.make_grid(test_images, nrow=3, padding=5)
        utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +
                        image_name.split('.')[-1], padding=5)

        # save psnr\ssim
        results[image_name.split('_')[0]]['psnr'].append(psnr)
        results[image_name.split('_')[0]]['ssim'].append(ssim)

In [None]:
model = Generator(2).eval()

In [None]:
if torch.cuda.is_available():
    model = model.cuda()
model.load_state_dict(torch.load('/content/netG_epoch_2_49 .pth'))

In [None]:
image = Image.open('/content/input/pic11.jpg')
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
image = image.cuda()

  


In [None]:
torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
    out = model(image)
    out_img = ToPILImage()(out[0].data.cpu())
    out_img.save('/content/result/pic11_' + str(2) + '.jpeg')

' what to do\n trian for 2X\n complete test becnshmark\n update documentation and individual files\n update github\n '