<a href="https://colab.research.google.com/github/Disha-Sikka/SAR-to-EO-CycleGAN/blob/main/cycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Copying Extracted files in Colab's Local Disk

In [2]:
import shutil
import os

# store location of files in drive. So, that we can copy them
drive_path_s1 = '/content/drive/MyDrive/CycleGAN/winter_s1'
drive_path_s2 = '/content/drive/MyDrive/CycleGAN/winter_s2'

# store location of colab's paths. Where you want to copy files
colab_path_s1 = '/content/ROIs2017_winter_s1.tar.gz'
colab_path_s2 = '/content/ROIs2017_winter_s2.tar.gz'

print("Copying ROIs2017_winter_s1.tar.gz from Drive to Colab local disk...")
if os.path.exists(drive_path_s1):
    shutil.copytree(drive_path_s1, colab_path_s1) # copytree --> is used to copy a folder while copy is used to cop a zip file
    print("S1 file copied.")
else:
    print("Wrong Path")

print("Copying ROIs2017_winter_s2.tar.gz from Drive to Colab local disk...")
if os.path.exists(drive_path_s2):
    shutil.copytree(drive_path_s2, colab_path_s2)
    print("S2 file copied.")
else:
    print("Wrong Path")

print("Copying complete.")

Copying ROIs2017_winter_s1.tar.gz from Drive to Colab local disk...
S1 file copied.
Copying ROIs2017_winter_s2.tar.gz from Drive to Colab local disk...
S2 file copied.
Copying complete.


# Importing Libraries

In [3]:
import os
import glob
import random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Configuration

- For defining the parameters. So, that we can easily change them when we want.

In [4]:
class Config:
    # Data paths
    SAR_DIR = '/content/ROIs2017_winter_s1.tar.gz'
    EO_DIR = '/content/ROIs2017_winter_s2.tar.gz'

    # Model parameters
    INPUT_NC = 2 # Number of input channels for SAR (Sentinel-1 GRD usually has 2: VV, VH)
    NGF = 64 # Number of generator filters in the first conv layer
    NDF = 64 # Number of discriminator filters in the first conv layer
    N_RESNET_BLOCKS = 6 # Number of ResNet blocks in the generator

    # Training parameters
    BATCH_SIZE = 1 # CycleGAN typically uses batch size 1
    NUM_EPOCHS = 20
    LR = 0.0002 # Learning rate
    BETA1 = 0.5 # Adam optimizer beta1
    LAMBDA_CYCLE = 10.0 # Weight for cycle consistency loss
    LAMBDA_IDENTITY = 5.0 # Weight for identity mapping loss helps stabilize

    # Image parameters
    IMAGE_SIZE = 256
    NUM_WORKERS = 4

    # Output and logging
    # Save outputs and checkpoints to Google Drive for persistence across sessions
    OUTPUT_BASE_DIR = '/content/drive/MyDrive/CycleGAN/SAR_EO_Project_Outputs' # Base directory in Drive
    OUTPUT_DIR = os.path.join(OUTPUT_BASE_DIR, 'output_cyclegan') # Specific output for images
    CHECKPOINT_DIR = os.path.join(OUTPUT_BASE_DIR, 'checkpoints_cyclegan') # Specific output for models

    SAVE_EPOCH_FREQ = 5 # Save model checkpoints every N epochs
    PRINT_FREQ = 1 # Print training loss every N batches

    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # EO Band Configurations (Sentinel-2 bands)
    # B1 (Coastal Aerosol), B2 (Blue), B3 (Green), B4 (Red), B5 (Red Edge 1),
    # B6 (Red Edge 2), B7 (Red Edge 3), B8 (NIR), B8A (NIR Narrow), B9 (Water Vapour),
    # B10 (SWIR - Cirrus), B11 (SWIR 1), B12 (SWIR 2)
    EO_BAND_CONFIGS = {
        "RGB": [4, 3, 2], # B4, B3, B2 (Red, Green, Blue)
        "NIR_SWIR_RedEdge": [8, 11, 5], # B8, B11, B5 (NIR, SWIR1, Red Edge 1)
        "RGB_NIR": [4, 3, 2, 8] # B4, B3, B2, B8 (Red, Green, Blue, NIR)
    }

    CURRENT_EO_CONFIG_NAME = "RGB"
    OUTPUT_NC = len(EO_BAND_CONFIGS[CURRENT_EO_CONFIG_NAME])


# Initialize configuration
config = Config()

# Create output directories if they don't exist in Google Drive
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)


In [5]:
config = Config()

# Create output directories if they don't exist in Google Drive
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)


# DataLoader and Preprocessing Custom Class


In [7]:
class Sen12MSDataset(Dataset):
    def __init__(self, sar_dir, eo_dir, eo_bands, image_size=256):
        self.sar_root = os.path.join(sar_dir)
        self.eo_root = os.path.join(eo_dir)
        self.eo_bands = eo_bands # List of band indices (1-indexed from original paper)
        self.image_size = image_size

        self.sar_image_paths = sorted(glob.glob(os.path.join(self.sar_root, '**', '*_s1_*.tif'), recursive=True))
        # For EO, we need to find the base path for each image pair, then load specific bands
        self.eo_base_paths = sorted(glob.glob(os.path.join(self.eo_root, '**', '*_s2_B*.tif'), recursive=True))

        self.eo_image_groups = self._group_eo_files(self.eo_base_paths)

        # to match SAR to EO
        self.pairs = self._match_sar_eo_pairs()

        self.transform = transforms.Compose([
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(), # Converts to [0, 1]
        ])

        print(f"Found {len(self.pairs)} matched SAR-EO pairs.")
        if len(self.pairs) == 0:
            print("WARNING: No SAR-EO pairs found. Please check your data paths and extraction.")
            print(f"SAR root: {self.sar_root}")
            print(f"EO root: {self.eo_root}")
            print(f"Example SAR path search: {os.path.join(self.sar_root, '**', '*_s1_*.tif')}")
            print(f"Example EO path search: {os.path.join(self.eo_root, '**', '*_s2_B*.tif')}")


    def _group_eo_files(self, eo_paths):
        """Groups EO band files by their common image ID."""
        groups = {}
        for path in eo_paths:
            base_name = '_'.join(os.path.basename(path).split('_')[:-1])
            if base_name not in groups:
                groups[base_name] = []
            groups[base_name].append(path)
        return groups

    def _match_sar_eo_pairs(self):
        """
        Matches SAR and EO image paths based on their common identifier.
        Needs to be done beacause we need this while calculating PSNR, NDVI metrices
        """
        matched_pairs = []

        # Created a dictionary of SAR image IDs to their full paths
        # SAR example: 'ROIs2017_winter_s1_21_p92.tif'
        # We want to extract 'ROIs2017_winter_21_p92'
        sar_ids_map = {}
        for p in self.sar_image_paths:
            base_name = os.path.basename(p)
            parts = base_name.rsplit('_', 1) # Split from right once by '_'
            if len(parts) > 1 and parts[-1].endswith('.tif'):
                clean_id = parts[0].replace('_s1_', '_') # e.g., ROIs2017_winter_21_p92
                sar_ids_map[clean_id] = p
            else: # Fallback if naming is different
                 clean_id = base_name.replace('.tif', '').replace('_s1_', '_')
                 sar_ids_map[clean_id] = p


        for eo_id_raw, eo_band_paths in self.eo_image_groups.items():
            # eo_id_raw example: 'ROIs2017_winter_s2_21_p10_s2'
            # We want to extract 'ROIs2017_winter_21_p10'
            clean_eo_id = eo_id_raw.replace('_s2', '').replace('_s2_', '_') # e.g., ROIs2017_winter_21_p10

            if clean_eo_id in sar_ids_map:
                matched_pairs.append((sar_ids_map[clean_eo_id], eo_band_paths))
            else:
                print(f"No matching SAR found for EO ID: {eo_id_raw} (cleaned: {clean_eo_id})")

        return matched_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        sar_path, eo_band_paths = self.pairs[idx]

        sar_image_pil = Image.open(sar_path) # loads image

        # Convert to numpy array to handle channels if PIL doesn't load it as multi-channel directly
        sar_image_np = np.array(sar_image_pil)

        # If SAR is grayscale (H, W), convert to (H, W, 1) then to (1, H, W)
        if sar_image_np.ndim == 2:
            sar_image_np = sar_image_np[:, :, np.newaxis] # Add channel dim (H, W, 1)


        if sar_image_np.dtype == np.uint16:
            sar_image_tensor = torch.from_numpy(sar_image_np.astype(np.float32)) / 65535.0 # Scale to [0, 1]
        else:
            sar_image_tensor = torch.from_numpy(sar_image_np).float()

        sar_image_tensor = torch.nan_to_num(sar_image_tensor, nan=0.0, posinf=0.0, neginf=0.0)

        if sar_image_tensor.ndim == 3 and sar_image_tensor.shape[0] not in [1, 2]: # If first dim is not channel count
             sar_image_tensor = sar_image_tensor.permute(2, 0, 1) # Assuming (H, W, C) to (C, H, W)

        # Normalize SAR to [-1, 1] (after initial [0,1] scaling from uint16 or float conversion)
        # Re-normalize to [-1, 1] based on the current tensor's min/max for robustness
        sar_min = sar_image_tensor.min()
        sar_max = sar_image_tensor.max()
        if sar_max > sar_min:
            sar_image_tensor = (sar_image_tensor - sar_min) / (sar_max - sar_min) # Scale to [0, 1]
        else:
            sar_image_tensor = torch.zeros_like(sar_image_tensor)
        sar_image_tensor = sar_image_tensor * 2.0 - 1.0 # Scale to [-1, 1]

        # Resize SAR image
        sar_image_tensor = transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BICUBIC)(sar_image_tensor)


        eo_images_list = []
        sorted_eo_band_paths = sorted(eo_band_paths, key=lambda x: int(os.path.basename(x).split('_B')[-1].split('.')[0]))

        for band_idx in self.eo_bands:
            band_path = next((p for p in sorted_eo_band_paths if f'_B{band_idx}.tif' in p), None)
            if band_path is None:
                print(f"Warning: Band B{band_idx} not found for EO image group {os.path.basename(os.path.dirname(sar_path))}. Filling with zeros.")
                dummy_array = np.zeros((self.image_size, self.image_size), dtype=np.uint16)
                eo_band_img = Image.fromarray(dummy_array)
            else:
                eo_band_img = Image.open(band_path).convert('I') # 'I' for 32-bit signed integer pixels

            eo_images_list.append(self.transform(eo_band_img))

        eo_image_tensor = torch.cat(eo_images_list, dim=0)

        eo_max_val = 10000.0
        eo_image_tensor = torch.clamp(eo_image_tensor, 0, eo_max_val)
        eo_image_tensor = (eo_image_tensor / eo_max_val) * 2.0 - 1.0

        return sar_image_tensor, eo_image_tensor


In [8]:
# Helper function for Convolutional Block
def conv_block(in_channels, out_channels, kernel_size, stride, padding, use_bias=False, norm_layer=nn.InstanceNorm2d, activation=nn.ReLU(True)):
    """A convolutional block with Conv2d, Normalization, and Activation."""
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=use_bias)
    ]
    if norm_layer:
        layers.append(norm_layer(out_channels))
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)

# Helper function for Transposed Convolutional Block (for upsampling)
def deconv_block(in_channels, out_channels, kernel_size, stride, padding, output_padding, use_bias=False, norm_layer=nn.InstanceNorm2d, activation=nn.ReLU(True)):
    """A transposed convolutional block with ConvTranspose2d, Normalization, and Activation."""
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=use_bias)
    ]
    if norm_layer:
        layers.append(norm_layer(out_channels))
    if activation:
        layers.append(activation)
    return nn.Sequential(*layers)
