<a href="https://colab.research.google.com/github/Disha-Sikka/SAR-to-EO-CycleGAN/blob/main/cycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')


# Copying Extracted files in Colab's Local Disk

In [None]:
import shutil
import os

# store location of files in drive. So, that we can copy them
drive_path_s1 = '/content/drive/MyDrive/CycleGAN/winter_s1'
drive_path_s2 = '/content/drive/MyDrive/CycleGAN/winter_s2'

Data_Root_Dir= '/content/my_sen12ms_data_subset'
os.makedirs(Data_Root_Dir, exist_ok=True)

# store location of colab's paths. Where you want to copy files
colab_path_s1 = os.path.join(Data_Root_Dir, 'winter_s1')
colab_path_s2 = os.path.join(Data_Root_Dir, 'winter_s2')

print("Copying ROIs2017_winter_s1.tar.gz from Drive to Colab local disk...")
if os.path.exists(drive_path_s1):
    shutil.copytree(drive_path_s1, colab_path_s1) # copytree --> is used to copy a folder while copy is used to cop a zip file
    print("S1 file copied.")
else:
    print("Wrong Path")

print("Copying ROIs2017_winter_s2.tar.gz from Drive to Colab local disk...")
if os.path.exists(drive_path_s2):
    shutil.copytree(drive_path_s2, colab_path_s2)
    print("S2 file copied.")
else:
    print("Wrong Path")

print("Copying complete.")

# Importing Libraries

In [None]:
import os
import glob
import random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Configuration

- For defining the parameters. So, that we can easily change them when we want.

In [None]:
class Config:
    # Data paths
    data_root_dir='/content/my_sen12ms_data_subset'
    SAR_DIR = 'winter_s1'
    EO_DIR = 'winter_s2'

    # Model parameters
    INPUT_NC = 2 # Number of input channels for SAR (Sentinel-1 GRD usually has 2: VV, VH)
    NGF = 64 # Number of generator filters in the first conv layer
    NDF = 64 # Number of discriminator filters in the first conv layer
    N_RESNET_BLOCKS = 6 # Number of ResNet blocks in the generator

    # Training parameters
    BATCH_SIZE = 1 # CycleGAN typically uses batch size 1
    NUM_EPOCHS = 20
    LR = 0.0002 # Learning rate
    BETA1 = 0.5 # Adam optimizer beta1
    LAMBDA_CYCLE = 10.0 # Weight for cycle consistency loss
    LAMBDA_IDENTITY = 5.0 # Weight for identity mapping loss helps stabilize

    # Image parameters
    IMAGE_SIZE = 256
    NUM_WORKERS = 4

    # Output and logging
    # Save outputs and checkpoints to Google Drive for persistence across sessions
    OUTPUT_BASE_DIR = '/content/drive/MyDrive/CycleGAN/SAR_EO_Project_Outputs' # Base directory in Drive
    OUTPUT_DIR = os.path.join(OUTPUT_BASE_DIR, 'output_cyclegan') # Specific output for images
    CHECKPOINT_DIR = os.path.join(OUTPUT_BASE_DIR, 'checkpoints_cyclegan') # Specific output for models

    SAVE_EPOCH_FREQ = 5 # Save model checkpoints every N epochs
    PRINT_FREQ = 1 # Print training loss every N batches

    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # EO Band Configurations (Sentinel-2 bands)
    # B1 (Coastal Aerosol), B2 (Blue), B3 (Green), B4 (Red), B5 (Red Edge 1),
    # B6 (Red Edge 2), B7 (Red Edge 3), B8 (NIR), B8A (NIR Narrow), B9 (Water Vapour),
    # B10 (SWIR - Cirrus), B11 (SWIR 1), B12 (SWIR 2)
    EO_BAND_CONFIGS = {
        "RGB": [4, 3, 2], # B4, B3, B2 (Red, Green, Blue)
        "NIR_SWIR_RedEdge": [8, 11, 5], # B8, B11, B5 (NIR, SWIR1, Red Edge 1)
        "RGB_NIR": [4, 3, 2, 8] # B4, B3, B2, B8 (Red, Green, Blue, NIR)
    }

    CURRENT_EO_CONFIG_NAME = "RGB"
    OUTPUT_NC = len(EO_BAND_CONFIGS[CURRENT_EO_CONFIG_NAME])


# Initialize configuration
config = Config()

# Create output directories if they don't exist in Google Drive
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)


In [None]:
config = Config()

# Create output directories if they don't exist in Google Drive
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)


# DataLoader and Preprocessing Custom Class


In [None]:
class Sen12MSDataset(Dataset):
    def __init__(self, root_dir, sar_dir, eo_dir, eo_bands, image_size=256):
        self.sar_root = os.path.join(root_dir, sar_dir)
        self.eo_root = os.path.join(root_dir, eo_dir)
        self.eo_bands = eo_bands # List of band indices (1-indexed from original paper)
        self.image_size = image_size

        # Find all .tif files within the SAR root, recursively
        self.sar_image_paths = sorted(glob.glob(os.path.join(self.sar_root, '**', '*.tif'), recursive=True))

        self.eo_image_paths = sorted(glob.glob(os.path.join(self.eo_root, '**', '*.tif'), recursive=True))


        self.pairs = self._match_sar_eo_pairs() # Match SAR images to EO

        self.transform = transforms.Compose([
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(), # Converts to [0, 1]
        ])

        print(f"Found {len(self.pairs)} matched SAR-EO pairs.")
        if len(self.pairs) == 0:
            print("WARNING: No SAR-EO pairs found. Please check your data paths and extraction.")
            print(f"SAR root: {self.sar_root}")
            print(f"EO root: {self.eo_root}")
            print(f"Example SAR path search: {os.path.join(self.sar_root, '**', '*.tif')}")
            print(f"Example EO path search: {os.path.join(self.eo_root, '**', '*.tif')}")


    def _match_sar_eo_pairs(self):
        """
        Matches SAR and EO image paths based on their common ROI folder and image ID.
        """
        matched_pairs = []

        # Map SAR image IDs to their full paths
        sar_image_id_to_path = {}
        for p in self.sar_image_paths:
            base_name = os.path.basename(p) # e.g., ROIs2017_winter_s1_21_p91.tif
            # Get 'ROIs2017_winter_21_p91'
            clean_image_id = base_name.replace('.tif', '').replace('_s1_', '_')
            sar_image_id_to_path[clean_image_id] = p

        # Map EO image IDs to their full paths
        eo_image_id_to_path = {}
        for p in self.eo_image_paths:
            base_name = os.path.basename(p) # e.g., ROIs2017_winter_s2_21_p100.tif
            # Get 'ROIs2017_winter_21_p100'
            clean_image_id = base_name.replace('.tif', '').replace('_s2_', '_')
            eo_image_id_to_path[clean_image_id] = p

        # Create pseudo-pairs by matching cleaned image IDs
        # This will create pairs like (SAR_path_p91, EO_path_p91) if patch IDs match
        # or (SAR_path_p91, EO_path_p100) if only ROI matches and we pick one.


        # This will create a list of (SAR_path, random_EO_path)

        if not self.sar_image_paths or not self.eo_image_paths:
            return [] # No data found

        # Create pseudo-pairs: each SAR image is paired with a random EO image
        # This ensures the dataset has a non-zero length and provides data for training.
        matched_pairs = []
        for sar_path in self.sar_image_paths:
            random_eo_path = random.choice(self.eo_image_paths)
            matched_pairs.append((sar_path, random_eo_path))

        return matched_pairs


    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        sar_path, eo_path = self.pairs[idx]

        # --- Load and Preprocess SAR Image ---
        sar_image_pil = Image.open(sar_path)
        sar_image_np = np.array(sar_image_pil)

        if sar_image_np.ndim == 2: # If grayscale, add channel dimension
            sar_image_np = sar_image_np[:, :, np.newaxis] # (H, W, 1)

        if sar_image_np.shape[2] == 1 and config.INPUT_NC == 2:
            sar_image_np = np.repeat(sar_image_np, 2, axis=2)
            print(f"Warning: SAR image {os.path.basename(sar_path)} is 1-channel, but INPUT_NC is 2. Repeating channel.")
        elif sar_image_np.shape[2] > config.INPUT_NC:
            # If SAR has more channels than expected, take the first INPUT_NC channels
            sar_image_np = sar_image_np[:, :, :config.INPUT_NC]
            print(f"Warning: SAR image {os.path.basename(sar_path)} has {sar_image_np.shape[2]} channels, but INPUT_NC is {config.INPUT_NC}. Taking first {config.INPUT_NC} channels.")

        if sar_image_np.dtype == np.uint16:
            sar_image_tensor = torch.from_numpy(sar_image_np.astype(np.float32)) / 65535.0
        else:
            sar_image_tensor = torch.from_numpy(sar_image_np).float()

        sar_image_tensor = torch.nan_to_num(sar_image_tensor, nan=0.0, posinf=0.0, neginf=0.0)

        # Permute to (C, H, W)
        sar_image_tensor = sar_image_tensor.permute(2, 0, 1)

        sar_min = sar_image_tensor.min()
        sar_max = sar_image_tensor.max()
        if sar_max > sar_min:
            sar_image_tensor = (sar_image_tensor - sar_min) / (sar_max - sar_min)
        else:
            sar_image_tensor = torch.zeros_like(sar_image_tensor)
        sar_image_tensor = sar_image_tensor * 2.0 - 1.0

        sar_image_tensor = transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BICUBIC)(sar_image_tensor)


        # --- Load and Preprocess EO Image (Multi-channel .tif) ---
        eo_image_pil = Image.open(eo_path)
        eo_image_np = np.array(eo_image_pil)

        if eo_image_np.ndim == 2:
            eo_image_np = eo_image_np[:, :, np.newaxis]

        if eo_image_np.dtype == np.uint16:
            eo_image_tensor_full = torch.from_numpy(eo_image_np.astype(np.float32)) / 65535.0
        else:
            eo_image_tensor_full = torch.from_numpy(eo_image_np).float()

        eo_image_tensor_full = eo_image_tensor_full.permute(2, 0, 1) # Convert to (C, H, W)

        selected_eo_bands_tensors = []

        num_channels_in_eo_tif = eo_image_tensor_full.shape[0]
        if num_channels_in_eo_tif < config.OUTPUT_NC:
            print(f"Warning: EO image {os.path.basename(eo_path)} has {num_channels_in_eo_tif} channels, but config.OUTPUT_NC is {config.OUTPUT_NC}. Filling missing channels with zeros.")
            # Pad with zeros if the EO .tif has fewer channels than required by OUTPUT_NC
            padding_needed = config.OUTPUT_NC - num_channels_in_eo_tif
            eo_image_tensor = torch.cat([eo_image_tensor_full, torch.zeros(padding_needed, eo_image_tensor_full.shape[1], eo_image_tensor_full.shape[2], device=config.DEVICE)], dim=0)
        elif num_channels_in_eo_tif > config.OUTPUT_NC:
            print(f"Warning: EO image {os.path.basename(eo_path)} has {num_channels_in_eo_tif} channels, but config.OUTPUT_NC is {config.OUTPUT_NC}. Taking first {config.OUTPUT_NC} channels.")
            eo_image_tensor = eo_image_tensor_full[:config.OUTPUT_NC, :, :]
        else: # Perfect match
            eo_image_tensor = eo_image_tensor_full

        # Normalize EO to [-1, 1]
        eo_max_val = 10000.0
        eo_image_tensor = torch.clamp(eo_image_tensor, 0, eo_max_val)
        eo_image_tensor = (eo_image_tensor / eo_max_val) * 2.0 - 1.0

        return sar_image_tensor, eo_image_tensor


In [None]:
# Helper function for Convolutional Block
def conv_block(in_channels, out_channels, kernel_size, stride, padding, use_bias=False, norm_layer=nn.InstanceNorm2d, activation=nn.ReLU(True)):
    """A convolutional block with Conv2d, Normalization, and Activation."""
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=use_bias)
    ]
    if norm_layer:
        layers.append(norm_layer(out_channels))
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)

# Helper function for Transposed Convolutional Block (for upsampling)
def deconv_block(in_channels, out_channels, kernel_size, stride, padding, output_padding, use_bias=False, norm_layer=nn.InstanceNorm2d, activation=nn.ReLU(True)):
    """A transposed convolutional block with ConvTranspose2d, Normalization, and Activation."""
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=use_bias)
    ]
    if norm_layer:
        layers.append(norm_layer(out_channels))
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)


# Class ResNet


In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, norm_layer=nn.InstanceNorm2d, use_bias=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            conv_block(dim, dim, kernel_size=3, stride=1, padding=0, use_bias=use_bias, norm_layer=norm_layer),
            nn.Dropout(0.5), # Added dropout for regularization
            nn.ReflectionPad2d(1),
            conv_block(dim, dim, kernel_size=3, stride=1, padding=0, use_bias=use_bias, norm_layer=norm_layer, activation=None) # No activation after second conv
        )

    def forward(self, x):
        return x + self.conv_block(x) # Residual connection


# Class Generator
- U-Net Based
- Translates images from Domain A to Domain B
- In this first we downsample the image, than pass the image through ResNet Block, than we decode the image by upscaling it and adding a conv layer which converts it to RGB format.

In [None]:
class Generator(nn.Module):
    """
    Generator model (U-Net based).
    Translates images from domain A to domain B.
    """
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9, norm_layer=nn.InstanceNorm2d):
        super(Generator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        model = [
            nn.ReflectionPad2d(3),
            conv_block(input_nc, ngf, kernel_size=7, stride=1, padding=0, use_bias=use_bias, norm_layer=norm_layer)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                conv_block(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, use_bias=use_bias, norm_layer=norm_layer)
            ]

        # ResNet blocks
        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_bias=use_bias)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                deconv_block(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, use_bias=use_bias, norm_layer=norm_layer)
            ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh() # Output activation to map to [-1, 1]
        ]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)


# Discriminator Class

- PatchGAN based
- Classifies whether the provided image is real or fake
- Convulational layer and Normalization while downscaling.
- At last layer, Conv layer and Sigmoid

In [None]:
class Discriminator(nn.Module):
    """
    Discriminator model (PatchGAN based).
    Classifies image patches as real or fake.
    """
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        super(Discriminator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4 # Kernel width/height
        padw = 1 # Padding

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True) # Leaky ReLU prevents dying ReLU problems and allow gradients to flow for negative inputs
        ]

        for i in range(1, n_layers):
            mult = 2 ** i
            model += [
                conv_block(ndf * mult // 2, ndf * mult, kernel_size=kw, stride=2, padding=padw, use_bias=use_bias, norm_layer=norm_layer, activation=nn.LeakyReLU(0.2, True))
            ]

        mult = 2 ** n_layers
        model += [
            conv_block(ndf * mult // 2, ndf * mult, kernel_size=kw, stride=1, padding=padw, use_bias=use_bias, norm_layer=norm_layer, activation=nn.LeakyReLU(0.2, True))
        ]

        model += [
            nn.Conv2d(ndf * mult, 1, kernel_size=kw, stride=1, padding=padw) # Output 1-channel prediction map
        ]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)


# Adversarial Loss Function

- It tells how well discriminator is working in distinguishing real and fake.
- It also tells how well generator is working in translating the image from one domain to other domain.


In [None]:
class GANLoss(nn.Module):
    """
    Adversarial Loss calculation of discriminator and generator functions. Loss Calculated using mean square
    """
    def __init__(self, gan_mode='mse', target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan': # Least Squares GAN
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla': # Standard GAN (Binary Cross Entropy)
            self.loss = nn.BCEWithLogitsLoss()
        else:
            raise NotImplementedError(f'GAN mode {gan_mode} not implemented.')

    def get_target_tensor(self, prediction, target_is_real):
        """Creates label tensors with the same size as the input prediction."""
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        return self.loss(prediction, target_tensor)


# Cycle Consistency Loss

- It calculates the loss while translating SAR to EO and than back to SAR
- Mean Absolute Error

In [None]:
class CycleConsistencyLoss(nn.Module):
    """
    L1 loss for cycle consistency.
    """
    def __init__(self, lambda_cycle=10.0): # higher the lambda--> higher is the consistency of cycle
        super(CycleConsistencyLoss, self).__init__()
        self.lambda_cycle = lambda_cycle
        self.loss = nn.L1Loss() # because it encourages pixel-wise accuracy

    def forward(self, real_image, cycled_image):
        return self.loss(real_image, cycled_image) * self.lambda_cycle


# Identity Loss

- Optional
- Done for colour preservation
- And training Stability


In [None]:
class IdentityLoss(nn.Module):
    """
    L1 loss for identity mapping.
    Encourages generators to preserve color composition when input is already from the target domain.
    """
    def __init__(self, lambda_identity=5.0):
        super(IdentityLoss, self).__init__()
        self.lambda_identity = lambda_identity
        self.loss = nn.L1Loss()

    def forward(self, real_image, identity_image):
        return self.loss(real_image, identity_image) * self.lambda_identity

# Utilities for Postprocessing and Metrics

- To convert pixels value from the range [-1,1] to [0,1] or [0,255]
- For showing the images SAR and generated EO side by side.
- To measure the similarity between two images.
- To calculate peak-signal-to-noise (PSNR) ratio.
- To calculate NDVI from EO image

In [None]:
def denormalize_image(tensor):
    """
    Denormalizes a tensor from [-1, 1] to [0, 1].
    Args:
        tensor (torch.Tensor): Image tensor in range [-1, 1].
    Returns:
        torch.Tensor: Image tensor in range [0, 1].
    """
    return (tensor + 1) / 2.0

def save_combined_image(sar_img, gen_eo_img, real_eo_img, filename):
    """
    Combines Real SAR, Generated EO, and Real EO images side-by-side and saves them.
    Assumes inputs are already denormalized to [0, 1].
    """
    if sar_img.shape[1] == 2: # Check channel dimension (N, C, H, W)
        sar_img_display = sar_img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) # Convert to grayscale 3-channel
    else:
        sar_img_display = sar_img.repeat(1, 3, 1, 1) if sar_img.shape[1] == 1 else sar_img

    # For EO, if it has 4 channels (RGB+NIR), select RGB for display
    if gen_eo_img.shape[1] == 4:
        gen_eo_img_display = gen_eo_img[:, :3, :, :] # Take first 3 channels (RGB)
        real_eo_img_display = real_eo_img[:, :3, :, :]
    else:
        gen_eo_img_display = gen_eo_img
        real_eo_img_display = real_eo_img

    # Concatenate images horizontally
    combined_image = torch.cat([sar_img_display, gen_eo_img_display, real_eo_img_display], dim=3) # Concatenate along width
    vutils.save_image(combined_image, filename, normalize=True, nrow=1) # normalize=True scales each image in the grid to [0,1]ng)

def calculate_ssim(img1, img2, data_range=1.0, multichannel=True):
    """
    Calculates SSIM between two images.
    Args:
        img1 (torch.Tensor): First image (N, C, H, W) or (C, H, W) in range [0,1].
        img2 (torch.Tensor): Second image (N, C, H, W) or (C, H, W) in range [0,1].
        data_range (float): The range of the data (e.g., 1.0 for [0,1], 255 for [0,255]).
        multichannel (bool): Set to True if images have multiple channels.
    Returns:
        float: SSIM score.
    """
    # SSIM expects numpy arrays and in the correct range
    img1_np = img1.squeeze(0).cpu().numpy() # Remove batch dim, move to CPU, convert to numpy
    img2_np = img2.squeeze(0).cpu().numpy()

    # Handle single channel images (remove channel dimension if it's 1 for SSIM)
    if img1_np.ndim == 3 and img1_np.shape[0] == 1:
        img1_np = img1_np[0]
        multichannel = False # Override if it's actually single channel
    if img2_np.ndim == 3 and img2_np.shape[0] == 1:
        img2_np = img2_np[0]
        multichannel = False

    # Transpose to (H, W, C) if it's (C, H, W) for multichannel
    if multichannel and img1_np.ndim == 3 and img1_np.shape[0] > 1:
        img1_np = np.transpose(img1_np, (1, 2, 0))
        img2_np = np.transpose(img2_np, (1, 2, 0))
    elif multichannel and img1_np.ndim == 2: # If it's 2D, it's not multichannel
        multichannel = False

    return ssim(img1_np, img2_np, data_range=data_range, multichannel=multichannel)

def calculate_psnr(img1, img2, data_range=1.0):
    """
    Calculates PSNR between two images.
    Args:
        img1 (torch.Tensor): First image (N, C, H, W) or (C, H, W) in range [0,1].
        img2 (torch.Tensor): Second image (N, C, H, W) or (C, H, W) in range [0,1].
        data_range (float): The range of the data (e.g., 1.0 for [0,1], 255 for [0,255]).
    Returns:
        float: PSNR score.
    """
    # PSNR expects numpy arrays and in the correct range
    img1_np = img1.squeeze(0).cpu().numpy()
    img2_np = img2.squeeze(0).cpu().numpy()

    # Handle single channel images (remove channel dimension if it's 1 for PSNR)
    if img1_np.ndim == 3 and img1_np.shape[0] == 1:
        img1_np = img1_np[0]
    if img2_np.ndim == 3 and img2_np.shape[0] == 1:
        img2_np = img2_np[0]

    return psnr(img1_np, img2_np, data_range=data_range)

def calculate_ndvi(eo_image_tensor, eo_band_config):
    """
    Calculates Normalized Difference Vegetation Index (NDVI) from an EO image tensor.
    NDVI = (NIR - Red) / (NIR + Red)
    Args:
        eo_image_tensor (torch.Tensor): EO image tensor (C, H, W) in range [0, 1].
                                        Assumes NIR and Red bands are present.
        eo_band_config (list): List of band indices used for the EO image.
    Returns:
        torch.Tensor: NDVI map (1, H, W) in range [-1, 1].
                      Returns None if required bands are not present.
    """
    # Find indices of NIR (B8) and Red (B4) in the current EO band configuration
    try:
        nir_idx = eo_band_config.index(8) # Sentinel-2 B8 is NIR
        red_idx = eo_band_config.index(4) # Sentinel-2 B4 is Red
    except ValueError:
        print("Warning: NIR (B8) or Red (B4) band not found in current EO configuration. Cannot calculate NDVI.")
        return None

    nir_band = eo_image_tensor[nir_idx, :, :]
    red_band = eo_image_tensor[red_idx, :, :]

    # Avoid division by zero
    denominator = nir_band + red_band
    # Add a small epsilon to avoid division by zero
    epsilon = 1e-6
    ndvi = (nir_band - red_band) / (denominator + epsilon)

    return ndvi.unsqueeze(0) # Add channel dimension back

# Training Funtion

In [None]:
def train_cyclegan():
    """
    Main function to design and train the CycleGAN model.
    """
    print(f"Using device: {config.DEVICE}")
    print(f"Current EO Output Configuration: {config.CURRENT_EO_CONFIG_NAME} (Bands: {config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME]})")
    print(f"Input Channels (SAR): {config.INPUT_NC}")
    print(f"Output Channels (EO): {config.OUTPUT_NC}")

    # Initialize Dataset and DataLoader
    dataset = Sen12MSDataset(
        root_dir=config.data_root_dir,
        sar_dir=config.SAR_DIR,
        eo_dir=config.EO_DIR,
        eo_bands=config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME],
        image_size=config.IMAGE_SIZE
    )
    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)

    # Initialize Generators and Discriminators
    # G_A: SAR -> EO, G_B: EO -> SAR
    G_A = Generator(config.INPUT_NC, config.OUTPUT_NC, config.NGF, config.N_RESNET_BLOCKS).to(config.DEVICE)
    G_B = Generator(config.OUTPUT_NC, config.INPUT_NC, config.NGF, config.N_RESNET_BLOCKS).to(config.DEVICE)
    # D_A: Discriminates real EO vs fake EO, D_B: Discriminates real SAR vs fake SAR
    D_A = Discriminator(config.OUTPUT_NC, config.NDF).to(config.DEVICE)
    D_B = Discriminator(config.INPUT_NC, config.NDF).to(config.DEVICE)

    # Initialize Optimizers
    optimizer_G = optim.Adam(list(G_A.parameters()) + list(G_B.parameters()), lr=config.LR, betas=(config.BETA1, 0.999))
    optimizer_D_A = optim.Adam(D_A.parameters(), lr=config.LR, betas=(config.BETA1, 0.999))
    optimizer_D_B = optim.Adam(D_B.parameters(), lr=config.LR, betas=(config.BETA1, 0.999))

    # Initialize Loss Functions
    criterion_GAN = GANLoss(gan_mode='lsgan').to(config.DEVICE)
    criterion_cycle = CycleConsistencyLoss(lambda_cycle=config.LAMBDA_CYCLE).to(config.DEVICE)
    criterion_identity = IdentityLoss(lambda_identity=config.LAMBDA_IDENTITY).to(config.DEVICE)

    # Training Loop
    print("Starting Training Loop...")
    for epoch in range(config.NUM_EPOCHS):
        G_A.train()
        G_B.train()
        D_A.train()
        D_B.train()

        for i, (real_sar, real_eo) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}")):
            real_sar = real_sar.to(config.DEVICE)
            real_eo = real_eo.to(config.DEVICE)

            # --- Train Generators G_A and G_B ---
            optimizer_G.zero_grad()

            # Identity loss (optional, but helps preserve color composition)
            # G_A should produce real_eo when given real_eo as input
            identity_eo = G_A(real_eo)
            loss_identity_A = criterion_identity(identity_eo, real_eo)
            # G_B should produce real_sar when given real_sar as input
            identity_sar = G_B(real_sar)
            loss_identity_B = criterion_identity(identity_sar, real_sar)

            # GAN loss D_A(G_A(real_sar))
            fake_eo = G_A(real_sar)
            pred_fake_eo = D_A(fake_eo)
            loss_GAN_A = criterion_GAN(pred_fake_eo, True) # G_A wants to fool D_A

            # GAN loss D_B(G_B(real_eo))
            fake_sar = G_B(real_eo)
            pred_fake_sar = D_B(fake_sar)
            loss_GAN_B = criterion_GAN(pred_fake_sar, True) # G_B wants to fool D_B

            # Cycle consistency loss
            # Cycle SAR -> EO -> SAR
            cycled_sar = G_B(fake_eo)
            loss_cycle_sar = criterion_cycle(cycled_sar, real_sar)
            # Cycle EO -> SAR -> EO
            cycled_eo = G_A(fake_sar)
            loss_cycle_eo = criterion_cycle(cycled_eo, real_eo)

            # Total Generator Loss
            loss_G = loss_GAN_A + loss_GAN_B + loss_cycle_sar + loss_cycle_eo + \
                     loss_identity_A + loss_identity_B
            loss_G.backward()
            optimizer_G.step()

            # --- Train Discriminator D_A (real EO vs fake EO) ---
            optimizer_D_A.zero_grad()
            # Real loss
            pred_real_eo = D_A(real_eo)
            loss_D_A_real = criterion_GAN(pred_real_eo, True)
            # Fake loss (detach fake_eo to stop gradients from flowing to G_A)
            pred_fake_eo = D_A(fake_eo.detach())
            loss_D_A_fake = criterion_GAN(pred_fake_eo, False)
            # Total D_A loss
            loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()

            # --- Train Discriminator D_B (real SAR vs fake SAR) ---
            optimizer_D_B.zero_grad()
            # Real loss
            pred_real_sar = D_B(real_sar)
            loss_D_B_real = criterion_GAN(pred_real_sar, True)
            # Fake loss (detach fake_sar to stop gradients from flowing to G_B)
            pred_fake_sar = D_B(fake_sar.detach())
            loss_D_B_fake = criterion_GAN(pred_fake_sar, False)
            # Total D_B loss
            loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()

            if i % config.PRINT_FREQ == 0:
                tqdm.write(f"Epoch [{epoch+1}/{config.NUM_EPOCHS}], Step [{i}/{len(dataloader)}]\n"
                           f"Loss_G: {loss_G.item():.4f} | Loss_G_GAN_A: {loss_GAN_A.item():.4f} | Loss_G_GAN_B: {loss_GAN_B.item():.4f}\n"
                           f"Loss_cycle_SAR: {loss_cycle_sar.item():.4f} | Loss_cycle_EO: {loss_cycle_eo.item():.4f}\n"
                           f"Loss_identity_A: {loss_identity_A.item():.4f} | Loss_identity_B: {loss_identity_B.item():.4f}\n"
                           f"Loss_D_A: {loss_D_A.item():.4f} | Loss_D_B: {loss_D_B.item():.4f}")

        # --- Save generated images and evaluate metrics at end of epoch ---
        G_A.eval()
        G_B.eval()
        with torch.no_grad():
            # Get a batch for visualization and metric calculation
            # Use a fixed batch for consistency in visualization
            try:
                # Try to get a new sample, or reuse the first one if dataloader is exhausted
                sample_sar, sample_eo = next(iter(dataloader))
            except StopIteration:
                # If dataloader is exhausted, re-initialize it for evaluation
                dataloader_eval = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS)
                sample_sar, sample_eo = next(iter(dataloader_eval))


            sample_sar = sample_sar.to(config.DEVICE)
            sample_eo = sample_eo.to(config.DEVICE)

            # Generate fake EO from real SAR
            generated_eo = G_A(sample_sar)
            # Cycle back to SAR
            cycled_sar_from_eo = G_B(generated_eo)

            # Generate fake SAR from real EO
            generated_sar = G_B(sample_eo)
            # Cycle back to EO
            cycled_eo_from_sar = G_A(generated_sar)

            # Denormalize for saving and metric calculation (from [-1, 1] to [0, 1])
            real_sar_display = denormalize_image(sample_sar)
            real_eo_display = denormalize_image(sample_eo)
            generated_eo_display = denormalize_image(generated_eo)
            cycled_sar_display = denormalize_image(cycled_sar_from_eo)
            generated_sar_display = denormalize_image(generated_sar)
            cycled_eo_display = denormalize_image(cycled_eo_from_sar)

            # Save sample images
            save_combined_image(real_sar_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_real_sar.png"))
            save_combined_image(real_eo_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_{config.CURRENT_EO_CONFIG_NAME}_real_eo.png"))
            save_combined_image(generated_eo_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_{config.CURRENT_EO_CONFIG_NAME}_generated_eo.png"))
            save_combined_image(cycled_sar_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_cycled_sar.png"))
            save_combined_image(generated_sar_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_generated_sar.png"))
            save_combined_image(cycled_eo_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_cycled_eo.png"))

            # --- Calculate Performance Metrics ---
            # For SSIM/PSNR, compare generated_eo with real_eo

            # SSIM and PSNR for SAR -> EO translation
            ssim_score = calculate_ssim(generated_eo_display, real_eo_display, multichannel=(config.OUTPUT_NC > 1))
            psnr_score = calculate_psnr(generated_eo_display, real_eo_display)
            print(f"Epoch {epoch+1} Metrics (SAR -> EO):")
            print(f"  SSIM: {ssim_score:.4f}")
            print(f"  PSNR: {psnr_score:.4f}")

            # NDVI calculation for generated EO and real EO
            # NDVI requires NIR (B8) and Red (B4) bands
            if 8 in config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME] and \
               4 in config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME]:

                # Calculate NDVI for real EO (take the first image in the batch)
                real_ndvi = calculate_ndvi(real_eo_display[0], config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME])
                # Calculate NDVI for generated EO (take the first image in the batch)
                generated_ndvi = calculate_ndvi(generated_eo_display[0], config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME])

                if real_ndvi is not None and generated_ndvi is not None:
                    # Save NDVI maps visualize as grayscale images
                    save_combined_image(real_ndvi, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_real_ndvi.png"))
                    save_combined_image(generated_ndvi, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_generated_ndvi.png"))

                    ndvi_ssim = calculate_ssim(generated_ndvi, real_ndvi, data_range=2.0, multichannel=False) # NDVI range is [-1, 1]
                    ndvi_psnr = calculate_psnr(generated_ndvi, real_ndvi, data_range=2.0)
                    print(f"  NDVI SSIM: {ndvi_ssim:.4f}")
                    print(f"  NDVI PSNR: {ndvi_psnr:.4f}")
            else:
                print("  NDVI not calculated: Required NIR (B8) or Red (B4) band missing in current EO configuration.")


        # Save model checkpoints
        if (epoch + 1) % config.SAVE_EPOCH_FREQ == 0:
            # Checkpoint directory is already set to Drive path in Config
            torch.save(G_A.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'G_A_epoch_{epoch+1}.pth'))
            torch.save(G_B.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'G_B_epoch_{epoch+1}.pth'))
            torch.save(D_A.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'D_A_epoch_{epoch+1}.pth'))
            torch.save(D_B.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'D_B_epoch_{epoch+1}.pth'))
            print(f"Models saved after epoch {epoch+1}")

    print("Training Complete!")


# To Run

In [None]:
if __name__ == '__main__':
    train_cyclegan()
