In [1]:
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose , RandomCrop , ToTensor , ToPILImage ,CenterCrop ,Resize
import torch
torch.autograd.set_detect_anomaly(True)

def is_image_file (filename):
    return any(filename.endswith(extension) for extension in ['.png' , '.jpg' , '.jpeg'])

def calculate_valid_crop_size(crop_size , upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])

def train_lr_transform(crop_size , upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor , interpolation=Image.BICUBIC),
        ToTensor()
    ])

def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self , dataset_dir , crop_size , upscale_factor):
        super(TrainDatasetFromFolder , self).__init__()
        self.image_filenames = [join(dataset_dir , x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size , upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size , upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image , hr_image
    
    def __len__(self):
        return len(self.image_filenames)

class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
    
    def __len__(self):
        return len(self.image_filenames)


In [2]:
import torch
import math
import torch.nn as nn
torch.autograd.set_detect_anomaly(True)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels *
                              (up_scale ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self,  x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x


class Generator(nn.Module):
    def __init__(self, scale_factor) -> None:
        upsample_block_num = int(math.log(scale_factor, 2))
        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )

        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self) -> None:
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.net(x)
        x = x.view(batch_size, -1)  # Reshape x without modifying it in-place
        x = torch.sigmoid(x)
        return x




In [3]:
import torch
from torch import nn
from torchvision.models.vgg import vgg16
torch.autograd.set_detect_anomaly(True)


class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        # we want real_out to be close 1, and fake_out to be close 0
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss 
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        return self.tv_loss_weight * 0.5 * (
            torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]).mean() + 
            torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]).mean()
        )
    

In [4]:
import argparse
import torch
import time 
from PIL import Image 
import os
import pandas as pd
import torch.optim as optim
import torch.utils.data 
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from data_utils import TrainDatasetFromFolder , ValDatasetFromFolder , display_transform
from srgan_loss import GeneratorLoss
from srgan import Generator , Discriminator
from math import log10
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1a7ec87d780>

In [5]:
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size' , default=88 , type=int , help='training images crop size')
parser.add_argument('--upsacale_factor' , default= 4 , type=int, choices=[2,4,8] , help="Super Resolution Upscale Factor")
parser.add_argument("--num_epochs" , default=100 , type= int , help='Train epoch number')

_StoreAction(option_strings=['--num_epochs'], dest='num_epochs', nargs=None, const=None, default=100, type=<class 'int'>, choices=None, required=False, help='Train epoch number', metavar=None)

In [10]:
import argparse

# Define a list of command-line arguments
args = ['--crop_size', '88', '--upsacale_factor', '4', '--num_epochs', '100']

# Create an ArgumentParser and parse the arguments
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upsacale_factor', default=4, type=int, choices=[2, 4, 8], help="Super Resolution Upscale Factor")
parser.add_argument("--num_epochs", default=100, type=int, help='Train epoch number')

opt = parser.parse_args(args)

# Now you can access the parsed arguments as opt.crop_size, opt.upsacale_factor, etc.
print("Crop Size:", opt.crop_size)
print("Upscale Factor:", opt.upsacale_factor)
print("Number of Epochs:", opt.num_epochs)



CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upsacale_factor
NUM_EPOCHS = opt.num_epochs

train_set = TrainDatasetFromFolder('datasets/div2k/train' , crop_size=CROP_SIZE , upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('datasets/div2k/valid' , upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set , batch_size=64 , shuffle=True , num_workers=4)
val_loader = DataLoader(dataset=val_set , num_workers=4 , batch_size=1 , shuffle=False)

Crop Size: 88
Upscale Factor: 4
Number of Epochs: 100


In [11]:

    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

# generator parameters: 733579
# discriminator parameters: 5215425


In [12]:
generator_criterion = GeneratorLoss()
if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()



In [13]:
optimizerG = optim.Adam(netG.parameters())    
optimizerD = optim.Adam(netD.parameters())

In [14]:
 results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [] , 'ssim': []}

In [17]:
for epoch in range( 1 , NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0 , 'd_loss': 0 , 'g_loss': 0 , 'd_score': 0 , 'g_score': 0}
        netG.train()
        netD.train()
        for data , target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size
            # Update D network
            # We want our real to be close to 1 and fake to be close to 0

            real_img = Variable(target)
            if torch.cuda.is_available():
                real_img = real_img.cuda()
            z = Variable(data)
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z)

           # Update D network
            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)  # Specify retain_graph=True

            # Update G network 
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()

            optimizerG.step()
            optimizerD.step()  # You can move this line here to update the discriminator after computing both losses

            # loss for current batch before optimization
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size


            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'], 
                running_results['g_loss'] / running_results['batch_sizes'], 
                running_results['d_score'] / running_results['batch_sizes'], 
                running_results['g_score'] / running_results['batch_sizes']
            ))

        netG.eval()
        out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        with torch.no_grad():
            val_bar = tqdm(val_loader)
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                valing_results['batch_sizes'] += batch_size
                lr = val_lr
                hr = val_hr
                if torch.cuda.is_available():
                    lr = lr.cuda()
                    hr = hr.cuda()
                sr = netG(lr)

                batch_mse = ((sr - hr) ** 2).data.mean()
                valing_results['mse'] += batch_mse * batch_size
                batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                valing_results['ssims'] += batch_ssim * batch_size
                valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        valing_results['psnr'], valing_results['ssim']
                    )
                )

                val_images.extend(
                    [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)), 
                     display_transform()(sr.data.cpu().squeeze(0))]
                )
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            val_save_bar = tqdm(val_images, desc='[saving training results]')
            index = 1
            for image in val_save_bar:
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                index += 1
        
        # save model parameters
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        if epoch % 10 == 0 and epoch != 0:
            out_path = "statistics/"
            data_frame = pd.DataFrame(
                data = {'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], 
                        'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'], }, 
                        index=range(1, epoch+1)
            )
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')








[1/100] Loss_D: 0.5155 Loss_G: 0.0450 D(x): 0.6553 D(G(z)): 0.1708: 100%|██████████████| 13/13 [03:24<00:00, 15.76s/it]
[converting LR images to SR images] PSNR: 12.6945 dB SSIM: 0.3827: 100%|█████████████| 100/100 [29:16<00:00, 17.57s/it]
[saving training results]: 100%|███████████████████████████████████████████████████████| 20/20 [00:30<00:00,  1.54s/it]


RuntimeError: Parent directory epochs does not exist.