In [None]:
!pip install -U albumentations
!pip install lpips
import torch
import torch.nn as nn
from torch import autograd
from torchvision.models import vgg19, VGG19_Weights
import lpips
from torchvision import transforms
import os
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from enum import Enum
from PIL import Image
import torch.nn.functional as F
import time
import sys
Device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.init()

In [None]:
class convelutional_block(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64,kernel_size = 9 ,padding = 4, stride = 1):
        super(convelutional_block, self).__init__()
        self.convo1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.PR = nn.PReLU()

    def forward(self, x):
        x = self.convo1(x)
        x = self.PR(x)
        return x

class Resedual_connection(nn.Module):
    def __init__(self,kernel_size = 3,channels = 64,padding = 1, stride = 1):
        super(Resedual_connection, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size,stride, padding)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += residual
        return x

class Post_residual_convolution(nn.Module):
    def __init__(self, in_channels = 64,  kernel_size = 3, padding = 1, stride = 1):
        super(Post_residual_convolution, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        return x

class Upsample_block(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 256, kernel_size = 3, padding = 1, stride = 1,upscale_factor = 2):
        super(Upsample_block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor )
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

class Genarator(nn.Module):
    #im not going to add input parameters here because it would be really long, if any thing should change from default values, you have to chance it on the class
    def __init__(self,num_res_blocks = 16, upsampling_factor = 4):
        super(Genarator, self).__init__()
        self.upsampling_factor = upsampling_factor
        self.conv1 = convelutional_block()
        reasuidual_layer_list =[Resedual_connection() for _ in range(num_res_blocks)]
        self.residual_layers = nn.Sequential(*reasuidual_layer_list)
        self.post_residual_convolution = Post_residual_convolution()
        self.upsampling = nn.Sequential(
            Upsample_block(upscale_factor=2),
            Upsample_block(upscale_factor=2)
        )
        self.final_layer = nn.Conv2d(64, 3, 9, 1, 4)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.conv1(x)
        residual = x
        x = self.residual_layers(x)
        x = self.post_residual_convolution(x)
        x += residual
        x = self.upsampling(x)
        x = self.final_layer(x)
        x = self.tanh(x)
        return x

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()

        self.features = nn.Sequential(

            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            self._block(64, 64, stride=2),
            self._block(64, 128, stride=1),
            self._block(128, 128, stride=2),
            self._block(128, 256, stride=1),
            self._block(256, 256, stride=2),
            self._block(256, 512, stride=1),
            self._block(512, 512, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        )

    def _block(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride, 1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

In [None]:
class SRdataset(Dataset):
    def __init__(self,High_resalution_paths,Low_resalution_paths,Scale = 4,crop_size = 96,is_traning =True):
        super(SRdataset,self).__init__()

        self.High_paths = High_resalution_paths
        self.Low_paths = Low_resalution_paths
        self.scale = Scale
        self.crop_size = crop_size
        self.is_traning = is_traning
        self.path_and_image_pairs = self.create_image_path_pairs()
        self.High_Resalution_image_transformations = A.Compose([A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
        self.Low_Resalution_image_transformations = A.Compose([A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), ToTensorV2()])


    def create_image_path_pairs(self):

        pairs = []

        for high_path,low_path in zip(self.High_paths,self.Low_paths):

            hr_file_names = sorted(os.listdir(high_path))
            lr_file_names = sorted(os.listdir(low_path))

            if len(hr_file_names) != len(lr_file_names):
                raise RuntimeError(f"Mismatched file counts in {high_path} ({len(hr_file_names)}) and {low_path} ({len(lr_file_names)})")

            dir_pairs_with_imges = list(zip(
                                        [(high_path,hr_file_name) for hr_file_name in hr_file_names],
                                        [(low_path,lr_file_name) for lr_file_name in lr_file_names])
                                        )

            pairs.extend(dir_pairs_with_imges)

        return pairs

    def load_image(self,path):
        img = cv2.imread(path)
        if img is None:
            raise ValueError(f"Could not Load the image of path : {path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return  img.astype(np.float32)

    def Transformation(self,high_resalution_image,Low_resalution_image):

        if not self.is_traning:
            high_resalution_Tensor = self.High_Resalution_image_transformations(image = high_resalution_image)["image"]
            Low_resalution_Tensor = self.Low_Resalution_image_transformations(image = Low_resalution_image)["image"]

            return Low_resalution_Tensor,high_resalution_Tensor

        if self.is_traning:

            h,w = high_resalution_image.shape[:2]
            x = np.random.randint(0,w - self.crop_size +1)
            y = np.random.randint(0,h - self.crop_size + 1)

            low_res_crop_size = self.crop_size // self.scale
            low_res_crop_x = x // self.scale
            low_res_crop_y = y // self.scale

            highres_crop = high_resalution_image[y:y+self.crop_size, x:x+self.crop_size,:]
            lowres_crop = Low_resalution_image[low_res_crop_y:low_res_crop_y+low_res_crop_size, low_res_crop_x:low_res_crop_x+low_res_crop_size,:]

            do_flip = np.random.random() < 0.5
            if do_flip:
                highres_crop = np.fliplr(highres_crop).copy()
                lowres_crop = np.fliplr(lowres_crop).copy()

            high_resalution_Tensor = self.High_Resalution_image_transformations(image=highres_crop)['image']
            Low_resalution_Tensor = self.Low_Resalution_image_transformations(image=lowres_crop)['image']

            return Low_resalution_Tensor, high_resalution_Tensor


    def __getitem__(self, idx):

        (hr_path,hr_name),(lr_path,lr_name) = self.path_and_image_pairs[idx]

        hr_path = os.path.join(hr_path, hr_name)
        lr_path = os.path.join(lr_path, lr_name)

        hr_image = self.load_image(hr_path)
        lr_image = self.load_image(lr_path)

        return self.Transformation(hr_image, lr_image) #return lower_res , Hig_res

    def __len__(self):
        return len(self.path_and_image_pairs)

In [None]:
class VGGLoss(nn.Module):
    def __init__(self, device='cuda', scale_factor=1/12.75):
        super(VGGLoss, self).__init__()
        vgg = vgg19(weights=VGG19_Weights.DEFAULT).features[:36].eval().to(Device)
        for param in vgg.parameters():
            param.requires_grad = False

        self.vgg = vgg
        self.scale_factor = scale_factor


        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, sr, hr):

        sr = (sr + 1) / 2


        # hr = torch.clamp(hr, 0, 1)

        sr = (sr - self.mean) / self.std
        hr = (hr - self.mean) / self.std


        sr_features = self.vgg(sr)
        hr_features = self.vgg(hr)

        N, C, H, W = sr_features.size()


        loss = torch.mean((sr_features - hr_features) ** 2) * self.scale_factor / (H * W)
        return loss

In [None]:
class TrainingPhase(Enum):
    PRETRAIN = "pretrain"
    SRGAN = "srgan"


class CheckpointHandler:
    def __init__(self,primary_path,phase=TrainingPhase.PRETRAIN):
        self.base_dir = Path(primary_path)
        self.phase = phase

        self.latest_dir = self.base_dir / phase.value / 'latest'
        self.best_dir = self.base_dir / phase.value / 'best'
        self.numbered_dir = self.base_dir / phase.value / 'numbered'

        for dir_path in [self.latest_dir, self.best_dir, self.numbered_dir]:
            dir_path.mkdir(parents=True, exist_ok=True)

        self.best_psnr = 0.0

    def save_checkpoint(self, generator,g_optimizer=None,g_scheduler = None, discriminator=None,
                        d_optimizer=None, d_scheduler = None,
                       iteration=0, psnr=None, is_best=False):

        if self.phase == TrainingPhase.PRETRAIN:
            checkpoint = {
                'iteration': iteration,
                'generator_state': generator.state_dict(),
                'g_optimizer_state': g_optimizer.state_dict() if g_optimizer else None,
                'g_scheduler_state': g_scheduler .state_dict() if g_scheduler else None
            }
        else:
            checkpoint = {
                'iteration': iteration,
                'generator_state': generator.state_dict(),
                'discriminator_state': discriminator.state_dict() if discriminator else None,
                'g_optimizer_state': g_optimizer.state_dict() if g_optimizer else None,
                'd_optimizer_state': d_optimizer.state_dict() if d_optimizer else None,
                'g_scheduler_state': g_scheduler .state_dict() if g_scheduler else None,
                'd_scheduler_state': d_scheduler .state_dict() if d_scheduler else None
            }

        latest_path = self.latest_dir / 'latest_checkpoint.pt'
        torch.save(checkpoint, latest_path)

        if iteration % 50 == 0:
            numbered_path = self.numbered_dir / f'checkpoint_{iteration}.pt'
            torch.save(checkpoint, numbered_path)

        if is_best and psnr is not None and psnr > self.best_psnr:
            self.best_psnr = psnr
            best_path = self.best_dir / 'best_model.pt'
            torch.save(checkpoint, best_path)


    def load_checkpoint(self, generator,g_optimizer=None,g_scheduler = None, discriminator=None,
                        d_optimizer=None,d_scheduler = None,
                       checkpoint_type='latest'):

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if checkpoint_type == 'latest':
            checkpoint_path = self.latest_dir / 'latest_checkpoint.pt'
        elif checkpoint_type == 'best':
            checkpoint_path = self.best_dir / 'best_model.pt'
        else:
            checkpoint_path = self.numbered_dir / f'checkpoint_{checkpoint_type}.pt'

        if not checkpoint_path.exists():
            print(f"No checkpoint found at {checkpoint_path}")
            return 0

        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.to(device)
        generator.load_state_dict(checkpoint['generator_state'])

        if g_scheduler and 'g_scheduler_state' in checkpoint:
            g_scheduler.load_state_dict(checkpoint['g_scheduler_state'])

        if g_optimizer and 'g_optimizer_state' in checkpoint:
            g_optimizer.load_state_dict(checkpoint['g_optimizer_state'])

        if self.phase == TrainingPhase.SRGAN:
            if discriminator and 'discriminator_state' in checkpoint:
                discriminator.to(device)
                discriminator.load_state_dict(checkpoint['discriminator_state'])

            if d_optimizer and 'd_optimizer_state' in checkpoint:
                d_optimizer.load_state_dict(checkpoint['d_optimizer_state'])

            if d_scheduler and 'd_scheduler_state' in checkpoint:
                d_scheduler.load_state_dict(checkpoint['d_scheduler_state'])


        return checkpoint['iteration']

    def clean_old_checkpoints(self, keep_last_n=3):
        checkpoint_files = sorted(list(self.numbered_dir.glob('checkpoint_*.pt')))
        for checkpoint_file in checkpoint_files[:-keep_last_n]:
            checkpoint_file.unlink()

    def ischeackpoint(self):
        if os.path.exists(self.latest_dir) and os.path.isdir(self.latest_dir) and os.listdir(self.latest_dir):
            return True
        else:
            return False



In [None]:
class LossCheckpointHandler:
    def __init__(self, primary_path):
        self.base_dir = Path(primary_path)
        self.loss_dir = self.base_dir / 'losses'
        self.loss_dir.mkdir(parents=True, exist_ok=True)
        self.losses = []

    def append_loss(self, new_loss):
        if torch.is_tensor(new_loss):
            new_loss = new_loss.cpu().detach().numpy()
        self.losses.append(new_loss)

    def save_checkpoint(self, iteration):
        checkpoint = {
            'iteration': iteration,
            'loss_history': self.losses
        }
        torch.save(checkpoint, self.loss_dir / 'loss_history.pt')

    def load_checkpoint(self):
        path = self.loss_dir / 'loss_history.pt'
        if path.exists():
            checkpoint = torch.load(path)
            self.losses = checkpoint['loss_history']
            return checkpoint['iteration'], self.losses
        return 0 ,[]

In [None]:
class ImageEvaluationMetrics:
    def __init__(self, device):
        self.device = device
        self.mse_criterion = nn.MSELoss().to(device)
        self.lpips_criterion = lpips.LPIPS(net='alex').to(device)

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def load_and_prepare_image(self, image_path):

        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)
        return img_tensor

    def prepare_images(self, sr, hr):

        sr_01 = torch.clamp((sr + 1) / 2, 0, 1)
        hr_01 = torch.clamp((hr + 1) / 2, 0, 1)


        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        sr_lpips = normalize(sr_01)
        hr_lpips = normalize(hr_01)

        return sr_01, hr_01, sr_lpips, hr_lpips

    def calculate_ssim(self, sr, hr, window_size=11):

        C1 = (0.01 * 1) ** 2
        C2 = (0.03 * 1) ** 2

        mu1 = F.avg_pool2d(sr, window_size, stride=1, padding=window_size//2)
        mu2 = F.avg_pool2d(hr, window_size, stride=1, padding=window_size//2)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.avg_pool2d(sr * sr, window_size, stride=1, padding=window_size//2) - mu1_sq
        sigma2_sq = F.avg_pool2d(hr * hr, window_size, stride=1, padding=window_size//2) - mu2_sq
        sigma12 = F.avg_pool2d(sr * hr, window_size, stride=1, padding=window_size//2) - mu1_mu2

        ssim = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        return ssim.mean()

    def evaluate_images(self, generator, lr_path, hr_path):

        generator.eval()
        metrics = {}

        with torch.no_grad():

            lr = self.load_and_prepare_image(lr_path)
            hr = self.load_and_prepare_image(hr_path)

            sr = generator(lr)

            sr_01, hr_01, sr_lpips, hr_lpips = self.prepare_images(sr, hr)

            mse = self.mse_criterion(sr_01, hr_01).item()
            psnr = -10 * torch.log10(torch.tensor(mse + 1e-8))
            ssim = self.calculate_ssim(sr_01, hr_01)
            lpips_value = self.lpips_criterion(sr_lpips, hr_lpips).mean()

            metrics = {
                'mse': mse,
                'psnr': psnr.item(),
                'ssim': ssim.item(),
                'lpips': lpips_value.item()
            }

        generator.train()
        return metrics

    def evaluate_directory(self, generator, lr_dir, hr_dir):

        total_metrics = {'psnr': 0, 'ssim': 0, 'mse': 0, 'lpips': 0}
        n_samples = 0

        lr_files = sorted([f for f in Path(lr_dir).glob('*.png')])
        hr_files = sorted([f for f in Path(hr_dir).glob('*.png')])

        for lr_path, hr_path in zip(lr_files, hr_files):
            metrics = self.evaluate_images(generator, lr_path, hr_path)

            for k, v in metrics.items():
                total_metrics[k] += v
            n_samples += 1

        avg_metrics = {k: v/n_samples for k, v in total_metrics.items()}
        return avg_metrics

In [None]:
def gradient_penalty(critc,real,fake,device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).to(device)
    interpolated_images = real * epsilon + fake * (1 - epsilon)
    interpolated_images.requires_grad = True
    mixed_scores = critc(interpolated_images)
    gradient = autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty,gradient_norm

In [None]:
def show_model_results(model, image_path):

    import torch
    from PIL import Image
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt

    def deprocess_image(image):

        image = image / 2 + 0.5
        image = np.clip(image * 255, 0, 255)
        return image.astype(np.uint8)


    transform = transforms.Compose([
        transforms.ToTensor(),

    ])

    img = Image.open(image_path)

    input_tensor = transform(img).unsqueeze(0).to('cuda')


    model.eval()
    with torch.no_grad():
        output = model(input_tensor)

    output = deprocess_image(output.cpu().squeeze(0).permute(1, 2, 0).numpy())
    model.train()
    Image.fromarray(output).save('output.png', quality=100, subsampling=0)
    plt.figure(figsize=(10, 10))
    plt.imshow(output)
    plt.axis('off')
    plt.show()


In [None]:
Train_High_resalution_image_paths = ["/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR","/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_train_HR","/kaggle/input/srgan-faces-4x/Faces_with_4x_downsamples/high_resolution","/kaggle/input/urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100"]
Train_Low_resalution_image_paths = ["/kaggle/input/flickr2k/Flickr2K/Flickr2K_LR_bicubic/X4","/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_train_LR_bicubic_X4/X4","/kaggle/input/srgan-faces-4x/Faces_with_4x_downsamples/low_resolution","/kaggle/input/urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100"]

Train_dataset = SRdataset(Train_High_resalution_image_paths,Train_Low_resalution_image_paths)
Train_dataloader = DataLoader(Train_dataset,batch_size=256,shuffle=True,num_workers=4,pin_memory=True,drop_last=True)

valid_High_resalution_image_paths = ["/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_valid_HR"]
valid_low_resalution_image_paths = ["/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_valid_LR_bicubic_X4/X4"]

valid_dataset = SRdataset(valid_High_resalution_image_paths,valid_low_resalution_image_paths)
valid_dataloader = DataLoader(valid_dataset,batch_size=16,shuffle=True,num_workers=4,pin_memory=True,drop_last=True)


In [None]:
learning_rate = 1e-4
Generator_model = Genarator()
Genarator_optimizer = torch.optim.Adam(Generator_model.parameters(),lr = learning_rate, betas=(0.9, 0.999))
g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(Genarator_optimizer,mode='min',factor=0.5,patience=5000,verbose=True,min_lr=1e-6)
loss_function = VGGLoss() #data,lables
mseloss = nn.MSELoss()
g_CheckpointHandler = CheckpointHandler("/kaggle/working/",TrainingPhase.PRETRAIN)
LossHandler = LossCheckpointHandler("/kaggle/working/")

In [None]:
def Full_model_gd_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm

In [None]:
current_iteration = 0
max_iterations = 800
Generator_model.to(Device)
loss_function.to(Device)
losses = []
if g_CheckpointHandler.ischeackpoint():
    current_iteration = g_CheckpointHandler.load_checkpoint(Generator_model,Genarator_optimizer,g_scheduler) +1
    loss_itration,losses = LossHandler.load_checkpoint()
    # for param_group in Genarator_optimizer.param_groups:
    #     param_group['lr'] = 1e-3

for epoc in range(current_iteration,max_iterations):
    batch_loss = []
    t1 = time.time()
    for data ,lables in Train_dataloader:

        data = data.to(Device)
        lables = lables.to(Device)

        predict = Generator_model(data)

        loss = mseloss(predict,lables) + 10*loss_function(predict,lables)

        batch_loss.append(loss.item())

        Genarator_optimizer.zero_grad()
        loss.backward()
        Genarator_optimizer.step()
    mean_batch_loss = np.mean(batch_loss)
    # g_scheduler.step(mean_batch_loss)
    losses.append(mean_batch_loss)
    LossHandler.append_loss(mean_batch_loss)
    t2 = time.time()

    if (epoc + 1) % 10 == 0:
        g_CheckpointHandler.save_checkpoint(Generator_model,Genarator_optimizer,g_scheduler,iteration = epoc)
        LossHandler.save_checkpoint(epoc)

    sys.stdout.write(f"\r{epoc + 1} / {max_iterations} epocs ,time:{t2-t1} ,LR: {Genarator_optimizer.param_groups[0]['lr']} ,Loss : {mean_batch_loss} , grad :{Full_model_gd_norm(Generator_model)}")
    sys.stdout.flush()



In [None]:
show_model_results(Generator_model.to(Device),"/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_valid_LR_bicubic_X4/X4/0850x4.png")

In [None]:
learning_rate = 1e-4
Generator_model = Genarator()
Critic_model = Critic()
Genarator_optimizer = torch.optim.Adam(Generator_model.parameters(),lr = learning_rate, betas=(0, 0.9))
Critic_optimizer = torch.optim.Adam(Critic_model.parameters(),lr = learning_rate, betas=(0, 0.9))
g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(Genarator_optimizer,mode='min',factor=0.5,patience=5000,verbose=True,min_lr=1e-6)
c_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(Critic_optimizer,mode='min',factor=0.5,patience=5000,verbose=True,min_lr=1e-6)
loss_function = VGGLoss() #data,lables
mseLoss = nn.MSELoss()
srgan_CheckpointHandler = CheckpointHandler("/kaggle/working/",TrainingPhase.SRGAN)
LossHandlerSrgan = LossCheckpointHandler("/kaggle/working/SRGANLoss")


In [None]:
current_iteration = 0
max_iterations = 100
Generator_model.to(Device)
Critic_model.to(Device)
loss_function.to(Device)
mseLoss.to(Device)
if srgan_CheckpointHandler.ischeackpoint():
    current_iteration = srgan_CheckpointHandler.load_checkpoint(Generator_model,Genarator_optimizer,g_scheduler,Critic_model,Critic_optimizer,c_scheduler) +1
    loss_itration,losses = LossHandlerSrgan.load_checkpoint()

Ganlosses_Generator = []
Ganlosses_Critic = []
Valid_Ganlosses_Generator = []
Valid_Ganlosses_Critic = []
ncritic_epoc = 5
for epoc in range(current_iteration,max_iterations):
    Gan_batch_G = []
    Gan_batch_C = []
    critic_real_loss_batch = []
    lamda = 10
    critic_data = iter(Train_dataloader)
    t1 = time.time()
    for data ,lables in Train_dataloader:

        critic_nepoc_loss = []
        critic_nepoc_loss_real = []
        for _ in range(ncritic_epoc):
            try:
                crtic_data,critc_labels = next(critic_data)

            except StopIteration:
                critic_data = iter(Train_dataloader)
                crtic_data,critc_labels = next(critic_data)

            crtic_data = crtic_data.to(Device)
            critc_labels = critc_labels.to(Device)

            with torch.no_grad():
                fake = Generator_model(crtic_data)
            fake_score = Critic_model(fake)
            real_score = Critic_model(critc_labels)
            grad_penalty,gradient_norm = gradient_penalty(Critic_model,critc_labels,fake,Device)
            real_critc_loss = -(torch.mean(real_score) - torch.mean(fake_score))
            critic_loss = real_critc_loss + lamda * grad_penalty
            critic_nepoc_loss_real.append(real_critc_loss.item())
            critic_nepoc_loss.append(critic_loss.item())
            Critic_optimizer.zero_grad()
            critic_loss.backward()
            Critic_optimizer.step()

        data = data.to(Device)
        lables = lables.to(Device)
        Gan_batch_C.append(np.mean(critic_nepoc_loss))
        critic_real_loss_batch.append(np.mean(critic_nepoc_loss_real))
        fake = Generator_model(data)
        fake_score = Critic_model(fake)
        generator_loss = -1e-2 * torch.mean(fake_score) + 2.0 *mseLoss(fake,lables) + 10.0 * loss_function(fake,lables)
        Genarator_optimizer.zero_grad()
        generator_loss.backward()
        Genarator_optimizer.step()
        Gan_batch_G.append(generator_loss.item())

    mean_real_critc_loss_full = np.mean(critic_real_loss_batch)
    mean_batch_loss_G = np.mean(Gan_batch_G)
    mean_batch_loss_C = np.mean(Gan_batch_C)
    Ganlosses_Generator.append(mean_batch_loss_G)
    Ganlosses_Critic.append(mean_batch_loss_C)
    LossHandlerSrgan.append_loss([mean_batch_loss_G,mean_batch_loss_C])

    # with torch.no_grad():
    #     valid_Gan_batch_G = []
    #     valid_Gan_batch_C = []
    #     for data ,lables in valid_dataloader:
    #         data = data.to(Device)
    #         lables = lables.to(Device)
    #         fake = Generator_model(data)
    #         fake_score = Critic_model(fake)
    #         real_score = Critic_model(lables)
    #         critic_loss = -(torch.mean(real_score) - torch.mean(fake_score))
    #         fake = Generator_model(data)
    #         fake_score = Critic_model(fake)
    #         generator_loss = -1e-3 * torch.mean(fake_score) + mseLoss(fake,lables) + 10 * loss_function(fake,lables)
    #         valid_Gan_batch_G.append(generator_loss.item())
    #         valid_Gan_batch_C.append(critic_loss.item())
    #     Valid_Ganlosses_Generator.append(np.mean(valid_Gan_batch_G))
    #     Valid_Ganlosses_Critic.append(np.mean(valid_Gan_batch_C))

    t2 = time.time()
    if (epoc + 1) % 10 == 0:
        srgan_CheckpointHandler.save_checkpoint(Generator_model,Genarator_optimizer,g_scheduler,Critic_model,Critic_optimizer,c_scheduler,iteration = epoc)
        LossHandlerSrgan.save_checkpoint(epoc)
    sys.stdout.write(f"\r{epoc + 1}/{max_iterations}epocs,time:{t2-t1},GLoss:{mean_batch_loss_G},CLoss:{mean_batch_loss_C},grad:{Full_model_gd_norm(Generator_model)},cgrad:{Full_model_gd_norm(Critic_model)},RCLoss {mean_real_critc_loss_full} ")
    sys.stdout.flush()



In [None]:

show_model_results(Generator_model.to(Device),"/kaggle/input/srgan-faces-4x/Faces_with_4x_downsamples/low_resolution/00403.png")