<a href="https://colab.research.google.com/github/Disha-Sikka/SAR-to-EO-CycleGAN/blob/main/cycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Copying Extracted files in Colab's Local Disk

In [1]:
import shutil
import os

# store location of files in drive. So, that we can copy them
drive_path_s1 = '/content/drive/MyDrive/CycleGAN/winter_s1'
drive_path_s2 = '/content/drive/MyDrive/CycleGAN/winter_s2'

Data_Root_Dir= '/content/my_sen12ms_data_subset'
os.makedirs(Data_Root_Dir, exist_ok=True)

# store location of colab's paths. Where you want to copy files
colab_path_s1 = os.path.join(Data_Root_Dir, 'winter_s1')
colab_path_s2 = os.path.join(Data_Root_Dir, 'winter_s2')

print("Copying ROIs2017_winter_s1.tar.gz from Drive to Colab local disk...")
if os.path.exists(drive_path_s1):
    shutil.copytree(drive_path_s1, colab_path_s1) # copytree --> is used to copy a folder while copy is used to copy a zip file
    print("S1 file copied.")
else:
    print("Wrong Path")

print("Copying ROIs2017_winter_s2.tar.gz from Drive to Colab local disk...")
if os.path.exists(drive_path_s2):
    shutil.copytree(drive_path_s2, colab_path_s2)
    print("S2 file copied.")
else:
    print("Wrong Path")

print("Copying complete.")

Copying ROIs2017_winter_s1.tar.gz from Drive to Colab local disk...
S1 file copied.
Copying ROIs2017_winter_s2.tar.gz from Drive to Colab local disk...
S2 file copied.
Copying complete.


# Importing Libraries

In [2]:
import os
import glob
import random
import numpy as np
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
!pip install rasterio
import rasterio

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m84.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl (11 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1.2 cligj-0.7.2 rasterio-1.4.3


# Configuration

- For defining the parameters. So, that we can easily change them when we want.

In [17]:
class Config:
    # Data paths
    data_root_dir='/content/my_sen12ms_data_subset'
    SAR_DIR = 'winter_s1'
    EO_DIR = 'winter_s2'

    # Model parameters
    INPUT_NC = 3 # Number of input channels for SAR (Sentinel-1 GRD usually has 2: VV, VH)
    NGF = 64 # Number of generator filters in the first conv layer
    NDF = 64 # Number of discriminator filters in the first conv layer
    N_RESNET_BLOCKS = 6 # Number of ResNet blocks in the generator

    # Training parameters
    BATCH_SIZE = 1 # CycleGAN typically uses batch size 1
    NUM_EPOCHS = 5
    LR = 0.0002 # Learning rate
    BETA1 = 0.5 # Adam optimizer beta1
    LAMBDA_CYCLE = 10.0 # Weight for cycle consistency loss
    LAMBDA_IDENTITY = 5.0 # Weight for identity mapping loss helps stabilize

    # Image parameters
    IMAGE_SIZE = 256
    NUM_WORKERS = 4

    # Output and logging
    # Save outputs and checkpoints to Google Drive for persistence across sessions
    OUTPUT_BASE_DIR = '/content/drive/MyDrive/CycleGAN/SAR_EO_Project_Outputs' # Base directory in Drive
    OUTPUT_DIR = os.path.join(OUTPUT_BASE_DIR, 'output_cyclegan') # Specific output for images
    CHECKPOINT_DIR = os.path.join(OUTPUT_BASE_DIR, 'checkpoints_cyclegan') # Specific output for models

    SAVE_EPOCH_FREQ = 5 # Save model checkpoints every N epochs
    PRINT_FREQ = 1 # Print training loss every N batches

    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # EO Band Configurations (Sentinel-2 bands)
    # B1 (Coastal Aerosol), B2 (Blue), B3 (Green), B4 (Red), B5 (Red Edge 1),
    # B6 (Red Edge 2), B7 (Red Edge 3), B8 (NIR), B8A (NIR Narrow), B9 (Water Vapour),
    # B10 (SWIR - Cirrus), B11 (SWIR 1), B12 (SWIR 2)
    EO_BAND_CONFIGS = {
        "RGB": [4, 3, 2], # B4, B3, B2 (Red, Green, Blue)
        "NIR_SWIR_RedEdge": [8, 11, 5], # B8, B11, B5 (NIR, SWIR1, Red Edge 1)
        "RGB_NIR": [4, 3, 2, 8] # B4, B3, B2, B8 (Red, Green, Blue, NIR)
    }

    CURRENT_EO_CONFIG_NAME = "RGB"
    OUTPUT_NC = len(EO_BAND_CONFIGS[CURRENT_EO_CONFIG_NAME])


# Initialize configuration
config = Config()

# Create output directories if they don't exist in Google Drive
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)


In [4]:
config = Config()

# Create output directories if they don't exist in Google Drive
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)


# DataLoader and Preprocessing Custom Class


In [15]:
class Sen12MSDataset(Dataset):
  def __init__(self, root_dir, sar_dir, eo_dir, eo_bands, image_size=256):
    self.sar_root = os.path.join(root_dir, sar_dir)
    self.eo_root = os.path.join(root_dir, eo_dir)
    self.eo_bands = eo_bands # List of band numbers (e.g., [4, 3, 2]) from Config
    self.image_size = image_size

    # Find all .tif files within the SAR root, recursively
    self.sar_image_paths = sorted(glob.glob(os.path.join(self.sar_root, '**', '*.tif'), recursive=True))

    # Find all .tif files within the EO root, recursively
    self.eo_image_paths = sorted(glob.glob(os.path.join(self.eo_root, '**', '*.tif'), recursive=True))

    # Filter out unreadable/problematic files during initialization
    self.sar_image_paths = self._filter_unreadable_images(self.sar_image_paths, "SAR")
    self.eo_image_paths = self._filter_unreadable_images(self.eo_image_paths, "EO")


    self.pairs = self._match_sar_eo_pairs()


    print(f"Found {len(self.pairs)} matched SAR-EO pairs.")
    if len(self.pairs) == 0:
        print("WARNING: No SAR-EO pairs found. Please check your data paths and extraction.")
        print(f"SAR root: {self.sar_root}")
        print(f"EO root: {self.eo_root}")
        print(f"Example SAR path search: {os.path.join(self.sar_root, '**', '*.tif')}")
        print(f"Example EO path search: {os.path.join(self.eo_root, '**', '*.tif')}")


  def _filter_unreadable_images(self, image_paths, image_type="Image"):
        """
        Checks if images are readable by rasterio and filters out unreadable ones.
        """
        readable_paths = []
        for path in tqdm(image_paths, desc=f"Checking {image_type} readability"):
            try:
                with rasterio.open(path) as src:
                    # Try to read a band to ensure it's truly readable
                    _ = src.read(1)
                readable_paths.append(path)
            except Exception as e: # Catch broader exceptions for rasterio issues
                print(f"Skipping unreadable {image_type} file: {path} (Error: {e})")
        return readable_paths


  def _match_sar_eo_pairs(self):
      """
      Creates pseudo-pairs for unpaired training. Each SAR image is paired with a random EO image.
      This ensures the dataset has a non-zero length and provides data for training.
      """
      if not self.sar_image_paths or not self.eo_image_paths:
          return [] # No data found

      matched_pairs = []
      for sar_path in self.sar_image_paths:
          random_eo_path = random.choice(self.eo_image_paths)
          matched_pairs.append((sar_path, random_eo_path))

      return matched_pairs


  def __len__(self):
      return len(self.pairs)

  def __getitem__(self, idx):
      sar_path, eo_path = self.pairs[idx]

      # --- Load and Preprocess SAR Image ---
     # --- Load and Preprocess SAR Image ---
      with rasterio.open(sar_path) as src:
          sar_image_np_raw = src.read().astype(np.float32) # Read all bands, convert to float32

      # --- NEW ROBUST SAR CHANNEL HANDLING ---
      # Ensure sar_image_np has exactly config.INPUT_NC channels (which is 2)
      if sar_image_np_raw.shape[0] == config.INPUT_NC:
          sar_image_np = sar_image_np_raw # Perfect match
      elif sar_image_np_raw.shape[0] == 1 and config.INPUT_NC == 2:
      #     # If 1-channel, but 2 expected, repeat it (e.g., VV -> VV,VV)
          sar_image_np = np.repeat(sar_image_np_raw, 2, axis=0)
          print(f"Warning: SAR image {os.path.basename(sar_path)} is 1-channel ({sar_image_np_raw.shape[0]} channels), but INPUT_NC is {config.INPUT_NC}. Repeating channel.")
      elif sar_image_np_raw.shape[0] == 3 and config.INPUT_NC == 2:
      #     # If 3-channel (e.g., RGB), but 2 expected, take first 2 channels
          sar_image_np = sar_image_np_raw[:2, :, :]
          print(f"Warning: SAR image {os.path.basename(sar_path)} is 3-channel, but INPUT_NC is {config.INPUT_NC}. Taking first 2 channels.")
      elif sar_image_np_raw.shape[0] > config.INPUT_NC:
      #     # If more channels than expected (and not 3-channel handled above), take the first INPUT_NC channels
          sar_image_np = sar_image_np_raw[:config.INPUT_NC, :, :]
          print(f"Warning: SAR image {os.path.basename(sar_path)} has {sar_image_np_raw.shape[0]} channels, but INPUT_NC is {config.INPUT_NC}. Taking first {config.INPUT_NC} channels.")
      else: # sar_image_np_raw.shape[0] < config.INPUT_NC
      #     # Pad with zeros if fewer channels than expected
          padding_needed = config.INPUT_NC - sar_image_np_raw.shape[0]
          sar_image_np = np.pad(sar_image_np_raw, ((0, padding_needed), (0,0), (0,0)), mode='constant')
          print(f"Warning: SAR image {os.path.basename(sar_path)} has {sar_image_np_raw.shape[0]} channels, but INPUT_NC is {config.INPUT_NC}. Padding with zeros.")


      sar_image_tensor = torch.from_numpy(sar_image_np).float()
      sar_image_tensor = torch.nan_to_num(sar_image_tensor, nan=0.0, posinf=0.0, neginf=0.0)

      sar_image_tensor = F.interpolate(sar_image_tensor.unsqueeze(0), size=(self.image_size, self.image_size), mode='bicubic', align_corners=False).squeeze(0)

      sar_min = sar_image_tensor.min()
      sar_max = sar_image_tensor.max()
      if sar_max > sar_min:
            sar_image_tensor = (sar_image_tensor - sar_min) / (sar_max - sar_min) # Scale to [0, 1]
      else:
            sar_image_tensor = torch.zeros_like(sar_image_tensor)
      sar_image_tensor = sar_image_tensor * 2.0 - 1.0 # Scale to [-1, 1]

      # Load and Preprocess EO Image (3-channel RGB .tif)
      with rasterio.open(eo_path) as src:
            eo_image_np = src.read().astype(np.float32) # Read all bands, convert to float32
            # src.read() reads (C, H, W) directly.

      eo_image_tensor_full = torch.from_numpy(eo_image_np).float()

      # find this index and band number using QGIS
      band_number_to_channel_index = {
          4: 0, # Sentinel-2 Band 4 (Red) is at channel index 0
          3: 1, # Sentinel-2 Band 3 (Green) is at channel index 1
          2: 2, # Sentinel-2 Band 2 (Blue) is at channel index 2
      }

      selected_eo_bands_tensors = []
      for band_num in self.eo_bands: # Iterate through desired band numbers from Config (e.g., 4, 3, 2)
          channel_idx = band_number_to_channel_index.get(band_num, -1)

          if channel_idx == -1 or channel_idx >= eo_image_tensor_full.shape[0]:
              print(f"Warning: Desired band B{band_num} not found or index out of range in EO image {os.path.basename(eo_path)}. Filling with zeros.")
              # Create a zero-filled channel if the band is not present in the loaded TIFF
              selected_eo_bands_tensors.append(torch.zeros(1, eo_image_tensor_full.shape[1], eo_image_tensor_full.shape[2], device=config.DEVICE))
          else:
              selected_eo_bands_tensors.append(eo_image_tensor_full[channel_idx:channel_idx+1, :, :])

      if not selected_eo_bands_tensors:
          print(f"Error: No valid bands selected for EO image {os.path.basename(eo_path)}. Returning zeros.")
          eo_image_tensor = torch.zeros(config.OUTPUT_NC, self.image_size, self.image_size, device=config.DEVICE)
      else:
          eo_image_tensor = torch.cat(selected_eo_bands_tensors, dim=0)

      # Ensure the final EO tensor has the correct number of channels (OUTPUT_NC)
      if eo_image_tensor.shape[0] != config.OUTPUT_NC:
          print(f"Error: Final EO tensor has {eo_image_tensor.shape[0]} channels, but expected {config.OUTPUT_NC}. This indicates an issue with band selection or config.OUTPUT_NC.")
          if eo_image_tensor.shape[0] < config.OUTPUT_NC:
              padding_needed = config.OUTPUT_NC - eo_image_tensor.shape[0]
              eo_image_tensor = torch.cat([eo_image_tensor, torch.zeros(padding_needed, eo_image_tensor.shape[1], eo_image_tensor.shape[2], device=config.DEVICE)], dim=0)
          else:
              eo_image_tensor = eo_image_tensor[:config.OUTPUT_NC, :, :]


      # Normalize EO to [-1, 1]
      eo_max_val = 10000.0
      eo_image_tensor = torch.clamp(eo_image_tensor, 0, eo_max_val)
      eo_image_tensor = (eo_image_tensor / eo_max_val) * 2.0 - 1.0

      eo_image_tensor = F.interpolate(eo_image_tensor.unsqueeze(0), size=(self.image_size, self.image_size), mode='bicubic', align_corners=False).squeeze(0)

      return sar_image_tensor, eo_image_tensor

In [6]:
# Helper function for Convolutional Block
def conv_block(in_channels, out_channels, kernel_size, stride, padding, use_bias=False, norm_layer=nn.InstanceNorm2d, activation=nn.ReLU(True)):
    """A convolutional block with Conv2d, Normalization, and Activation."""
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=use_bias)
    ]
    if norm_layer:
        layers.append(norm_layer(out_channels))
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)

# Helper function for Transposed Convolutional Block (for upsampling)
def deconv_block(in_channels, out_channels, kernel_size, stride, padding, output_padding, use_bias=False, norm_layer=nn.InstanceNorm2d, activation=nn.ReLU(True)):
    """A transposed convolutional block with ConvTranspose2d, Normalization, and Activation."""
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=use_bias)
    ]
    if norm_layer:
        layers.append(norm_layer(out_channels))
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)


# Class ResNet


In [7]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, norm_layer=nn.InstanceNorm2d, use_bias=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            conv_block(dim, dim, kernel_size=3, stride=1, padding=0, use_bias=use_bias, norm_layer=norm_layer),
            nn.Dropout(0.5), # Added dropout for regularization
            nn.ReflectionPad2d(1),
            conv_block(dim, dim, kernel_size=3, stride=1, padding=0, use_bias=use_bias, norm_layer=norm_layer, activation=None) # No activation after second conv
        )

    def forward(self, x):
        return x + self.conv_block(x) # Residual connection


# Class Generator
- U-Net Based
- Translates images from Domain A to Domain B
- In this first we downsample the image, than pass the image through ResNet Block, than we decode the image by upscaling it and adding a conv layer which converts it to RGB format.

In [8]:
class Generator(nn.Module):
    """
    Generator model (U-Net based).
    Translates images from domain A to domain B.
    """
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9, norm_layer=nn.InstanceNorm2d):
        super(Generator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        model = [
            nn.ReflectionPad2d(3),
            conv_block(input_nc, ngf, kernel_size=7, stride=1, padding=0, use_bias=use_bias, norm_layer=norm_layer)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                conv_block(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, use_bias=use_bias, norm_layer=norm_layer)
            ]

        # ResNet blocks
        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_bias=use_bias)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                deconv_block(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, use_bias=use_bias, norm_layer=norm_layer)
            ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh() # Output activation to map to [-1, 1]
        ]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)


# Discriminator Class

- PatchGAN based
- Classifies whether the provided image is real or fake
- Convulational layer and Normalization while downscaling.
- At last layer, Conv layer and Sigmoid

In [9]:

class Discriminator(nn.Module):
    """
    Discriminator model (PatchGAN based).
    Classifies image patches as real or fake.
    """
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        super(Discriminator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4 # Kernel width/height
        padw = 1 # Padding

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True) # Leaky ReLU prevents dying ReLU problems and allow gradients to flow for negative inputs
        ]

        for i in range(1, n_layers):
            mult = 2 ** i
            model += [
                conv_block(ndf * mult // 2, ndf * mult, kernel_size=kw, stride=2, padding=padw, use_bias=use_bias, norm_layer=norm_layer, activation=nn.LeakyReLU(0.2, True))
            ]

        mult = 2 ** n_layers
        model += [
            conv_block(ndf * mult // 2, ndf * mult, kernel_size=kw, stride=1, padding=padw, use_bias=use_bias, norm_layer=norm_layer, activation=nn.LeakyReLU(0.2, True))
        ]

        model += [
            nn.Conv2d(ndf * mult, 1, kernel_size=kw, stride=1, padding=padw) # Output 1-channel prediction map
        ]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)


# Adversarial Loss Function

- It tells how well discriminator is working in distinguishing real and fake.
- It also tells how well generator is working in translating the image from one domain to other domain.


In [10]:
class GANLoss(nn.Module):
    """
    Adversarial Loss calculation of discriminator and generator functions. Loss Calculated using mean square
    """
    def __init__(self, gan_mode='mse', target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan': # Least Squares GAN
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla': # Standard GAN (Binary Cross Entropy)
            self.loss = nn.BCEWithLogitsLoss()
        else:
            raise NotImplementedError(f'GAN mode {gan_mode} not implemented.')

    def get_target_tensor(self, prediction, target_is_real):
        """Creates label tensors with the same size as the input prediction."""
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        return self.loss(prediction, target_tensor)


# Cycle Consistency Loss

- It calculates the loss while translating SAR to EO and than back to SAR
- Mean Absolute Error

In [11]:
class CycleConsistencyLoss(nn.Module):
    """
    L1 loss for cycle consistency.
    """
    def __init__(self, lambda_cycle=10.0): # higher the lambda--> higher is the consistency of cycle
        super(CycleConsistencyLoss, self).__init__()
        self.lambda_cycle = lambda_cycle
        self.loss = nn.L1Loss() # because it encourages pixel-wise accuracy

    def forward(self, real_image, cycled_image):
        return self.loss(real_image, cycled_image) * self.lambda_cycle


# Identity Loss

- Optional
- Done for colour preservation
- And training Stability


In [12]:
class IdentityLoss(nn.Module):
    """
    L1 loss for identity mapping.
    Encourages generators to preserve color composition when input is already from the target domain.
    """
    def __init__(self, lambda_identity=5.0):
        super(IdentityLoss, self).__init__()
        self.lambda_identity = lambda_identity
        self.loss = nn.L1Loss()

    def forward(self, real_image, identity_image):
        return self.loss(real_image, identity_image) * self.lambda_identity

# Utilities for Postprocessing and Metrics

- To convert pixels value from the range [-1,1] to [0,1] or [0,255]
- For showing the images SAR and generated EO side by side.
- To measure the similarity between two images.
- To calculate peak-signal-to-noise (PSNR) ratio.
- To calculate NDVI from EO image

In [13]:
def denormalize_image(tensor):
    """
    Denormalizes a tensor from [-1, 1] to [0, 1].
    Args:
        tensor (torch.Tensor): Image tensor in range [-1, 1].
    Returns:
        torch.Tensor: Image tensor in range [0, 1].
    """
    return (tensor + 1) / 2.0

def save_combined_image(sar_img, gen_eo_img, real_eo_img, filename):
    """
    Combines Real SAR, Generated EO, and Real EO images side-by-side and saves them.
    Assumes inputs are already denormalized to [0, 1].
    """
    if sar_img.shape[1] == 2: # Check channel dimension (N, C, H, W)
        sar_img_display = sar_img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) # Convert to grayscale 3-channel
    else:
        sar_img_display = sar_img.repeat(1, 3, 1, 1) if sar_img.shape[1] == 1 else sar_img

    # For EO, if it has 4 channels (RGB+NIR), select RGB for display
    if gen_eo_img.shape[1] == 4:
        gen_eo_img_display = gen_eo_img[:, :3, :, :] # Take first 3 channels (RGB)
        real_eo_img_display = real_eo_img[:, :3, :, :]
    else:
        gen_eo_img_display = gen_eo_img
        real_eo_img_display = real_eo_img

    # Concatenate images horizontally
    combined_image = torch.cat([sar_img_display, gen_eo_img_display, real_eo_img_display], dim=3) # Concatenate along width
    vutils.save_image(combined_image, filename, normalize=True, nrow=1) # normalize=True scales each image in the grid to [0,1]ng)

def calculate_ssim(img1, img2, data_range=1.0, multichannel=True):
    """
    Calculates SSIM between two images.
    Args:
        img1 (torch.Tensor): First image (N, C, H, W) or (C, H, W) in range [0,1].
        img2 (torch.Tensor): Second image (N, C, H, W) or (C, H, W) in range [0,1].
        data_range (float): The range of the data (e.g., 1.0 for [0,1], 255 for [0,255]).
        multichannel (bool): Set to True if images have multiple channels.
    Returns:
        float: SSIM score.
    """
    # SSIM expects numpy arrays and in the correct range
    img1_np = img1.squeeze(0).cpu().numpy() # Remove batch dim, move to CPU, convert to numpy
    img2_np = img2.squeeze(0).cpu().numpy()

    # Handle single channel images (remove channel dimension if it's 1 for SSIM)
    if img1_np.ndim == 3 and img1_np.shape[0] == 1:
        img1_np = img1_np[0]
        multichannel = False # Override if it's actually single channel
    if img2_np.ndim == 3 and img2_np.shape[0] == 1:
        img2_np = img2_np[0]
        multichannel = False

    # Transpose to (H, W, C) if it's (C, H, W) for multichannel
    if multichannel and img1_np.ndim == 3 and img1_np.shape[0] > 1:
        img1_np = np.transpose(img1_np, (1, 2, 0))
        img2_np = np.transpose(img2_np, (1, 2, 0))
    elif multichannel and img1_np.ndim == 2: # If it's 2D, it's not multichannel
        multichannel = False

    return ssim(img1_np, img2_np, data_range=data_range, multichannel=multichannel)

def calculate_psnr(img1, img2, data_range=1.0):
    """
    Calculates PSNR between two images.
    Args:
        img1 (torch.Tensor): First image (N, C, H, W) or (C, H, W) in range [0,1].
        img2 (torch.Tensor): Second image (N, C, H, W) or (C, H, W) in range [0,1].
        data_range (float): The range of the data (e.g., 1.0 for [0,1], 255 for [0,255]).
    Returns:
        float: PSNR score.
    """
    # PSNR expects numpy arrays and in the correct range
    img1_np = img1.squeeze(0).cpu().numpy()
    img2_np = img2.squeeze(0).cpu().numpy()

    # Handle single channel images (remove channel dimension if it's 1 for PSNR)
    if img1_np.ndim == 3 and img1_np.shape[0] == 1:
        img1_np = img1_np[0]
    if img2_np.ndim == 3 and img2_np.shape[0] == 1:
        img2_np = img2_np[0]

    return psnr(img1_np, img2_np, data_range=data_range)

def calculate_ndvi(eo_image_tensor, eo_band_config):
    """
    Calculates Normalized Difference Vegetation Index (NDVI) from an EO image tensor.
    NDVI = (NIR - Red) / (NIR + Red)
    Args:
        eo_image_tensor (torch.Tensor): EO image tensor (C, H, W) in range [0, 1].
                                        Assumes NIR and Red bands are present.
        eo_band_config (list): List of band indices used for the EO image.
    Returns:
        torch.Tensor: NDVI map (1, H, W) in range [-1, 1].
                      Returns None if required bands are not present.
    """
    # Find indices of NIR (B8) and Red (B4) in the current EO band configuration
    try:
        nir_idx = eo_band_config.index(8) # Sentinel-2 B8 is NIR
        red_idx = eo_band_config.index(4) # Sentinel-2 B4 is Red
    except ValueError:
        print("Warning: NIR (B8) or Red (B4) band not found in current EO configuration. Cannot calculate NDVI.")
        return None

    nir_band = eo_image_tensor[nir_idx, :, :]
    red_band = eo_image_tensor[red_idx, :, :]

    # Avoid division by zero
    denominator = nir_band + red_band
    # Add a small epsilon to avoid division by zero
    epsilon = 1e-6
    ndvi = (nir_band - red_band) / (denominator + epsilon)

    return ndvi.unsqueeze(0) # Add channel dimension back

# Training Funtion

In [14]:

def train_cyclegan():
    """
    Main function to design and train the CycleGAN model.
    """
    print(f"Using device: {config.DEVICE}")
    print(f"Current EO Output Configuration: {config.CURRENT_EO_CONFIG_NAME} (Bands: {config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME]})")
    print(f"Input Channels (SAR): {config.INPUT_NC}")
    print(f"Output Channels (EO): {config.OUTPUT_NC}")

    # Initialize Dataset and DataLoader
    dataset = Sen12MSDataset(
        root_dir=config.data_root_dir,
        sar_dir=config.SAR_DIR,
        eo_dir=config.EO_DIR,
        eo_bands=config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME],
        image_size=config.IMAGE_SIZE
    )
    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)

    # Initialize Generators and Discriminators
    # G_A: SAR -> EO, G_B: EO -> SAR
    G_A = Generator(config.INPUT_NC, config.OUTPUT_NC, config.NGF, config.N_RESNET_BLOCKS).to(config.DEVICE)
    G_B = Generator(config.OUTPUT_NC, config.INPUT_NC, config.NGF, config.N_RESNET_BLOCKS).to(config.DEVICE)
    # D_A: Discriminates real EO vs fake EO, D_B: Discriminates real SAR vs fake SAR
    D_A = Discriminator(config.OUTPUT_NC, config.NDF).to(config.DEVICE)
    D_B = Discriminator(config.INPUT_NC, config.NDF).to(config.DEVICE)

    # Initialize Optimizers
    optimizer_G = optim.Adam(list(G_A.parameters()) + list(G_B.parameters()), lr=config.LR, betas=(config.BETA1, 0.999))
    optimizer_D_A = optim.Adam(D_A.parameters(), lr=config.LR, betas=(config.BETA1, 0.999))
    optimizer_D_B = optim.Adam(D_B.parameters(), lr=config.LR, betas=(config.BETA1, 0.999))

    # Initialize Loss Functions
    criterion_GAN = GANLoss(gan_mode='lsgan').to(config.DEVICE)
    criterion_cycle = CycleConsistencyLoss(lambda_cycle=config.LAMBDA_CYCLE).to(config.DEVICE)
    criterion_identity = IdentityLoss(lambda_identity=config.LAMBDA_IDENTITY).to(config.DEVICE)

    # Training Loop
    print("Starting Training Loop...")
    for epoch in range(config.NUM_EPOCHS):
        G_A.train()
        G_B.train()
        D_A.train()
        D_B.train()

        for i, (real_sar, real_eo) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}")):
            real_sar = real_sar.to(config.DEVICE)
            real_eo = real_eo.to(config.DEVICE)

            # --- Train Generators G_A and G_B ---
            optimizer_G.zero_grad()

            # Identity loss (helps preserve color composition)
            # G_A is SAR -> EO. Identity loss should be G_A(real_eo) compared to real_eo
            identity_eo = G_A(real_eo)
            loss_identity_A = criterion_identity(identity_eo, real_eo)

            # G_B is EO -> SAR. Identity loss should be G_B(real_sar) compared to real_sar
            identity_sar = G_B(real_sar)
            loss_identity_B = criterion_identity(identity_sar, real_sar)


            # GAN loss D_A(G_A(real_sar))
            fake_eo = G_A(real_sar)
            pred_fake_eo = D_A(fake_eo)
            loss_GAN_A = criterion_GAN(pred_fake_eo, True) # G_A wants to fool D_A

            # GAN loss D_B(G_B(real_eo))
            fake_sar = G_B(real_eo)
            pred_fake_sar = D_B(fake_sar)
            loss_GAN_B = criterion_GAN(pred_fake_sar, True) # G_B wants to fool D_B

            # Cycle consistency loss
            # Cycle SAR -> EO -> SAR
            cycled_sar = G_B(fake_eo)
            loss_cycle_sar = criterion_cycle(cycled_sar, real_sar)
            # Cycle EO -> SAR -> EO
            cycled_eo = G_A(fake_sar)
            loss_cycle_eo = criterion_cycle(cycled_eo, real_eo)

            # # Total Generator Loss
            loss_G = loss_GAN_A + loss_GAN_B + loss_cycle_sar + loss_cycle_eo + \
                     loss_identity_A + loss_identity_B
            loss_G.backward()
            optimizer_G.step()

            # # --- Train Discriminator D_A (real EO vs fake EO) ---
            optimizer_D_A.zero_grad()
            # # Real loss
            pred_real_eo = D_A(real_eo)
            loss_D_A_real = criterion_GAN(pred_real_eo, True)
            # # Fake loss (detach fake_eo to stop gradients from flowing to G_A)
            pred_fake_eo = D_A(fake_eo.detach())
            loss_D_A_fake = criterion_GAN(pred_fake_eo, False)
            # Total D_A loss
            loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()

            # # --- Train Discriminator D_B (real SAR vs fake SAR) ---
            optimizer_D_B.zero_grad()
            # # Real loss
            pred_real_sar = D_B(real_sar)
            loss_D_B_real = criterion_GAN(pred_real_sar, True)
            # Fake loss (detach fake_sar to stop gradients from flowing to G_B)
            pred_fake_sar = D_B(fake_sar.detach())
            loss_D_B_fake = criterion_GAN(pred_fake_sar, False)
            # Total D_B loss
            loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()

            if i % config.PRINT_FREQ == 0:
                tqdm.write(f"Epoch [{epoch+1}/{config.NUM_EPOCHS}], Step [{i}/{len(dataloader)}]\n"
                           f"Loss_G: {loss_G.item():.4f} | Loss_G_GAN_A: {loss_GAN_A.item():.4f} | Loss_G_GAN_B: {loss_GAN_B.item():.4f}\n"
                           f"Loss_cycle_SAR: {loss_cycle_sar.item():.4f} | Loss_cycle_EO: {loss_cycle_eo.item():.4f}\n"
                           f"Loss_identity_A: {loss_identity_A.item():.4f} | Loss_identity_B: {loss_identity_B.item():.4f}\n"
                           f"Loss_D_A: {loss_D_A.item():.4f} | Loss_D_B: {loss_D_B.item():.4f}")

        # --- Save generated images and evaluate metrics at end of epoch ---
        G_A.eval()
        G_B.eval()
        with torch.no_grad():
            # Get a batch for visualization and metric calculation
            # Use a fixed batch for consistency in visualization
            try:
                # Try to get a new sample, or reuse the first one if dataloader is exhausted
                dataloader_eval = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS)
                sample_sar, sample_eo = next(iter(dataloader_eval))
            except StopIteration:
                 print("Warning: Dataloader exhausted during evaluation. This should not happen with shuffle=False.")
                 # If still no data, use dummy tensors or skip evaluation
                 if len(dataset) == 0:
                     print("Error: Dataset is empty. Skipping evaluation.")
                     continue
                 # Fallback: Use the first sample if available
                 sample_sar, sample_eo = dataset[0]
                 sample_sar = sample_sar.unsqueeze(0) # Add batch dimension
                 sample_eo = sample_eo.unsqueeze(0) # Add batch dimension


            sample_sar = sample_sar.to(config.DEVICE)
            sample_eo = sample_eo.to(config.DEVICE)

            # Generate fake EO from real SAR
            generated_eo = G_A(sample_sar)
            # Cycle back to SAR
            cycled_sar_from_eo = G_B(generated_eo)

            # Generate fake SAR from real EO
            generated_sar = G_B(sample_eo)
            # Cycle back to EO
            cycled_eo_from_sar = G_A(generated_sar)

            # Denormalize for saving and metric calculation (from [-1, 1] to [0, 1])
            real_sar_display = denormalize_image(sample_sar)
            real_eo_display = denormalize_image(sample_eo)
            generated_eo_display = denormalize_image(generated_eo)
            cycled_sar_display = denormalize_image(cycled_sar_from_eo)
            generated_sar_display = denormalize_image(generated_sar)
            cycled_eo_display = denormalize_image(cycled_eo_from_sar)

            # Save sample images
            # Need to adjust save_combined_image or use vutils.save_image directly for single images
            # vutils.save_image expects a grid, so create a grid of 1 image or save individually
            vutils.save_image(real_sar_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_real_sar.png"), normalize=True)
            vutils.save_image(real_eo_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_{config.CURRENT_EO_CONFIG_NAME}_real_eo.png"), normalize=True)
            vutils.save_image(generated_eo_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_{config.CURRENT_EO_CONFIG_NAME}_generated_eo.png"), normalize=True)
            vutils.save_image(cycled_sar_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_cycled_sar.png"), normalize=True)
            vutils.save_image(generated_sar_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_generated_sar.png"), normalize=True)
            vutils.save_image(cycled_eo_display, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_cycled_eo.png"), normalize=True)


            # --- Calculate Performance Metrics ---
            # For SSIM/PSNR, compare generated_eo with real_eo

            # SSIM and PSNR for SAR -> EO translation
            # Ensure images are (C, H, W) or (H, W) numpy arrays for SSIM/PSNR
            ssim_score = calculate_ssim(generated_eo_display[0], real_eo_display[0], multichannel=(config.OUTPUT_NC > 1))
            psnr_score = calculate_psnr(generated_eo_display[0], real_eo_display[0])
            print(f"Epoch {epoch+1} Metrics (SAR -> EO):")
            print(f"  SSIM: {ssim_score:.4f}")
            print(f"  PSNR: {psnr_score:.4f}")

            # NDVI calculation for generated EO and real EO
            # NDVI requires NIR (B8) and Red (B4) bands
            if 8 in config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME] and \
               4 in config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME]:

                # Calculate NDVI for real EO (take the first image in the batch)
                real_ndvi = calculate_ndvi(real_eo_display[0], config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME])
                # Calculate NDVI for generated EO (take the first image in the batch)
                generated_ndvi = calculate_ndvi(generated_eo_display[0], config.EO_BAND_CONFIGS[config.CURRENT_EO_CONFIG_NAME])

                if real_ndvi is not None and generated_ndvi is not None:
                    # Save NDVI maps visualize as grayscale images
                    vutils.save_image(real_ndvi, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_real_ndvi.png"), normalize=True)
                    vutils.save_image(generated_ndvi, os.path.join(config.OUTPUT_DIR, f"epoch_{epoch+1}_generated_ndvi.png"), normalize=True)


                    ndvi_ssim = calculate_ssim(generated_ndvi[0], real_ndvi[0], data_range=2.0, multichannel=False) # NDVI range is [-1, 1]
                    ndvi_psnr = calculate_psnr(generated_ndvi[0], real_ndvi[0], data_range=2.0)
                    print(f"  NDVI SSIM: {ndvi_ssim:.4f}")
                    print(f"  NDVI PSNR: {ndvi_psnr:.4f}")
            else:
                print("  NDVI not calculated: Required NIR (B8) or Red (B4) band missing in current EO configuration.")


        # Save model checkpoints
        if (epoch + 1) % config.SAVE_EPOCH_FREQ == 0:
            # Checkpoint directory is already set to Drive path in Config
            torch.save(G_A.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'G_A_epoch_{epoch+1}.pth'))
            torch.save(G_B.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'G_B_epoch_{epoch+1}.pth'))
            torch.save(D_A.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'D_A_epoch_{epoch+1}.pth'))
            torch.save(D_B.state_dict(), os.path.join(config.CHECKPOINT_DIR, f'D_B_epoch_{epoch+1}.pth'))
            print(f"Models saved after epoch {epoch+1}")

    print("Training Complete!")

# To Run

In [None]:
if __name__ == '__main__':
    train_cyclegan()


Using device: cpu
Current EO Output Configuration: RGB (Bands: [4, 3, 2])
Input Channels (SAR): 3
Output Channels (EO): 3


Checking SAR readability: 100%|██████████| 1947/1947 [00:13<00:00, 141.10it/s]
Checking EO readability: 100%|██████████| 1951/1951 [00:29<00:00, 66.33it/s]


Found 1947 matched SAR-EO pairs.
Starting Training Loop...


Epoch 1/5:   0%|          | 0/1947 [00:00<?, ?it/s]




Epoch 1/5:   0%|          | 1/1947 [00:36<19:55:51, 36.87s/it]

Epoch [1/5], Step [0/1947]
Loss_G: 17.9890 | Loss_G_GAN_A: 1.0083 | Loss_G_GAN_B: 1.2274
Loss_cycle_SAR: 5.5331 | Loss_cycle_EO: 4.9577
Loss_identity_A: 2.4995 | Loss_identity_B: 2.7630
Loss_D_A: 0.5590 | Loss_D_B: 0.6567


Epoch 1/5:   0%|          | 2/1947 [01:09<18:28:32, 34.20s/it]

Epoch [1/5], Step [1/1947]
Loss_G: 11.6449 | Loss_G_GAN_A: 0.2981 | Loss_G_GAN_B: 0.4195
Loss_cycle_SAR: 4.1571 | Loss_cycle_EO: 3.0779
Loss_identity_A: 1.6226 | Loss_identity_B: 2.0697
Loss_D_A: 0.5592 | Loss_D_B: 0.6433


Epoch 1/5:   0%|          | 3/1947 [01:41<17:54:11, 33.15s/it]

Epoch [1/5], Step [2/1947]
Loss_G: 8.4642 | Loss_G_GAN_A: 0.5459 | Loss_G_GAN_B: 0.2512
Loss_cycle_SAR: 3.3434 | Loss_cycle_EO: 1.7814
Loss_identity_A: 0.8884 | Loss_identity_B: 1.6538
Loss_D_A: 0.4227 | Loss_D_B: 1.0588


Epoch 1/5:   0%|          | 4/1947 [02:13<17:42:16, 32.80s/it]

Epoch [1/5], Step [3/1947]
Loss_G: 8.7172 | Loss_G_GAN_A: 0.9198 | Loss_G_GAN_B: 0.8910
Loss_cycle_SAR: 2.7858 | Loss_cycle_EO: 1.7975
Loss_identity_A: 0.9317 | Loss_identity_B: 1.3914
Loss_D_A: 0.2335 | Loss_D_B: 0.7890


Epoch 1/5:   0%|          | 5/1947 [02:46<17:42:00, 32.81s/it]

Epoch [1/5], Step [4/1947]
Loss_G: 9.0904 | Loss_G_GAN_A: 0.8261 | Loss_G_GAN_B: 1.6111
Loss_cycle_SAR: 2.7553 | Loss_cycle_EO: 1.6926
Loss_identity_A: 0.8410 | Loss_identity_B: 1.3644
Loss_D_A: 0.1267 | Loss_D_B: 0.8487


Epoch 1/5:   0%|          | 6/1947 [03:18<17:31:19, 32.50s/it]

Epoch [1/5], Step [5/1947]
Loss_G: 8.6017 | Loss_G_GAN_A: 0.7826 | Loss_G_GAN_B: 0.9054
Loss_cycle_SAR: 2.6913 | Loss_cycle_EO: 1.9406
Loss_identity_A: 0.9441 | Loss_identity_B: 1.3378
Loss_D_A: 0.1094 | Loss_D_B: 0.5005


Epoch 1/5:   0%|          | 7/1947 [03:49<17:20:37, 32.18s/it]

Epoch [1/5], Step [6/1947]
Loss_G: 7.9329 | Loss_G_GAN_A: 0.8109 | Loss_G_GAN_B: 0.9703
Loss_cycle_SAR: 2.1446 | Loss_cycle_EO: 1.9735
Loss_identity_A: 0.9848 | Loss_identity_B: 1.0488
Loss_D_A: 0.0784 | Loss_D_B: 0.5191


Epoch 1/5:   0%|          | 8/1947 [04:21<17:17:37, 32.11s/it]

Epoch [1/5], Step [7/1947]
Loss_G: 6.3770 | Loss_G_GAN_A: 1.0017 | Loss_G_GAN_B: 0.4638
Loss_cycle_SAR: 2.2514 | Loss_cycle_EO: 1.0239
Loss_identity_A: 0.5330 | Loss_identity_B: 1.1032
Loss_D_A: 0.0750 | Loss_D_B: 0.3071


Epoch 1/5:   0%|          | 9/1947 [04:54<17:28:24, 32.46s/it]

Epoch [1/5], Step [8/1947]
Loss_G: 6.5863 | Loss_G_GAN_A: 0.9694 | Loss_G_GAN_B: 0.4355
Loss_cycle_SAR: 2.5175 | Loss_cycle_EO: 0.9453
Loss_identity_A: 0.4702 | Loss_identity_B: 1.2485
Loss_D_A: 0.0859 | Loss_D_B: 0.3058


Epoch 1/5:   1%|          | 10/1947 [05:26<17:22:08, 32.28s/it]

Epoch [1/5], Step [9/1947]
Loss_G: 6.1479 | Loss_G_GAN_A: 0.8003 | Loss_G_GAN_B: 0.4340
Loss_cycle_SAR: 2.4031 | Loss_cycle_EO: 0.8330
Loss_identity_A: 0.4950 | Loss_identity_B: 1.1825
Loss_D_A: 0.0910 | Loss_D_B: 0.2678


Epoch 1/5:   1%|          | 11/1947 [05:58<17:14:45, 32.07s/it]

Epoch [1/5], Step [10/1947]
Loss_G: 7.4798 | Loss_G_GAN_A: 0.7812 | Loss_G_GAN_B: 0.4943
Loss_cycle_SAR: 1.7567 | Loss_cycle_EO: 2.3935
Loss_identity_A: 1.1934 | Loss_identity_B: 0.8608
Loss_D_A: 0.1035 | Loss_D_B: 0.2634


Epoch 1/5:   1%|          | 12/1947 [06:30<17:16:42, 32.15s/it]

Epoch [1/5], Step [11/1947]
Loss_G: 7.8608 | Loss_G_GAN_A: 1.0238 | Loss_G_GAN_B: 0.4439
Loss_cycle_SAR: 1.7408 | Loss_cycle_EO: 2.5406
Loss_identity_A: 1.2633 | Loss_identity_B: 0.8484
Loss_D_A: 0.0731 | Loss_D_B: 0.2515


Epoch 1/5:   1%|          | 13/1947 [07:03<17:21:21, 32.31s/it]

Epoch [1/5], Step [12/1947]
Loss_G: 8.0988 | Loss_G_GAN_A: 1.0283 | Loss_G_GAN_B: 0.3867
Loss_cycle_SAR: 2.0103 | Loss_cycle_EO: 2.4617
Loss_identity_A: 1.2318 | Loss_identity_B: 0.9800
Loss_D_A: 0.0362 | Loss_D_B: 0.2290


Epoch 1/5:   1%|          | 14/1947 [07:34<17:13:09, 32.07s/it]

Epoch [1/5], Step [13/1947]
Loss_G: 5.1260 | Loss_G_GAN_A: 0.8834 | Loss_G_GAN_B: 0.3967
Loss_cycle_SAR: 1.7269 | Loss_cycle_EO: 0.8601
Loss_identity_A: 0.4293 | Loss_identity_B: 0.8296
Loss_D_A: 0.0420 | Loss_D_B: 0.2295


Epoch 1/5:   1%|          | 15/1947 [08:06<17:05:37, 31.85s/it]

Epoch [1/5], Step [14/1947]
Loss_G: 4.9880 | Loss_G_GAN_A: 1.0005 | Loss_G_GAN_B: 0.4175
Loss_cycle_SAR: 1.7637 | Loss_cycle_EO: 0.6259
Loss_identity_A: 0.3107 | Loss_identity_B: 0.8697
Loss_D_A: 0.0457 | Loss_D_B: 0.1998


Epoch 1/5:   1%|          | 16/1947 [08:37<16:59:46, 31.69s/it]

Epoch [1/5], Step [15/1947]
Loss_G: 5.5379 | Loss_G_GAN_A: 0.6813 | Loss_G_GAN_B: 0.4665
Loss_cycle_SAR: 1.9697 | Loss_cycle_EO: 0.9858
Loss_identity_A: 0.4710 | Loss_identity_B: 0.9636
Loss_D_A: 0.0833 | Loss_D_B: 0.1851
