**Libraries**

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/x.zip

In [None]:
!mv /content/x/* /content/drive/MyDrive/prototype/Train/x/
!mv /content/y/* /content/drive/MyDrive/prototype/Train/y/

In [3]:
import os

print(len(os.listdir('/content/drive/MyDrive/prototype/Train/x')))
print(len(os.listdir('/content/drive/MyDrive/prototype/Train/y')))

1000
1000


In [4]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import models
from tqdm import tqdm

# Building Generator

In [None]:

# class DownSample(nn.Module):
#     def __init__(self, in_channels, out_channels, apply_batchnorm=True):
#         super(DownSample, self).__init__()
#         layers = [
#             nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
#         ]
#         if apply_batchnorm:
#             layers.append(nn.BatchNorm2d(out_channels))
#         layers.append(nn.LeakyReLU(0.2))
#         self.block = nn.Sequential(*layers)

#     def forward(self, x):
#         return self.block(x)

# # Define the upsampling block
# class Upsample(nn.Module):
#     def __init__(self, in_channels, out_channels, apply_dropout=False):
#         super(Upsample, self).__init__()
#         layers = [
#             nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU()
#         ]
#         if apply_dropout:
#             layers.append(nn.Dropout(0.5))
#         self.block = nn.Sequential(*layers)

#     def forward(self, x, skip_input):
#         x = self.block(x)
#         x = torch.cat((x, skip_input), 1)
#         return x

# # Generator adapted for SAR images
# class Generator(nn.Module):
#     def __init__(self, in_channels=1, out_channels=3):  # 1 input channel (SAR), 3 output channels (RGB)
#         super(Generator, self).__init__()
#         self.down1 = DownSample(in_channels, 64)
#         self.down2 = DownSample(64, 128)
#         self.down3 = DownSample(128, 256)
#         self.down4 = DownSample(256, 512)
#         self.down5 = DownSample(512, 512)
#         self.down6 = DownSample(512, 512)
#         self.down7 = DownSample(512, 512)
#         self.down8 = DownSample(512, 512)

#         self.up1 = Upsample(512, 512)
#         self.up2 = Upsample(1024, 512)
#         self.up3 = Upsample(1024, 512)
#         self.up4 = Upsample(1024, 512)
#         self.up5 = Upsample(1024, 256)
#         self.up6 = Upsample(512, 128)
#         self.up7 = Upsample(256, 64)

#         self.final = nn.Sequential(
#             nn.Upsample(scale_factor=2),
#             nn.ZeroPad2d((1, 0, 1, 0)),
#             nn.Conv2d(128, out_channels, kernel_size=4, padding=1),  # 3 output channels (RGB)
#             nn.Tanh(),
#         )

#     def forward(self, x):
#         # U-NET generator with skip connections from encoder to decoder
#         d1 = self.down1(x)
#         d2 = self.down2(d1)
#         d3 = self.down3(d2)
#         d4 = self.down4(d3)
#         d5 = self.down5(d4)
#         d6 = self.down6(d5)
#         d7 = self.down7(d6)
#         d8 = self.down8(d7)

#         u1 = self.up1(d8, d7)
#         u2 = self.up2(u1, d6)
#         u3 = self.up3(u2, d5)
#         u4 = self.up4(u3, d4)
#         u5 = self.up5(u4, d3)
#         u6 = self.up6(u5, d2)
#         u7 = self.up7(u6, d1)
#         u8 = self.final(u7)

#         return u8


In [5]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, apply_batchnorm=True):
        super(DownSample, self).__init__()
        layers = [
            nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1, groups=in_channels, bias=False),  # Depthwise
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)  # Pointwise
        ]
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

# Define an upsampling block using interpolation followed by depthwise separable convolution
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, apply_dropout=False):
        super(Upsample, self).__init__()
        layers = [
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=4, padding='same', groups=in_channels, bias=False),  # Depthwise
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),  # Pointwise
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        if apply_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.block(x)
        x = torch.cat((x, skip_input), 1)
        return x

# Generator adapted for SAR images using depthwise separable convolutions
class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(Generator, self).__init__()
        self.down1 = DownSample(in_channels, 64)
        self.down2 = DownSample(64, 128)
        self.down3 = DownSample(128, 256)
        self.down4 = DownSample(256, 512)
        self.down5 = DownSample(512, 512)
        self.down6 = DownSample(512, 512)
        self.down7 = DownSample(512, 512)
        # self.down8 = DownSample(512, 512)

        # self.up1 = Upsample(512, 512)
        self.up2 = Upsample(512, 512)
        self.up3 = Upsample(1024, 512)
        self.up4 = Upsample(1024, 512)
        self.up5 = Upsample(1024, 256)
        self.up6 = Upsample(512, 128)
        self.up7 = Upsample(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 128, kernel_size=4, padding='same', groups=128, bias=False),  # Depthwise
            nn.Conv2d(128, out_channels, kernel_size=1, stride=1),  # Pointwise
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)  #128, 64
        d2 = self.down2(d1) #64, 128
        d3 = self.down3(d2) #32, 256
        d4 = self.down4(d3) #16, 512
        d5 = self.down5(d4) #8, 512
        d6 = self.down6(d5) #4, 512
        d7 = self.down7(d6) #2, 512
        # d8 = self.down8(d7) #1, 512

        # u1 = self.up1(d8, d7)
        u2 = self.up2(d7, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1) #128x128x128c
        u8 = self.final(u7)

        return u8

# **Building Discriminator**

In [6]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=4):
        super(Discriminator, self).__init__()
        def discriminator_block(in_filters, out_filters, stride=2, normalize=True):
            layers = [
                nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=stride, padding=1)
                ]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512, stride=1),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
device

device(type='cuda')

In [9]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

In [10]:
generator = generator.to(device)
discriminator = discriminator.to(device)

In [11]:
print(discriminator)

Discriminator(
  (model): Sequential(
    (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)


In [12]:
print(generator)

Generator(
  (down1): DownSample(
    (block): Sequential(
      (0): Conv2d(1, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): LeakyReLU(negative_slope=0.2)
    )
  )
  (down2): DownSample(
    (block): Sequential(
      (0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
      (1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): LeakyReLU(negative_slope=0.2)
    )
  )
  (down3): DownSample(
    (block): Sequential(
      (0): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
      (1): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(256, eps=1e-05, momentum=0

In [13]:
from torchsummary import summary

summary(generator, input_size=(1, 256, 256))
sample_input = torch.randn(1, 1, 256, 256).to(device)
output = generator(sample_input)
print(output.shape)


summary(discriminator, input_size=[(1, 256, 256), (3, 256, 256)])
img1 = torch.randn(1, 1, 256, 256).to(device)
img2 = torch.randn(1, 3, 256, 256).to(device)
output = discriminator(img1, img2)
print(output.shape)

  return F.conv2d(


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 1, 128, 128]              16
            Conv2d-2         [-1, 64, 128, 128]              64
       BatchNorm2d-3         [-1, 64, 128, 128]             128
         LeakyReLU-4         [-1, 64, 128, 128]               0
        DownSample-5         [-1, 64, 128, 128]               0
            Conv2d-6           [-1, 64, 64, 64]           1,024
            Conv2d-7          [-1, 128, 64, 64]           8,192
       BatchNorm2d-8          [-1, 128, 64, 64]             256
         LeakyReLU-9          [-1, 128, 64, 64]               0
       DownSample-10          [-1, 128, 64, 64]               0
           Conv2d-11          [-1, 128, 32, 32]           2,048
           Conv2d-12          [-1, 256, 32, 32]          32,768
      BatchNorm2d-13          [-1, 256, 32, 32]             512
        LeakyReLU-14          [-1, 256,

In [14]:
import os

torch.save(generator.state_dict(), "model.pth")
print(f"Model size: {os.path.getsize('model.pth') / (1024 * 1024):.2f} MB")

torch.save(discriminator.state_dict(), "model.pth")
print(f"Model size: {os.path.getsize('model.pth') / (1024 * 1024):.2f} MB")

Model size: 10.46 MB
Model size: 10.57 MB


# **Defining Loss Functions**

In [15]:
# # criterion_GAN = nn.MSELoss()
# # criterion_pixelwise = nn.L1Loss()

# """# Chromatic Aberration Loss definition
# class ChromaticAberrationLoss(nn.Module):
#     def __init__(self, lambda_color=1.0, lambda_spatial=1.0, lambda_perceptual=1.0, lambda_edge=1.0):
#         super(ChromaticAberrationLoss, self).__init__()
#         self.lambda_color = lambda_color
#         self.lambda_spatial = lambda_spatial
#         self.lambda_perceptual = lambda_perceptual
#         self.lambda_edge = lambda_edge

#         # Pre-trained VGG19 model for perceptual loss
#         vgg19 = models.vgg19(pretrained=True).features
#         self.vgg19_block4_conv4 = nn.Sequential(*list(vgg19[:21])).eval()  # Block 4 conv 4
#         for param in self.vgg19_block4_conv4.parameters():
#             param.requires_grad = False

#     # Convert from RGB to YUV (PyTorch version)
#     def rgb_to_yuv(self, image):
#         r, g, b = image[:, 0:1], image[:, 1:1], image[:, 2:1]
#         y = 0.299 * r + 0.587 * g + 0.114 * b
#         u = -0.14713 * r - 0.28886 * g + 0.436 * b
#         v = 0.615 * r - 0.51499 * g - 0.10001 * b
#         return torch.cat([y, u, v], dim=1)

#     # 1. Color Discrepancy Loss (L2 Norm in YUV space)
#     def color_loss(self, y_true, y_pred):
#         y_true_yuv = self.rgb_to_yuv(y_true)
#         y_pred_yuv = self.rgb_to_yuv(y_pred)
#         return torch.mean((y_true_yuv - y_pred_yuv) ** 2)

#     # 2. Spatial Consistency Loss (L1 Norm between neighboring pixels)
#     def spatial_loss(self, image):
#         loss_vertical = torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
#         loss_horizontal = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:]))
#         return loss_vertical + loss_horizontal

#     # 3. Perceptual Loss (using VGG19)
#     def perceptual_loss(self, y_true, y_pred):
#         y_true_vgg = self.vgg19_block4_conv4(y_true)
#         y_pred_vgg = self.vgg19_block4_conv4(y_pred)
#         return torch.mean((y_true_vgg - y_pred_vgg) ** 2)

#     # 4. Edge-Aware Loss (gradient difference in edge areas)
#     def edge_aware_loss(self, y_true, y_pred):
#         grad_true_x = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
#         grad_true_y = y_true[:, :, :, 1:] - y_true[:, :, :, :-1]
#         grad_pred_x = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
#         grad_pred_y = y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1]

#         edge_loss = torch.mean(torch.abs(grad_true_x - grad_pred_x)) + \
#                     torch.mean(torch.abs(grad_true_y - grad_pred_y))
#         return edge_loss

#     # Total Chromatic Aberration Loss
#     def forward(self, y_true, y_pred):
#         color_loss_value = self.color_loss(y_true, y_pred)
#         spatial_loss_value = self.spatial_loss(y_pred)
#         perceptual_loss_value = self.perceptual_loss(y_true, y_pred)
#         edge_loss_value = self.edge_aware_loss(y_true, y_pred)

#         total_loss = (self.lambda_color * color_loss_value) + \
#                      (self.lambda_spatial * spatial_loss_value) + \
#                      (self.lambda_perceptual * perceptual_loss_value) + \
#                      (self.lambda_edge * edge_loss_value)

#         return total_loss"""

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision import models

# # Convert RGB to LAB using PyTorch
# def rgb_to_lab(image):
#     # The conversion logic from RGB to LAB will be approximated.
#     # For simplicity, here we assume the image is normalized to [0, 1].
#     image = (image + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
#     image = image.clamp(0, 1)  # Ensure pixel values are within [0, 1]

#     # You can use a library like OpenCV to convert to LAB, but here's a placeholder
#     # You would need to use `cv2.cvtColor(image, cv2.COLOR_RGB2LAB)` with actual data.
#     # Placeholder conversion: Assume yuv is LAB (for simplicity)
#     return image  # This would be replaced by actual LAB conversion logic

# # # Chromatic Aberration Loss in PyTorch
# # class ChromaticAberrationLoss(nn.Module):
# #     def __init__(self, lambda_color=1.0, lambda_spatial=1.0, lambda_perceptual=1.0, lambda_edge=1.0):
# #         super(ChromaticAberrationLoss, self).__init__()
# #         self.lambda_color = lambda_color
# #         self.lambda_spatial = lambda_spatial
# #         self.lambda_perceptual = lambda_perceptual
# #         self.lambda_edge = lambda_edge

# #         # Load the pre-trained VGG19 model for perceptual loss
# #         vgg = models.vgg19(pretrained=True).features
# #         self.vgg = nn.Sequential(*list(vgg.children())[:22])  # Extract up to 'block4_conv4'
# #         for param in self.vgg.parameters():
# #             param.requires_grad = False  # Freeze VGG19 parameters

# #     def forward(self, y_true, y_pred):
# #         # 1. Color Discrepancy Loss (L2 Norm in LAB space)
# #         y_true_lab = rgb_to_lab(y_true)
# #         y_pred_lab = rgb_to_lab(y_pred)
# #         color_loss = F.mse_loss(y_true_lab, y_pred_lab)

# #         # 2. Spatial Consistency Loss (L1 Norm between neighboring pixels)
# #         def spatial_loss(image):
# #             loss_x = F.l1_loss(image[:, :, :-1, :], image[:, :, 1:, :])
# #             loss_y = F.l1_loss(image[:, :, :, :-1], image[:, :, :, 1:])
# #             return loss_x + loss_y

# #         spatial_loss_value = spatial_loss(y_pred)

# #         # 3. Perceptual Loss using VGG19 features
# #         def perceptual_loss(y_true, y_pred):
# #             y_true_vgg = self.vgg(y_true)
# #             y_pred_vgg = self.vgg(y_pred)
# #             return F.mse_loss(y_true_vgg, y_pred_vgg)

# #         perceptual_loss_value = perceptual_loss(y_true, y_pred)

# #         # 4. Edge-Aware Loss (gradient difference in edge areas)
# #         def edge_aware_loss(y_true, y_pred):
# #             grad_true_x = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
# #             grad_true_y = y_true[:, :, :, 1:] - y_true[:, :, :, :-1]
# #             grad_pred_x = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
# #             grad_pred_y = y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1]
# #             edge_loss = F.l1_loss(grad_true_x, grad_pred_x) + F.l1_loss(grad_true_y, grad_pred_y)
# #             return edge_loss

# #         edge_loss_value = edge_aware_loss(y_true, y_pred)

# #         # Total Chromatic Aberration Loss
# #         total_loss = (self.lambda_color * color_loss) + \
# #                      (self.lambda_spatial * spatial_loss_value) + \
# #                      (self.lambda_perceptual * perceptual_loss_value) + \
# #                      (self.lambda_edge * edge_loss_value)

# #         return total_loss

# class ChromaticAberrationLoss(nn.Module):
#     def __init__(self, device, lambda_color=1.0, lambda_spatial=1.0, lambda_perceptual=1.0, lambda_edge=1.0):
#         super(ChromaticAberrationLoss, self).__init__()
#         self.lambda_color = lambda_color
#         self.lambda_spatial = lambda_spatial
#         self.lambda_perceptual = lambda_perceptual
#         self.lambda_edge = lambda_edge

#         # Load the pre-trained VGG19 model for perceptual loss and move it to the appropriate device
#         vgg = models.vgg19(pretrained=True).features
#         self.vgg = nn.Sequential(*list(vgg.children())[15:22])  # Extract up to 'block4_conv4'
#         self.vgg.to(device)  # Move the VGG model to the appropriate device (GPU/CPU)

#         for param in self.vgg.parameters():
#             param.requires_grad = False  # Freeze VGG19 parameters

#         self.device = device  # Store device for later use

#     def forward(self, y_true, y_pred):
#         # Move inputs to the same device as the model (VGG)
#         y_true = y_true.to(self.device)
#         y_pred = y_pred.to(self.device)

#         # 1. Color Discrepancy Loss (L2 Norm in LAB space)
#         y_true_lab = rgb_to_lab(y_true)
#         y_pred_lab = rgb_to_lab(y_pred)
#         color_loss = F.mse_loss(y_true_lab, y_pred_lab)

#         # 2. Spatial Consistency Loss (L1 Norm between neighboring pixels)
#         def spatial_loss(image):
#             loss_x = F.l1_loss(image[:, :, :-1, :], image[:, :, 1:, :])
#             loss_y = F.l1_loss(image[:, :, :, :-1], image[:, :, :, 1:])
#             return loss_x + loss_y

#         spatial_loss_value = spatial_loss(y_pred)

#         # 3. Perceptual Loss using VGG19 features
#         def perceptual_loss(y_true, y_pred):
#             y_true_vgg = self.vgg(y_true)
#             y_pred_vgg = self.vgg(y_pred)
#             return F.mse_loss(y_true_vgg, y_pred_vgg)

#         perceptual_loss_value = perceptual_loss(y_true, y_pred)

#         # 4. Edge-Aware Loss (gradient difference in edge areas)
#         def edge_aware_loss(y_true, y_pred):
#             grad_true_x = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
#             grad_true_y = y_true[:, :, :, 1:] - y_true[:, :, :, :-1]
#             grad_pred_x = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
#             grad_pred_y = y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1]
#             edge_loss = F.l1_loss(grad_true_x, grad_pred_x) + F.l1_loss(grad_true_y, grad_pred_y)
#             return edge_loss

#         edge_loss_value = edge_aware_loss(y_true, y_pred)

#         # Total Chromatic Aberration Loss
#         total_loss = (self.lambda_color * color_loss) + \
#                      (self.lambda_spatial * spatial_loss_value) + \
#                      (self.lambda_perceptual * perceptual_loss_value) + \
#                      (self.lambda_edge * edge_loss_value)

#         return total_loss

# criterion_GAN = nn.MSELoss()
# criterion_pixelwise = ChromaticAberrationLoss(device, lambda_color=1.0, lambda_spatial=1.0, lambda_perceptual=1.0, lambda_edge=1.0)

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms

class ChromaticAberrationLoss(nn.Module):
    def __init__(self, device, lambda_L1=1.0, lambda_CA=1.0, lambda_perceptual=1.0, kernel_size=5, sigma=3.0):
        super(ChromaticAberrationLoss, self).__init__()

        # Initialize weights for the loss terms
        self.lambda_L1 = lambda_L1
        self.lambda_CA = lambda_CA
        self.lambda_perceptual = lambda_perceptual

        # L1 Loss
        self.l1_loss = nn.L1Loss()

        # Gaussian Blur using torchvision.transforms
        self.gaussian_blur = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)

        # VGG19 for perceptual loss
        vgg = models.vgg19(pretrained=True).features
        self.vgg19 = vgg[:21].eval().to(device)  # Use layers up to block4_conv4
        for param in self.vgg19.parameters():
            param.requires_grad = False  # Freeze VGG19

    def forward(self, generated, target):
        """
        Calculate the combined loss.
        """
        # L1 Loss
        loss_L1 = self.lambda_L1 * self.l1_loss(generated, target)

        # Chromatic Aberration Loss
        loss_CA = self.lambda_CA * self.chromatic_aberration_loss(generated, target)

        # Perceptual Loss
        loss_perceptual = self.lambda_perceptual * self.perceptual_loss(generated, target)

        # Total Loss
        total_loss = loss_L1 + loss_CA + loss_perceptual
        return total_loss

    def chromatic_aberration_loss(self, generated, target):
        """
        Calculate chromatic aberration loss using Gaussian blur.
        """
        blurred_gen = self.gaussian_blur(generated)
        blurred_target = self.gaussian_blur(target)
        return F.mse_loss(blurred_gen, blurred_target)

    def perceptual_loss(self, generated, target):
        """
        Calculate perceptual loss using pre-trained VGG19.
        """
        gen_features = self.vgg19(generated)
        target_features = self.vgg19(target)
        return F.mse_loss(gen_features, target_features)

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_pixelwise = ChromaticAberrationLoss(device, lambda_L1=1.0, lambda_CA=1.0, lambda_perceptual=1.0, kernel_size=5, sigma=3.0)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 176MB/s]


# Optimizers and Transforms

In [17]:
optimizer_G = optim.Adam(generator.parameters(), lr=0.002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.002, betas=(0.5, 0.999))

# Dataset class for handling SAR input and DATALOADER

In [18]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torchvision.transforms as transforms

# Define the dataset class
class ImageDataset(Dataset):
    def __init__(self, SAR_root, color_root=None, transforms_=None):
        self.SAR_root = SAR_root
        self.color_root = color_root
        self.transforms = transforms_

        # Get sorted list of SAR images
        self.SAR_images = sorted([f for f in os.listdir(SAR_root) if f.endswith('.png')])

        # If color images are provided, get sorted list of color images
        if color_root:
            self.color_images = sorted([f for f in os.listdir(color_root) if f.endswith('.png')])
            # Ensure the number of SAR and color images are the same
            assert len(self.SAR_images) == len(self.color_images), "Mismatch between SAR and color image count"

    def __len__(self):
        return len(self.SAR_images)

    def __getitem__(self, idx):
        # Load SAR image
        SAR_image_path = os.path.join(self.SAR_root, self.SAR_images[idx])
        SAR_image = Image.open(SAR_image_path).convert('L')  # Grayscale (1 channel)

        if self.color_root:
            # Load corresponding color image
            color_image_path = os.path.join(self.color_root, self.color_images[idx])
            color_image = Image.open(color_image_path).convert('RGB')  # RGB (3 channels)

            # Apply transformations if available
            if self.transforms:
                SAR_image = self.transforms(SAR_image)
                color_image = self.transforms(color_image)

            return SAR_image, color_image
        else:
            # Apply transformations to SAR image only
            if self.transforms:
                SAR_image = self.transforms(SAR_image)
            return SAR_image

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to match your model input size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize (use 3 channels normalization if needed for color)
])

# DataLoader for training
dataloader = DataLoader(
    ImageDataset(
        SAR_root="/content/drive/MyDrive/prototype/Train/x",  # Path to SAR images
        color_root="/content/drive/MyDrive/prototype/Train/y",  # Path to color images
        transforms_=transform
    ),
    batch_size=32,  # Adjust the batch size as needed
    shuffle=True
)

# Training

In [None]:
# Paths to save the best models
save_path_generator = '/content/drive/MyDrive/prototype2/generator2.pth'
save_path_discriminator = '/content/drive/MyDrive/prototype2/discriminator2.pth'

# Early stopping parameters
patience = 5
best_loss_G = float('inf')
best_loss_D = float('inf')
no_improvement_G = 0
no_improvement_D = 0
no_of_epochs = 200

# Get the total number of batches
total_batches = len(dataloader)

# Start the training loop
for epoch in range(no_of_epochs):
    start_time = time.time()  # Record start time for ETA calculation
    epoch_loss_G = 0
    epoch_loss_D = 0

    dataloader_tqdm = tqdm(dataloader, desc=f'Epoch {epoch+1}/{no_of_epochs}', leave=False)

    for i, (SAR_imgs, color_imgs) in enumerate(dataloader_tqdm):
        SAR_imgs = SAR_imgs.to(device)
        color_imgs = color_imgs.to(device)

        valid = torch.ones((SAR_imgs.size(0), 1, 30, 30), requires_grad=False).to(device)  # Adjust to match image size
        fake = torch.zeros((SAR_imgs.size(0), 1, 30, 30), requires_grad=False).to(device)  # Adjust to match image size

        # ------------------
        # Train Generator
        # ------------------
        optimizer_G.zero_grad()

        fake_imgs = generator(SAR_imgs)
        # Resize fake images to match color images size if necessary
        fake_imgs_resized = F.interpolate(fake_imgs, size=color_imgs.size()[2:], mode='bilinear', align_corners=False)

        # Ensure that fake_imgs_resized and color_imgs have the same size
        assert fake_imgs_resized.size() == color_imgs.size(), f"Size mismatch: {fake_imgs_resized.size()} vs {color_imgs.size()}"

        # GAN loss (Discriminator should classify fake images as valid)
        loss_GAN = criterion_GAN(discriminator(fake_imgs_resized, SAR_imgs), valid)

        # Pixel-wise loss
        loss_pixelwise = criterion_pixelwise(fake_imgs_resized, color_imgs)

        # Total generator loss
        loss_G = loss_GAN + loss_pixelwise
        loss_G.backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        optimizer_G.step()

        # ------------------
        # Train Discriminator
        # ------------------
        optimizer_D.zero_grad()

        # Real images (Discriminator should classify real images as valid)
        loss_real = criterion_GAN(discriminator(color_imgs, SAR_imgs), valid)

        # Fake images (Discriminator should classify generated images as fake)
        loss_fake = criterion_GAN(discriminator(fake_imgs_resized.detach(), SAR_imgs), fake)

        # Total discriminator loss
        loss_D = 0.5 * (loss_real + loss_fake)
        loss_D.backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        optimizer_D.step()

        epoch_loss_G += loss_G.item()
        epoch_loss_D += loss_D.item()

        # Calculate elapsed time and ETA
        elapsed_time = time.time() - start_time
        avg_time_per_batch = elapsed_time / (i + 1)
        remaining_batches = total_batches - (i + 1)
        eta = avg_time_per_batch * remaining_batches

        # Update the tqdm description with ETA
        dataloader_tqdm.set_postfix({
            'D loss': f'{loss_D.item():.4f}',
            'G loss': f'{loss_G.item():.4f}',
            'ETA': f'{eta:.2f}s'
        })

    # Calculate average loss for the epoch
    avg_loss_G = epoch_loss_G / len(dataloader)
    avg_loss_D = epoch_loss_D / len(dataloader)

    print(f"\n[Epoch {epoch+1}/{no_of_epochs}] [Avg D loss: {avg_loss_D:.4f}] [Avg G loss: {avg_loss_G:.4f}]")
    print(f"Elapsed Time: {elapsed_time:.2f} seconds | ETA: {eta:.2f} seconds")

    # Early stopping and model checkpointing
    if avg_loss_G < best_loss_G:
        best_loss_G = avg_loss_G
        no_improvement_G = 0
        best_generator_state = generator.state_dict()
    else:
        no_improvement_G += 1

    if avg_loss_D < best_loss_D:
        best_loss_D = avg_loss_D
        no_improvement_D = 0
        best_discriminator_state = discriminator.state_dict()
    else:
        no_improvement_D += 1

    # Check for early stopping
    if no_improvement_G >= patience and no_improvement_D >= patience:
        print("Early stopping triggered. Training stopped.")
        break

# Save the final best models with architecture and optimizer state
torch.save({
    'model_state_dict': best_generator_state,
    'optimizer_state_dict': optimizer_G.state_dict(),
    'loss': best_loss_G,
}, save_path_generator)

torch.save({
    'model_state_dict': best_discriminator_state,
    'optimizer_state_dict': optimizer_D.state_dict(),
    'loss': best_loss_D,
}, save_path_discriminator)

print("Training complete. Best models saved.")

Epoch 1/200:  22%|██▏       | 7/32 [03:49<12:30, 30.02s/it, D loss=0.9382, G loss=6.9745, ETA=820.31s]

# Testing


In [None]:
'''
# Load the generator model for testing
checkpoint_G = torch.load("/content/drive/MyDrive/prototype/Models/generator/generator.pth")
generator.load_state_dict(checkpoint_G['model_state_dict'])
optimizer_G.load_state_dict(checkpoint_G['optimizer_state_dict'])
generator.eval()  # Set the generator to evaluation mode


# Load test dataset for inference
test_dataset = ImageDataset(
    SAR_root="/content/drive/MyDrive/prototype/Test",
    color_root=None,  # Not needed for testing
    transforms_=transform
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False
)

# Function to test the generator and display images
def test_generator(generator, test_loader, num_images=5):
    generator.eval()  # Set the generator to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for faster inference
        for i, SAR_imgs in enumerate(test_loader):
            SAR_imgs = SAR_imgs.to(device)  # Move SAR images to the device
            generated_imgs = generator(SAR_imgs)
            generated_imgs = 0.5 * (generated_imgs + 1)  # Denormalize from [-1, 1] to [0, 1]
            SAR_imgs = SAR_imgs.cpu()
            generated_imgs = generated_imgs.cpu()

            # Display the first few images in the batch
            for j in range(min(num_images, SAR_imgs.size(0))):
                fig, axes = plt.subplots(1, 2, figsize=(10, 5))

                # Input SAR image
                axes[0].imshow(SAR_imgs[j].squeeze(0), cmap='gray')  # Display as grayscale
                axes[0].set_title('Input SAR Image')
                axes[0].axis('off')

                # Output colorized image
                axes[1].imshow(transforms.ToPILImage()(generated_imgs[j]))
                axes[1].set_title('Generated Color Image')
                axes[1].axis('off')

                plt.show()

# Run testing and visualization
test_generator(generator, test_dataloader, num_images=5)
'''

In [None]:
import os
from PIL import Image

# Create Testpred directory if it doesn't exist
save_dir = "/content/drive/MyDrive/prototype2/Testpred"
os.makedirs(save_dir, exist_ok=True)

# Load the generator model for testing
checkpoint_G = torch.load("/content/drive/MyDrive/prototype2/generator2.pth", map_location=torch.device('cpu'))
generator.load_state_dict(checkpoint_G['model_state_dict'])
optimizer_G.load_state_dict(checkpoint_G['optimizer_state_dict'])
generator.eval()  # Set the generator to evaluation mode


# Load test dataset for inference
test_dataset = ImageDataset(
    SAR_root="/content/drive/MyDrive/prototype/Test",
    color_root=None,  # Not needed for testing
    transforms_=transform
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False
)


# Function to test the generator, display images, and save them
def test_generator(generator, test_loader, num_images=5, save_dir=save_dir):
    generator.eval()  # Set the generator to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for faster inference
        for i, SAR_imgs in enumerate(test_loader):
            SAR_imgs = SAR_imgs.to(device)  # Move SAR images to the device
            generated_imgs = generator(SAR_imgs)
            generated_imgs = 0.5 * (generated_imgs + 1)  # Denormalize from [-1, 1] to [0, 1]
            SAR_imgs = SAR_imgs.cpu()
            generated_imgs = generated_imgs.cpu()

            # Display the first few images in the batch
            for j in range(min(num_images, SAR_imgs.size(0))):
                fig, axes = plt.subplots(1, 2, figsize=(10, 5))

                # Input SAR image
                axes[0].imshow(SAR_imgs[j].squeeze(0), cmap='gray')  # Display as grayscale
                axes[0].set_title('Input SAR Image')
                axes[0].axis('off')

                # Output colorized image
                axes[1].imshow(transforms.ToPILImage()(generated_imgs[j]))
                axes[1].set_title('Generated Color Image')
                axes[1].axis('off')

                plt.show()

                # Save the generated image
                generated_img_pil = transforms.ToPILImage()(generated_imgs[j])
                img_save_path = os.path.join(save_dir, f'generated_image_{i}_{j}.png')
                generated_img_pil.save(img_save_path)
                print(f"Saved: {img_save_path}")

# Run testing, visualization, and saving
test_generator(generator, test_dataloader, num_images=6)
