In [None]:
%%capture

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from getpass import getpass

username = input("Enter your GitHub username: ")
token = getpass("Enter your GitHub PAT: ")
repo_name = input("Enter the name of the repository: ")
branch_name = input("Enter the name of the branch: ")

!git clone https://Narender{username}:{token}@github.com/{username}/{repo_name}.git

file_name = "GAN-pix2pix.py"
with open(file_name, "w") as f:
    f.write("print('GAN')")

!mv {file_name} .

%cd CLEAR-VISION
!git checkout {branch_name}


!git config --global user.email "formyproject2402@gmail.com"
!git config --global user.name "formyproject2402"
!git add .
!git commit -m "added GAN-pix2pix.py to {branch_name} from Colab"
!git push origin {branch_name}

Enter your GitHub username:  Narender-0
Enter your GitHub PAT:  ········
Enter the name of the repository:  CLEAR-VISION
Enter the name of the branch:  Narender


Cloning into 'CLEAR-VISION'...
remote: Enumerating objects: 20, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 20 (delta 4), reused 6 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (20/20), 10.07 KiB | 2.01 MiB/s, done.
Resolving deltas: 100% (4/4), done.
mv: 'GAN-pix2pix.py' and './GAN-pix2pix.py' are the same file
/kaggle/working/CLEAR-VISION
Branch 'Narender' set up to track remote branch 'Narender' from 'origin'.
Switched to a new branch 'Narender'
On branch Narender
Your branch is up to date with 'origin/Narender'.

nothing to commit, working tree clean
Everything up-to-date


In [None]:
%%capture
!pip install piq

In [None]:
import torch
import torch.nn as nn
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
import piq
from piq import SSIMLoss, MultiScaleSSIMLoss
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class ImageRestorationDataset(Dataset):

  def __init__(self, corrupted_dir, clean_dir, transform=None):
    self.corrupted_dir = corrupted_dir
    self.clean_dir = clean_dir
    self.transform = transform

    self.filenames = sorted(os.listdir(corrupted_dir))

  def __len__(self):
    return len(self.filenames)

  def __getitem__(self, idx):

    corrupted_path = os.path.join(self.corrupted_dir, self.filenames[idx])
    clean_path = os.path.join(self.clean_dir, self.filenames[idx])


    corrupted_image = Image.open(corrupted_path).convert("RGB")
    clean_image = Image.open(clean_path).convert("RGB")

    if self.transform:
      corrupted_image = self.transform(corrupted_image)
      clean_image = self.transform(clean_image)

    return corrupted_image, clean_image

In [None]:
corrupted_dir = "/kaggle/input/clearvision-image-dataset/corrupted__images"
clean_dir = "/kaggle/input/clearvision-image-dataset/clean_images"

transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])

dataset = ImageRestorationDataset(corrupted_dir, clean_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
val_corrupted_dir = "/kaggle/input/clearvision-image-dataset/val_corrupted_images"
val_clean_dir = "/kaggle/input/clearvision-image-dataset/val_clean_images"

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(x + self.block(x))

In [None]:

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super(UNetGenerator, self).__init__()

        # Encoder with residuals
        self.down1 = nn.Sequential(self._contract_block(in_channels, features, use_batchnorm=False), ResidualBlock(features))
        self.down2 = nn.Sequential(self._contract_block(features, features*2), ResidualBlock(features*2))
        self.down3 = nn.Sequential(self._contract_block(features*2, features*4), ResidualBlock(features*4))
        self.down4 = nn.Sequential(self._contract_block(features*4, features*8), ResidualBlock(features*8))

        # Decoder with residuals
        self.up1 = nn.Sequential(self._expand_block(features*8, features*4), ResidualBlock(features*4))
        self.up2 = nn.Sequential(self._expand_block(features*8, features*2), ResidualBlock(features*2))
        self.up3 = nn.Sequential(self._expand_block(features*4, features), ResidualBlock(features))

        self.final = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _contract_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_batchnorm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)]
        if use_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def _expand_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u1 = self.up1(d4)
        u2 = self.up2(torch.cat([u1, d3], dim=1))
        u3 = self.up3(torch.cat([u2, d2], dim=1))

        output = self.final(torch.cat([u3, d1], dim=1))
        return output

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=6, features=[64, 128, 256, 512]):
        super(Discriminator, self).__init__()

        layers=[]

        # Initial convulation block
        layers.append(nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.2))

        # Hidden conv blocks with increasing feature
        for feature in features[1:]:
            layers.append(nn.Conv2d(feature//2 if feature!= features[1] else features[0], feature, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(feature))
            layers.append(nn.LeakyReLU(0.2))

        # Final output layer - output : 1 channel prediction map
        layers.append(nn.Conv2d(features[-1], 1, kernel_size=4, stride=1, padding=1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
generator = UNetGenerator().to(device)
discriminator = Discriminator(in_channels=6).to(device)

In [None]:
adversarial_loss = torch.nn.BCEWithLogitsLoss()
pixelwise_loss   = torch.nn.L1Loss()

ssim_loss        = SSIMLoss(data_range=2.0).to(device)       
ms_ssim_loss     = MultiScaleSSIMLoss(data_range=2.0).to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
os.makedirs('/kaggle/working/gan-checkpoints-new', exist_ok=True)
OUTPUT = "/kaggle/working/gan-checkpoints-new"

In [None]:
num_epochs = 100
patience = 10
best_val_score = -float("inf")
epochs_no_improve = 0
start_epoch = 0

val_dataset = ImageRestorationDataset(val_corrupted_dir, val_clean_dir, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

scaler_G = torch.cuda.amp.GradScaler()
scaler_D = torch.cuda.amp.GradScaler()

  scaler_G = torch.cuda.amp.GradScaler()
  scaler_D = torch.cuda.amp.GradScaler()


In [None]:
def rescale_to_01(x):
    return (x + 1) / 2

In [None]:
latest_ckpt = f"{OUTPUT}/latest_checkpoint.pth"
if os.path.exists(latest_ckpt):
    print("Loading latest checkpoint...")
    checkpoint = torch.load(latest_ckpt, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_score = checkpoint.get('best_val_score', best_val_score)
    epochs_no_improve = checkpoint.get('epochs_no_improve', epochs_no_improve)
    print(f"Resuming from epoch {start_epoch}")



for epoch in range(start_epoch, num_epochs):
    generator.train()
    discriminator.train()
    running_loss_D, running_loss_G = 0.0, 0.0

    for batch_idx, (corrupted, clean) in enumerate(tqdm(dataloader)):
        if corrupted is None or clean is None:
            continue
        corrupted = corrupted.to(device)
        clean = clean.to(device)


        optimizer_D.zero_grad()
        with torch.cuda.amp.autocast():
            fake = generator(corrupted)
            pred_real = discriminator(torch.cat([corrupted, clean], dim=1))
            pred_fake = discriminator(torch.cat([corrupted, fake.detach()], dim=1))
            loss_D_real = adversarial_loss(pred_real, torch.ones_like(pred_real))
            loss_D_fake = adversarial_loss(pred_fake, torch.zeros_like(pred_fake))
            loss_D = (loss_D_real + loss_D_fake) / 2

        scaler_D.scale(loss_D).backward()
        scaler_D.step(optimizer_D)
        scaler_D.update()
        running_loss_D += loss_D.item()   
        fake_rescal = rescale_to_01(fake)
        clean_rescal = rescale_to_01(clean)

        optimizer_G.zero_grad()
        with torch.cuda.amp.autocast():
            pred_fake_for_G = discriminator(torch.cat([corrupted, fake], dim=1))



            loss_G_adv   = adversarial_loss(pred_fake_for_G, torch.ones_like(pred_fake_for_G))
            loss_G_pixel = pixelwise_loss(fake, clean)
            loss_ssim    = ssim_loss(fake_rescal, clean_rescal)
            loss_ms_ssim = ms_ssim_loss(fake_rescal, clean_rescal)
            loss_psnr    = -psnr(fake_rescal, clean_rescal)



            loss_G = (
                0.2 * loss_G_adv +        
                0.2 * loss_G_pixel +      
                0.3 * loss_ssim +         
                0.3 * loss_ms_ssim        
            )

        scaler_G.scale(loss_G).backward()
        scaler_G.step(optimizer_G)
        scaler_G.update()
        running_loss_G += loss_G.item()   


    generator.eval()
    val_loss   = 0.0
    val_ssim   = 0.0
    val_ms_ssim = 0.0
    val_psnr    = 0.0

    with torch.no_grad():
        for corrupted, clean in tqdm(val_dataloader):
            if corrupted is None or clean is None:
                continue
            corrupted = corrupted.to(device).float()
            clean = clean.to(device).float()
            fake = generator(corrupted).float()


            fake_res  = rescale_to_01(fake)
            clean_res = rescale_to_01(clean)

            val_loss      += pixelwise_loss(fake, clean).item()
            val_ssim      += (1 - ssim_loss(fake_res, clean_res)).item()        
            val_ms_ssim   += (1 - ms_ssim_loss(fake_res, clean_res)).item()
            val_psnr      += psnr(fake_res, clean_res).item()       



    val_loss    /= len(val_dataloader)
    val_ssim    /= len(val_dataloader)
    val_ms_ssim /= len(val_dataloader)
    val_psnr    /= len(val_dataloader)

    l1_norm = 1.0 / (1.0 + val_loss)


    val_score = (0.25 * l1_norm +  0.30 * val_ssim + 0.30 * val_ms_ssim +  0.15 * (val_psnr / 40.0))

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Discriminator Loss: {running_loss_D/len(dataloader):.4f}")
    print(f"Generator Loss: {running_loss_G/len(dataloader):.4f}")
    print(f"Validation L1 Loss: {val_loss:.4f}")
    print(f"Validation SSIM: {val_ssim:.4f}")
    print(f"Validation MS-SSIM: {val_ms_ssim:.4f}")
    print(f"Validation PSNR: {val_psnr:.2f} dB")
    print(f"Validation Combined Score: {val_score:.4f}")



    if val_score > best_val_score:   
        best_val_score = val_score
        epochs_no_improve = 0
        print("Validation score improved, saving best checkpoint...")
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'best_val_score': best_val_score,
            'epochs_no_improve': epochs_no_improve
        },f"{OUTPUT}/best_checkpoint_epoch{epoch+1}.pth")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epochs.")


    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'best_val_score': best_val_score,
        'epochs_no_improve': epochs_no_improve
    }, f"{OUTPUT}/latest_checkpoint_new{epoch+1}.pth")

    if epochs_no_improve >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

torch.save({
    'epoch': epoch,
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_G_state_dict': optimizer_G.state_dict(),
    'optimizer_D_state_dict': optimizer_D.state_dict(),
    'best_val_score': best_val_score,
    'epochs_no_improve': epochs_no_improve
}, f"{OUTPUT}/final_model_new.pth")
print(f"Final model saved to {OUTPUT}/final_model_new.pth")


torch.save(generator.state_dict(), f"{OUTPUT}/final_generator_only_new.pth")
print(f"Final generator weights saved to {OUTPUT}/final_generator_only_new.pth")


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 690/690 [11:07<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 1/100
Discriminator Loss: 0.7005
Generator Loss: 0.2133
Validation L1 Loss: 0.0639
Validation SSIM: 0.8564
Validation MS-SSIM: 0.9590
Validation PSNR: 25.25 dB
Validation Combined Score: 0.8743
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:05<00:00,  1.04it/s]
100%|██████████| 321/321 [01:24<00:00,  3.78it/s]


Epoch 2/100
Discriminator Loss: 0.6967
Generator Loss: 0.2070
Validation L1 Loss: 0.0619
Validation SSIM: 0.8623
Validation MS-SSIM: 0.9612
Validation PSNR: 25.24 dB
Validation Combined Score: 0.8771
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:06<00:00,  1.04it/s]
100%|██████████| 321/321 [01:24<00:00,  3.78it/s]


Epoch 3/100
Discriminator Loss: 0.6954
Generator Loss: 0.2011
Validation L1 Loss: 0.0584
Validation SSIM: 0.8739
Validation MS-SSIM: 0.9663
Validation PSNR: 26.04 dB
Validation Combined Score: 0.8859
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:06<00:00,  1.04it/s]
100%|██████████| 321/321 [01:25<00:00,  3.78it/s]


Epoch 4/100
Discriminator Loss: 0.6947
Generator Loss: 0.1969
Validation L1 Loss: 0.0561
Validation SSIM: 0.8806
Validation MS-SSIM: 0.9671
Validation PSNR: 26.11 dB
Validation Combined Score: 0.8889
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:08<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.78it/s]


Epoch 5/100
Discriminator Loss: 0.6942
Generator Loss: 0.1934
Validation L1 Loss: 0.0546
Validation SSIM: 0.8847
Validation MS-SSIM: 0.9703
Validation PSNR: 26.61 dB
Validation Combined Score: 0.8934
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:09<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 6/100
Discriminator Loss: 0.6940
Generator Loss: 0.1905
Validation L1 Loss: 0.0543
Validation SSIM: 0.8903
Validation MS-SSIM: 0.9718
Validation PSNR: 26.84 dB
Validation Combined Score: 0.8964
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:09<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 7/100
Discriminator Loss: 0.6938
Generator Loss: 0.1882
Validation L1 Loss: 0.0552
Validation SSIM: 0.8769
Validation MS-SSIM: 0.9734
Validation PSNR: 27.06 dB
Validation Combined Score: 0.8935
No improvement for 1 epochs.


100%|██████████| 690/690 [11:08<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.78it/s]


Epoch 8/100
Discriminator Loss: 0.6937
Generator Loss: 0.1861
Validation L1 Loss: 0.0505
Validation SSIM: 0.8937
Validation MS-SSIM: 0.9755
Validation PSNR: 27.50 dB
Validation Combined Score: 0.9019
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:10<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 9/100
Discriminator Loss: 0.6938
Generator Loss: 0.1840
Validation L1 Loss: 0.0511
Validation SSIM: 0.8983
Validation MS-SSIM: 0.9760
Validation PSNR: 27.60 dB
Validation Combined Score: 0.9036
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 10/100
Discriminator Loss: 0.6946
Generator Loss: 0.1829
Validation L1 Loss: 0.0483
Validation SSIM: 0.9084
Validation MS-SSIM: 0.9769
Validation PSNR: 27.85 dB
Validation Combined Score: 0.9085
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:12<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 11/100
Discriminator Loss: 0.6960
Generator Loss: 0.1812
Validation L1 Loss: 0.0462
Validation SSIM: 0.9124
Validation MS-SSIM: 0.9786
Validation PSNR: 28.16 dB
Validation Combined Score: 0.9119
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 12/100
Discriminator Loss: 0.6965
Generator Loss: 0.1804
Validation L1 Loss: 0.0455
Validation SSIM: 0.9134
Validation MS-SSIM: 0.9795
Validation PSNR: 28.32 dB
Validation Combined Score: 0.9132
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 13/100
Discriminator Loss: 0.6965
Generator Loss: 0.1804
Validation L1 Loss: 0.0572
Validation SSIM: 0.8804
Validation MS-SSIM: 0.9719
Validation PSNR: 27.04 dB
Validation Combined Score: 0.8936
No improvement for 1 epochs.


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 14/100
Discriminator Loss: 0.6964
Generator Loss: 0.1801
Validation L1 Loss: 0.0543
Validation SSIM: 0.8845
Validation MS-SSIM: 0.9768
Validation PSNR: 27.62 dB
Validation Combined Score: 0.8991
No improvement for 2 epochs.


100%|██████████| 690/690 [11:12<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 15/100
Discriminator Loss: 0.6964
Generator Loss: 0.1789
Validation L1 Loss: 0.0547
Validation SSIM: 0.8894
Validation MS-SSIM: 0.9771
Validation PSNR: 27.45 dB
Validation Combined Score: 0.8999
No improvement for 3 epochs.


100%|██████████| 690/690 [11:13<00:00,  1.02it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 16/100
Discriminator Loss: 0.6963
Generator Loss: 0.1780
Validation L1 Loss: 0.0430
Validation SSIM: 0.9240
Validation MS-SSIM: 0.9817
Validation PSNR: 28.52 dB
Validation Combined Score: 0.9184
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.78it/s]


Epoch 17/100
Discriminator Loss: 0.6961
Generator Loss: 0.1773
Validation L1 Loss: 0.0484
Validation SSIM: 0.9110
Validation MS-SSIM: 0.9806
Validation PSNR: 28.23 dB
Validation Combined Score: 0.9118
No improvement for 1 epochs.


100%|██████████| 690/690 [11:12<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.76it/s]


Epoch 18/100
Discriminator Loss: 0.6959
Generator Loss: 0.1778
Validation L1 Loss: 0.0459
Validation SSIM: 0.9201
Validation MS-SSIM: 0.9806
Validation PSNR: 27.90 dB
Validation Combined Score: 0.9139
No improvement for 2 epochs.


100%|██████████| 690/690 [11:15<00:00,  1.02it/s]
100%|██████████| 321/321 [01:25<00:00,  3.75it/s]


Epoch 19/100
Discriminator Loss: 0.6959
Generator Loss: 0.1762
Validation L1 Loss: 0.0439
Validation SSIM: 0.9190
Validation MS-SSIM: 0.9826
Validation PSNR: 28.66 dB
Validation Combined Score: 0.9174
No improvement for 3 epochs.


100%|██████████| 690/690 [11:15<00:00,  1.02it/s]
100%|██████████| 321/321 [01:25<00:00,  3.74it/s]


Epoch 20/100
Discriminator Loss: 0.6957
Generator Loss: 0.1758
Validation L1 Loss: 0.0462
Validation SSIM: 0.9195
Validation MS-SSIM: 0.9821
Validation PSNR: 28.59 dB
Validation Combined Score: 0.9167
No improvement for 4 epochs.


100%|██████████| 690/690 [11:13<00:00,  1.02it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 21/100
Discriminator Loss: 0.6956
Generator Loss: 0.1757
Validation L1 Loss: 0.0493
Validation SSIM: 0.9053
Validation MS-SSIM: 0.9817
Validation PSNR: 28.26 dB
Validation Combined Score: 0.9103
No improvement for 5 epochs.


100%|██████████| 690/690 [11:12<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.78it/s]


Epoch 22/100
Discriminator Loss: 0.6956
Generator Loss: 0.1760
Validation L1 Loss: 0.0461
Validation SSIM: 0.9185
Validation MS-SSIM: 0.9817
Validation PSNR: 28.39 dB
Validation Combined Score: 0.9155
No improvement for 6 epochs.


100%|██████████| 690/690 [11:14<00:00,  1.02it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 23/100
Discriminator Loss: 0.6956
Generator Loss: 0.1765
Validation L1 Loss: 0.0527
Validation SSIM: 0.8999
Validation MS-SSIM: 0.9808
Validation PSNR: 27.75 dB
Validation Combined Score: 0.9058
No improvement for 7 epochs.


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 24/100
Discriminator Loss: 0.6956
Generator Loss: 0.1777
Validation L1 Loss: 0.0431
Validation SSIM: 0.9264
Validation MS-SSIM: 0.9832
Validation PSNR: 28.83 dB
Validation Combined Score: 0.9207
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:11<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 25/100
Discriminator Loss: 0.6951
Generator Loss: 0.1741
Validation L1 Loss: 0.0414
Validation SSIM: 0.9294
Validation MS-SSIM: 0.9837
Validation PSNR: 29.22 dB
Validation Combined Score: 0.9236
Validation score improved, saving best checkpoint...


100%|██████████| 690/690 [11:12<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 26/100
Discriminator Loss: 0.6685
Generator Loss: 0.2215
Validation L1 Loss: 0.0926
Validation SSIM: 0.7796
Validation MS-SSIM: 0.9483
Validation PSNR: 23.53 dB
Validation Combined Score: 0.8354
No improvement for 1 epochs.


100%|██████████| 690/690 [11:06<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 27/100
Discriminator Loss: 0.3687
Generator Loss: 0.6151
Validation L1 Loss: 0.0836
Validation SSIM: 0.7780
Validation MS-SSIM: 0.9421
Validation PSNR: 23.08 dB
Validation Combined Score: 0.8333
No improvement for 2 epochs.


100%|██████████| 690/690 [11:08<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 28/100
Discriminator Loss: 0.4020
Generator Loss: 0.6266
Validation L1 Loss: 0.0670
Validation SSIM: 0.8423
Validation MS-SSIM: 0.9596
Validation PSNR: 25.29 dB
Validation Combined Score: 0.8697
No improvement for 3 epochs.


100%|██████████| 690/690 [11:09<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 29/100
Discriminator Loss: 0.4234
Generator Loss: 0.5972
Validation L1 Loss: 0.0636
Validation SSIM: 0.8557
Validation MS-SSIM: 0.9622
Validation PSNR: 25.76 dB
Validation Combined Score: 0.8770
No improvement for 4 epochs.


100%|██████████| 690/690 [11:09<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 30/100
Discriminator Loss: 0.4467
Generator Loss: 0.5858
Validation L1 Loss: 0.0724
Validation SSIM: 0.8263
Validation MS-SSIM: 0.9550
Validation PSNR: 24.19 dB
Validation Combined Score: 0.8583
No improvement for 5 epochs.


100%|██████████| 690/690 [11:09<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 31/100
Discriminator Loss: 0.4957
Generator Loss: 0.5598
Validation L1 Loss: 0.0656
Validation SSIM: 0.8488
Validation MS-SSIM: 0.9607
Validation PSNR: 25.07 dB
Validation Combined Score: 0.8715
No improvement for 6 epochs.


100%|██████████| 690/690 [11:07<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.78it/s]


Epoch 32/100
Discriminator Loss: 0.4932
Generator Loss: 0.5345
Validation L1 Loss: 0.0670
Validation SSIM: 0.8407
Validation MS-SSIM: 0.9605
Validation PSNR: 24.99 dB
Validation Combined Score: 0.8684
No improvement for 7 epochs.


100%|██████████| 690/690 [11:08<00:00,  1.03it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 33/100
Discriminator Loss: 0.5001
Generator Loss: 0.5463
Validation L1 Loss: 0.0607
Validation SSIM: 0.8636
Validation MS-SSIM: 0.9639
Validation PSNR: 25.93 dB
Validation Combined Score: 0.8812
No improvement for 8 epochs.


100%|██████████| 690/690 [11:06<00:00,  1.04it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 34/100
Discriminator Loss: 0.4495
Generator Loss: 0.5955
Validation L1 Loss: 0.0630
Validation SSIM: 0.8552
Validation MS-SSIM: 0.9625
Validation PSNR: 25.52 dB
Validation Combined Score: 0.8762
No improvement for 9 epochs.


100%|██████████| 690/690 [11:06<00:00,  1.04it/s]
100%|██████████| 321/321 [01:25<00:00,  3.77it/s]


Epoch 35/100
Discriminator Loss: 0.4704
Generator Loss: 0.5860
Validation L1 Loss: 0.0674
Validation SSIM: 0.8376
Validation MS-SSIM: 0.9594
Validation PSNR: 24.90 dB
Validation Combined Score: 0.8667
No improvement for 10 epochs.
Early stopping at epoch 35
Final model saved to /kaggle/working/gan-checkpoints-new/final_model_new.pth
Final generator weights saved to /kaggle/working/gan-checkpoints-new/final_generator_only_new.pth


In [None]:

best_checkpoint_path = "/kaggle/working/gan-checkpoints-new/best_checkpoint_epoch25.pth"
new_generator_path = "/kaggle/working/gan-checkpoints-new/final_generator_best.pth"
new_full_model_path = "/kaggle/working/gan-checkpoints-new/final_model_best.pth"

checkpoint = torch.load(best_checkpoint_path, map_location=device)  # or 'cpu'

generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

generator.eval()

torch.save(generator.state_dict(), new_generator_path)
print(f"Standalone generator saved to {new_generator_path}")

full_model_dict = {
    'epoch': checkpoint['epoch'],
    'generator_state_dict': checkpoint['generator_state_dict'],
    'discriminator_state_dict': checkpoint['discriminator_state_dict'],
    'optimizer_G_state_dict': checkpoint['optimizer_G_state_dict'],
    'optimizer_D_state_dict': checkpoint['optimizer_D_state_dict'],
    'best_val_score': checkpoint['best_val_score']
}
torch.save(full_model_dict, new_full_model_path)
print(f"Full model saved to {new_full_model_path}")

Standalone generator saved to /kaggle/working/gan-checkpoints-new/final_generator_best.pth
Full model saved to /kaggle/working/gan-checkpoints-new/final_model_best.pth


# **SO OUR BEST SCORES ARE OBSERVED AT EPOCH 25**

# **THE SCORES ARE AS FOLLOWS :**

**Discriminator Loss: 0.6951**

**Generator Loss: 0.1741**

**Validation L1 Loss: 0.0414**

**Validation SSIM: 0.9294**

**Validation MS-SSIM: 0.9837**

**Validation PSNR: 29.22 dB**

**Validation Combined Score: 0.9236**