In [1]:
# !pip install rasterio --quiet

In [2]:
# !pip install mlflow --quiet
# !pip install dagshub --quiet

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
import random
import time
import rasterio
import dagshub
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from torch.cuda.amp import autocast, GradScaler
from datetime import timedelta
import sys
import torchvision.utils as vutils
from tqdm.notebook import tqdm

import dagshub
import mlflow

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
#torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define constants
IMG_SIZE = 512  # As per the paper, images are resized to 400x400
BATCH_SIZE = 4  # As per the paper it was 1
NUM_EPOCHS = 500
LEARNING_RATE = 0.0002
BETA1 = 0.5  # Adam optimizer beta1 parameter
BETA2 = 0.999  # Adam optimizer beta2 parameter
LAMBDA_CYCLE = 10.0  # Weight for cycle consistency loss
LAMBDA_IDENTITY = 0.5  # Weight for identity loss
run_name = "CycleGAN_training_up_512"

Using device: cuda


In [5]:
# Dataset class for SAR to Optical conversion
class SAROpticalDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.data_x = []
        self.data_y = []
        self.isTrain = True if "train" in root_dir else False
        self.count=0
        
        files = sorted(os.listdir(root_dir))
        
        qxs_pairs = []
        for sar_file in files:
            base_name = os.path.basename(sar_file)
            if "sar" in base_name :
                opt_file = base_name.replace("sar", "opt")  # Assuming the file names are the same
                sar_path = os.path.join(root_dir, base_name)
                opt_path = os.path.join(root_dir, opt_file)
                self.data_x.append(sar_path)
                self.data_y.append(opt_path)
        if len(self.data_x) != len(self.data_y) :
            print("Error in dataset unequal samples")

        print(f'Total: Found {len(self.data_x)} valid image pairs')
    
    def __len__(self):
        return len(self.data_x)
    
    def __getitem__(self, idx):
        if self.count==0 and self.isTrain:
            #print("Generating Random Pairs")
            random.shuffle(self.data_x)
            random.shuffle(self.data_y)
        self.count +=1 
        self.count = self.count%len(self.data_y)
        sar_path, opt_path = self.data_x[idx], self.data_y[idx]
        
        try:
            sar_image = Image.open(sar_path).convert("RGB")
            optical_image = Image.open(opt_path).convert("RGB")
        except Exception as e:
            print(f"Error loading images: {e}")
            print(f"SAR path: {sar_path}")
            print(f"Optical path: {opt_path}")
            # Return a fallback image or raise the exception
            # For simplicity, we'll return a random pair instead
            return self.__getitem__(random.randint(0, len(self.data) - 1))
        
        if random.random() > 0.5:
            sar_image = transforms.functional.hflip(sar_image)
            optical_image = transforms.functional.hflip(optical_image)

        if random.random() > 0.5:
            sar_image = transforms.functional.vflip(sar_image)
            optical_image = transforms.functional.vflip(optical_image)
        
        if self.transform:
            sar_image = self.transform(sar_image)
            optical_image = self.transform(optical_image)
        
        return {"SAR" : sar_image, "Optical" : optical_image}
#train_dataset = SAROpticalDataset(root_dir="/home/bhargav/Documents/SAR2OPT/OS-dataset512/train")

In [6]:
# Generator Network (ResNet-based)
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )
        
    def forward(self, x):
        return x + self.block(x)

In [7]:
class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=12):
        super(Generator, self).__init__()
        
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 128, 7),  # Increased from 64 to 128
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        ]
        
        in_features = 128
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]
        
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(128, output_channels, 7),  # Match the first Conv2D channels
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)


In [8]:
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(input_channels, 128, normalize=False),  # Increased from 64 to 128
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            *discriminator_block(512, 1024),  # One more block
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(1024, 1, 4, padding=1)
        )
        
    def forward(self, img):
        return self.model(img)


# Initialize weights function
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [9]:
# Initialize generators and discriminators
G_SAR_to_Optical = Generator().to(device)
G_Optical_to_SAR = Generator().to(device)
D_Optical = Discriminator().to(device)
D_SAR = Discriminator().to(device)

# Initialize weights
G_SAR_to_Optical.apply(weights_init_normal)
G_Optical_to_SAR.apply(weights_init_normal)
D_Optical.apply(weights_init_normal)
D_SAR.apply(weights_init_normal)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# Optimizers
optimizer_G = optim.Adam(
    list(G_SAR_to_Optical.parameters()) + list(G_Optical_to_SAR.parameters()),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2)
)
optimizer_D_Optical = optim.Adam(D_Optical.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
optimizer_D_SAR = optim.Adam(D_SAR.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))

scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=5, gamma=1) 
scheduler_D_Optical = torch.optim.lr_scheduler.StepLR(optimizer_D_Optical, step_size=5, gamma=1)
scheduler_D_SAR = torch.optim.lr_scheduler.StepLR(optimizer_D_SAR, step_size=5, gamma=1)


In [10]:
def plot_metrics(metrics_dict, output_path, title="Training Metrics"):
    """Utility function to plot multiple metrics in a single figure"""
    plt.figure(figsize=(12, 6))
    plt.title(title)
    
    for label, values in metrics_dict.items():
        plt.plot(values, label=label)
    
    plt.xlabel("Iterations" if "Loss" in title else "Epochs")
    plt.ylabel("Value")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(output_path)
    plt.close()

In [11]:
def visualize_results(images, save_path=None, nrow=5, figsize=(15, 12)):
    sar_images, generated_optical, real_optical = images
    n_images = min(nrow, sar_images.size(0))
    
    # Create 3 x n_images grid
    fig, axs = plt.subplots(3, n_images, figsize=figsize)

    # If axs is 1D (i.e., n_images == 1), make it 2D for consistency
    if n_images == 1:
        axs = axs.reshape(3, 1)

    for i in range(n_images):
        # Convert tensors to numpy and denormalize
        sar_img = sar_images[i].cpu().detach().numpy().transpose(1, 2, 0)
        gen_img = generated_optical[i].cpu().detach().numpy().transpose(1, 2, 0)
        real_img = real_optical[i].cpu().detach().numpy().transpose(1, 2, 0)

        # Normalize from [-1, 1] to [0, 1]
        sar_img = np.clip((sar_img + 1) / 2, 0, 1)
        gen_img = np.clip((gen_img + 1) / 2, 0, 1)
        real_img = np.clip((real_img + 1) / 2, 0, 1)

        # Check for grayscale SAR
        if sar_img.shape[2] == 3 and np.allclose(sar_img[:, :, 0], sar_img[:, :, 1]) and np.allclose(sar_img[:, :, 0], sar_img[:, :, 2]):
            sar_img = sar_img[:, :, 0]
            cmap_sar = 'gray'
        else:
            cmap_sar = None

        axs[0, i].imshow(sar_img, cmap=cmap_sar)
        axs[0, i].set_title("SAR Input")
        axs[0, i].axis('off')

        axs[1, i].imshow(gen_img)
        axs[1, i].set_title("Generated Optical")
        axs[1, i].axis('off')

        axs[2, i].imshow(real_img)
        axs[2, i].set_title("Real Optical")
        axs[2, i].axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

In [12]:
def plot_metrics(history, output_dir='.'):
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    epochs_range = range(1, len(history['g_losses']) + 1)
    
    def plot_and_save(metric_values, title, ylabel, filename):
        plt.figure()
        plt.plot(epochs_range, metric_values, label=title)
        plt.title(title)
        plt.xlabel('Epoch')
        plt.ylabel(ylabel)
        plt.legend()
        plt.savefig(os.path.join(output_dir, filename))
        plt.close()  # Close figure to avoid overlapping
    
    plot_and_save(history['g_losses'], 'Generator Loss', 'Loss', 'generator_loss.png')
    plot_and_save(history['d_losses'], 'Discriminator Loss', 'Loss', 'discriminator_loss.png')
    plot_and_save(history['cycle_losses'], 'Cycle Loss', 'Loss', 'cycle_loss.png')
    plot_and_save(history['identity_losses'], 'Identity Loss', 'Loss', 'identity_loss.png')
    plot_and_save(history['val_psnr'], 'Validation PSNR', 'PSNR (dB)', 'val_psnr.png')
    plot_and_save(history['val_ssim'], 'Validation SSIM', 'SSIM', 'val_ssim.png')

In [13]:
def load_sar_images(folder_path, image_size=(256, 256)):

    filenames = os.listdir(folder_path)
    image_paths = [os.path.join(folder_path, f) for f in filenames if f.lower().endswith((".png", ".jpg", ".jpeg"))]
    
    images = []
    for path in image_paths:
        image = Image.open(path).convert("RGB")  # Use "L" for grayscale if SAR images are single-channel
        image_tensor = transform(image)
        images.append(image_tensor)
    
    return images  # List of tensors


In [14]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_results_sar(sar_images, generated_optical, save_path=None, nrow=5, figsize_per_img=(2.5, 2.5)):
    n_images = min(nrow, sar_images.size(0))
    figsize = (figsize_per_img[0] * n_images, figsize_per_img[1] * 2)

    fig, axs = plt.subplots(2, n_images, figsize=figsize)

    if n_images == 1:
        axs = axs.reshape(2, 1)

    for i in range(n_images):
        sar_img = sar_images[i].cpu().detach().numpy().transpose(1, 2, 0)
        gen_img = generated_optical[i].cpu().detach().numpy().transpose(1, 2, 0)

        sar_img = np.clip((sar_img + 1) / 2, 0, 1)
        gen_img = np.clip((gen_img + 1) / 2, 0, 1)

        if sar_img.shape[2] == 3 and np.allclose(sar_img[:, :, 0], sar_img[:, :, 1]) and np.allclose(sar_img[:, :, 0], sar_img[:, :, 2]):
            sar_img = sar_img[:, :, 0]
            cmap_sar = 'gray'
        else:
            cmap_sar = None

        axs[0, i].imshow(sar_img, cmap=cmap_sar)
        axs[0, i].set_title("SAR Input", fontsize=10)
        axs[0, i].axis('off')

        axs[1, i].imshow(gen_img)
        axs[1, i].set_title("Generated Optical", fontsize=10)
        axs[1, i].axis('off')

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.05, hspace=0.3)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


In [15]:
def save_generated_images(epoch, generator, sar_images, device, base_dir="results"):
    base_dir = os.path.join(base_dir, "drdo_gen")
    os.makedirs(base_dir, exist_ok=True)

    final_path = os.path.join(base_dir, f"epoch_{epoch+1}_gen.png")
 
    with torch.no_grad():
        real_sar = torch.stack(sar_images).to(device)  # Convert list of tensors to batch tensor
        fake_optical = generator(real_sar)

        visualize_results_sar(real_sar, fake_optical, final_path)

In [16]:
def train_cyclegan(train_dataloader, val_dataloader, output_dir, epochs=NUM_EPOCHS, display_interval=503, patience=10):
    """
    Improved training function for CycleGAN with validation and proper epoch tracking
    """
    sar_images = None
    # Create directories
    mlflow.set_experiment("SAR_to_Optical_Pix2Pix")
    
    os.makedirs(output_dir, exist_ok=True)
    models_dir = os.path.join(output_dir, 'models')
    samples_dir = os.path.join(output_dir, 'samples')
    checkpoint_dir = os.path.join(output_dir, "checkpoints")
    os.makedirs(models_dir, exist_ok=True)
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Initialize gradient scaler
    scaler = GradScaler()
    
    
    checkpoint_path = '/home/bhargav/Documents/SAR2OPT/resume_ckpt/cycleGAN_checkpoint.pth'
    start_epoch = 0
    
    # Load checkpoint if available
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, weights_only=False)
        G_SAR_to_Optical.load_state_dict(checkpoint['G_SAR_to_Optical'])
        G_Optical_to_SAR.load_state_dict(checkpoint['G_Optical_to_SAR'])
        D_Optical.load_state_dict(checkpoint['D_Optical'])
        D_SAR.load_state_dict(checkpoint['D_SAR'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G'])
        optimizer_D_Optical.load_state_dict(checkpoint['optimizer_D_Optical'])
        optimizer_D_SAR.load_state_dict(checkpoint['optimizer_D_SAR'])
        scaler.load_state_dict(checkpoint['scaler'])
        history = checkpoint.get('history', {
            'g_losses': [],
            'd_losses': [],
            'cycle_losses': [],
            'identity_losses': [],
            'val_psnr': [],
            'val_ssim': [],
            'val_identity_losses' : []
        })
        start_epoch = checkpoint['epoch']
        scheduler_G.load_state_dict(checkpoint['scheduler_G'])
        scheduler_D_Optical.load_state_dict(checkpoint['scheduler_D_Optical'])
        scheduler_D_SAR.load_state_dict(checkpoint['scheduler_D_SAR'])
        print(f"✅ Resuming from epoch {start_epoch+1}")
    else:
        print("🚀 Starting training from scratch")
        history = {
            'g_losses': [],
            'd_losses': [],
            'cycle_losses': [],
            'identity_losses': [],
            'val_psnr': [],
            'val_ssim': [],
            'val_identity_losses': []
        }

    best_cycle_loss = float('inf')
    patience_counter = 0

    print(f"Starting training for {epochs} epochs...")
    
    with mlflow.start_run(run_name=run_name):
        if history:
            print("Logging prev data")
            print(history)
            for epoch in range(len(history["g_losses"])):
                # Log metrics directly for each epoch
                mlflow.log_metrics({
                    "train_g_loss": history.get("g_losses")[epoch],
                    "train_d_loss": history.get("d_losses")[epoch],
                    "cycle_loss": history.get("cycle_losses")[epoch],
                    "identity_loss": history.get("identity_losses")[epoch],
                    "val_psnr": history.get("val_psnr")[epoch],
                    "val_ssim": history.get("val_ssim")[epoch],
                    "val_identity_losses": history.get("val_identity_losses")[epoch],
                    "learning_rate": optimizer_G.param_groups[0]['lr'],  # Log the learning rate
                }, step=epoch)

        for epoch in range(start_epoch, epochs):
            G_SAR_to_Optical.train()
            G_Optical_to_SAR.train()
            D_Optical.train()
            D_SAR.train()

            epoch_g_loss, epoch_d_loss, epoch_cycle_loss, epoch_identity_loss = 0, 0, 0, 0

            for i, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
                real_sar = batch['SAR'].to(device)
                real_optical = batch['Optical'].to(device)

                # Calculate patch size
                with torch.no_grad():
                    if i == 0:
                        dummy_output = D_Optical(real_optical[:1])
                        patch_h, patch_w = dummy_output.size(2), dummy_output.size(3)
                        valid_prototype = torch.ones((1, 1, patch_h, patch_w), requires_grad=False).to(device)
                        fake_prototype = torch.zeros((1, 1, patch_h, patch_w), requires_grad=False).to(device)

                valid = valid_prototype.expand(real_sar.size(0), -1, -1, -1)
                fake = fake_prototype.expand(real_sar.size(0), -1, -1, -1)

                # ------------------
                # Train Generators
                # ------------------
                optimizer_G.zero_grad()

                with torch.cuda.amp.autocast():

                    fake_optical = G_SAR_to_Optical(real_sar)
                    loss_GAN_SAR2Optical = criterion_GAN(D_Optical(fake_optical), valid)

                    fake_sar = G_Optical_to_SAR(real_optical)
                    loss_GAN_Optical2SAR = criterion_GAN(D_SAR(fake_sar), valid)

                    loss_id_optical = criterion_identity(G_SAR_to_Optical(real_optical), real_optical)
                    loss_id_sar = criterion_identity(G_Optical_to_SAR(real_sar), real_sar)
                    loss_identity = (loss_id_optical + loss_id_sar) * LAMBDA_IDENTITY

                    loss_GAN = loss_GAN_SAR2Optical + loss_GAN_Optical2SAR

                    recovered_sar = G_Optical_to_SAR(fake_optical)
                    loss_cycle_sar = criterion_cycle(recovered_sar, real_sar)

                    recovered_optical = G_SAR_to_Optical(fake_sar)
                    loss_cycle_optical = criterion_cycle(recovered_optical, real_optical)

                    loss_cycle = (loss_cycle_sar + loss_cycle_optical) * LAMBDA_CYCLE

                    loss_G = loss_GAN + loss_cycle + loss_identity

                scaler.scale(loss_G).backward()
                scaler.step(optimizer_G)

                # -----------------------
                # Train Discriminator Optical
                # -----------------------
                optimizer_D_Optical.zero_grad()

                with torch.cuda.amp.autocast():
                    loss_real = criterion_GAN(D_Optical(real_optical), valid)
                    loss_fake = criterion_GAN(D_Optical(fake_optical.detach()), fake)
                    loss_D_optical = (loss_real + loss_fake) * 0.5

                scaler.scale(loss_D_optical).backward()
                scaler.step(optimizer_D_Optical)

                # -----------------------
                # Train Discriminator SAR
                # -----------------------
                optimizer_D_SAR.zero_grad()

                with torch.cuda.amp.autocast():
                    loss_real = criterion_GAN(D_SAR(real_sar), valid)
                    loss_fake = criterion_GAN(D_SAR(fake_sar.detach()), fake)
                    loss_D_sar = (loss_real + loss_fake) * 0.5

                scaler.scale(loss_D_sar).backward()
                scaler.step(optimizer_D_SAR)
                scaler.update()

                epoch_g_loss += loss_G.item()
                epoch_d_loss += 0.5 * (loss_D_optical.item() + loss_D_sar.item())
                epoch_cycle_loss += loss_cycle.item()
                epoch_identity_loss += loss_identity.item()

                if (i+1) % display_interval == 0:
                    print(f"\nEpoch {epoch+1}/{epochs}, Iteration {i+1}/{len(train_dataloader)}")
                    print(f"Generator Loss: {loss_G.item():.4f}, Discriminator Loss: {0.5 * (loss_D_optical.item() + loss_D_sar.item()):.4f}")
                    print(f"Cycle Loss: {loss_cycle.item():.4f}, Identity Loss: {loss_identity.item():.4f}")

                    with torch.no_grad():
                        num_samples = min(5, real_sar.size(0))
                        sample_sar = real_sar[:num_samples]
                        sample_optical = real_optical[:num_samples]
                        sample_fake_optical = G_SAR_to_Optical(sample_sar)

                        sample_path = os.path.join(samples_dir, f"epoch_{epoch+1}_batch_{i}.png")
                        visualize_results(
                            (sample_sar, sample_fake_optical, sample_optical),
                            save_path=sample_path,
                            nrow=num_samples
                        )
                        mlflow.log_artifact(sample_path, artifact_path="samples")

            avg_g_loss = epoch_g_loss / len(train_dataloader)
            avg_d_loss = epoch_d_loss / len(train_dataloader)
            avg_cycle_loss = epoch_cycle_loss / len(train_dataloader)
            avg_identity_loss = epoch_identity_loss / len(train_dataloader)

            history['g_losses'].append(avg_g_loss)
            history['d_losses'].append(avg_d_loss)
            history['cycle_losses'].append(avg_cycle_loss)
            history['identity_losses'].append(avg_identity_loss)

            # Validation ##################################################################################################################

            G_SAR_to_Optical.eval()
            G_Optical_to_SAR.eval()
            
            ## saving the images by DRDO

            if epoch >= 0 and epoch%1==0 :
                if sar_images is None :
                    sar_images = load_sar_images("/home/bhargav/Documents/SAR2OPT/samples")
                save_generated_images(epoch, G_SAR_to_Optical, sar_images, device, base_dir = output_dir)


            val_psnr = []
            val_ssim = []
            val_loss_identities = []
            with torch.no_grad():
                
                for batch in tqdm(val_dataloader, desc="Validation"):
                    real_sar = batch['SAR'].to(device)
                    real_optical = batch['Optical'].to(device)
                    
                    fake_sar = G_Optical_to_SAR(real_optical)
                    fake_optical = G_SAR_to_Optical(real_sar)

                    for j in range(real_sar.size(0)):
                        val_loss_id_optical = criterion_identity(G_SAR_to_Optical(real_optical), real_optical)
                        val_loss_id_sar = criterion_identity(G_Optical_to_SAR(real_sar), real_sar)
                        val_loss_identity = (val_loss_id_optical + val_loss_id_sar) * LAMBDA_IDENTITY
                        
                        val_loss_identities.append(val_loss_identity)

                        single_optical = real_optical[j:j+1]
                        single_fake = fake_optical[j:j+1]

                        real_optical_np = single_optical.cpu().numpy().transpose(0, 2, 3, 1)[0]
                        fake_optical_np = single_fake.cpu().numpy().transpose(0, 2, 3, 1)[0]

                        real_optical_np = (real_optical_np + 1) / 2
                        fake_optical_np = (fake_optical_np + 1) / 2

                        real_optical_np = np.clip(real_optical_np, 0, 1)
                        fake_optical_np = np.clip(fake_optical_np, 0, 1)


                        psnr_value = psnr(real_optical_np, fake_optical_np, data_range=1.0)
                        ssim_value = np.mean([ssim(real_optical_np[:,:,c], fake_optical_np[:,:,c], data_range=1.0) for c in range(3)])
                        val_psnr.append(psnr_value)
                        val_ssim.append(ssim_value)

            avg_val_psnr = np.mean(val_psnr)
            avg_val_ssim = np.mean(val_ssim)

            avg_val_identity = torch.mean(torch.tensor(val_loss_identities).float()).item()

            history['val_psnr'].append(avg_val_psnr)
            history['val_ssim'].append(avg_val_ssim)
            history['val_identity_losses'].append(avg_val_identity)

            print(f"\nEpoch {epoch+1}/{epochs} completed")
            print(f"Training losses - Generator: {avg_g_loss:.4f}, Discriminator: {avg_d_loss:.4f}")
            print(f"Cycle Loss: {avg_cycle_loss:.4f}, Identity Loss: {avg_identity_loss:.4f}")
            print(f"Validation PSNR: {avg_val_psnr:.4f}, SSIM: {avg_val_ssim:.4f}")

            val_batch = next(iter(val_dataloader))
            with torch.no_grad():
                real_sar = val_batch['SAR'].to(device)[:5]
                real_optical = val_batch['Optical'].to(device)[:5]
                fake_optical = G_SAR_to_Optical(real_sar)

                val_sample_path = os.path.join(samples_dir, f"val_epoch_{epoch+1}.png")
                visualize_results(
                    (real_sar, fake_optical, real_optical),
                    save_path=val_sample_path,
                    nrow=real_sar.size(0)
                )
                mlflow.log_artifact(val_sample_path, artifact_path="validation_samples")


            checkpoint_path = os.path.join(checkpoint_dir, "cycleGAN_checkpoint.pth")
            
            # --- Save checkpoint ---
            torch.save({
                'epoch': epoch + 1,
                'G_SAR_to_Optical': G_SAR_to_Optical.state_dict(),
                'G_Optical_to_SAR': G_Optical_to_SAR.state_dict(),
                'D_Optical': D_Optical.state_dict(),
                'D_SAR': D_SAR.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D_Optical': optimizer_D_Optical.state_dict(),
                'optimizer_D_SAR': optimizer_D_SAR.state_dict(),
                'scaler': scaler.state_dict(),
                'scheduler_G': scheduler_G.state_dict(),
                'scheduler_D_Optical': scheduler_D_Optical.state_dict(),
                'scheduler_D_SAR': scheduler_D_SAR.state_dict(),
                'history': history,
            }, checkpoint_path)

            print(f"💾 Epoch {epoch+1} complete. Checkpoint saved.")

            # --- Save best model ---
            if avg_cycle_loss < best_cycle_loss:
                best_cycle_loss = avg_cycle_loss
                patience_counter = 0

                torch.save({
                    'epoch': epoch + 1,
                    'G_SAR_to_Optical': G_SAR_to_Optical.state_dict(),
                    'G_Optical_to_SAR': G_Optical_to_SAR.state_dict(),
                    'D_Optical': D_Optical.state_dict(),
                    'D_SAR': D_SAR.state_dict(),
                    'optimizer_G': optimizer_G.state_dict(),
                    'optimizer_D_Optical': optimizer_D_Optical.state_dict(),
                    'optimizer_D_SAR': optimizer_D_SAR.state_dict(),
                    'history': history,
                }, os.path.join(models_dir, "cyclegan_best_model.pth"))

                print(f"✓ Saved best model with cycle loss: {best_cycle_loss:.4f}")
            # else:
            #     patience_counter += 1
            #     print(f"! Cycle loss did not improve. Patience: {patience_counter}/{patience}")

            # --- MLflow logging ---
            mlflow.log_metrics({
                "train_g_loss": avg_g_loss,
                "train_d_loss": avg_d_loss,
                "cycle_loss": avg_cycle_loss,
                "identity_loss": avg_identity_loss,
                "val_psnr": avg_val_psnr,
                "val_ssim": avg_val_ssim,
                "val_identity_losses": avg_val_identity,
                "learning_rate": optimizer_G.param_groups[0]['lr'],
            }, step=epoch)

            scheduler_G.step()
            scheduler_D_Optical.step()
            scheduler_D_SAR.step()

            # if patience_counter >= patience:
            #     print(f"Early stopping triggered after {epoch+1} epochs")
            #     break

        plot_metrics(history)

        checkpoint = torch.load(os.path.join(models_dir, "cyclegan_best_model.pth"), weights_only=False)
        G_SAR_to_Optical.load_state_dict(checkpoint['G_SAR_to_Optical'])
        G_Optical_to_SAR.load_state_dict(checkpoint['G_Optical_to_SAR'])

        mlflow.pytorch.log_model(G_SAR_to_Optical, "G_SAR_to_Optical_final")
        mlflow.pytorch.log_model(G_Optical_to_SAR, "G_Optical_to_SAR_final")

        return G_SAR_to_Optical, G_Optical_to_SAR, history


In [17]:
def evaluate_model(model, test_dataloader, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    # Lists to store metrics
    psnr_values = []
    ssim_values = []
    
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(test_dataloader):
            # Set model input
            real_sar = batch['SAR'].to(device)
            real_optical = batch['Optical'].to(device)
            
            # Generate optical image from SAR
            fake_optical = model(real_sar)
            
            # Convert tensors to numpy arrays for evaluation
            real_optical_np = real_optical.cpu().numpy().transpose(0, 2, 3, 1)[0]
            fake_optical_np = fake_optical.cpu().numpy().transpose(0, 2, 3, 1)[0]
            
            # Denormalize
            real_optical_np = (real_optical_np + 1) / 2
            fake_optical_np = (fake_optical_np + 1) / 2
            
            # Clip values to [0, 1]
            real_optical_np = np.clip(real_optical_np, 0, 1)
            fake_optical_np = np.clip(fake_optical_np, 0, 1)
            
            # Calculate PSNR
            psnr_value = psnr(real_optical_np, fake_optical_np, data_range=1.0)
            psnr_values.append(psnr_value)
            
            # Calculate SSIM (for each channel separately and then average)
            ssim_value = np.mean([ssim(real_optical_np[:,:,c], fake_optical_np[:,:,c], data_range=1.0) for c in range(3)])
            ssim_values.append(ssim_value)
            
            # Save images
            save_image(torch.cat((real_sar, fake_optical, real_optical), 0),
                      os.path.join(output_dir, f"test_sample_{i}_psnr_{psnr_value:.4f}_ssim_{ssim_value:.4f}.png"),
                      nrow=3, normalize=True)
            
            print(f"Test sample {i}: PSNR: {psnr_value:.4f}, SSIM: {ssim_value:.4f}")
    
    # Calculate average metrics
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    
    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    
    return psnr_values, ssim_values

In [18]:
# Define directories
root_dir = ""
output_dir = "/home/bhargav/Documents/SAR2OPT/results"
test_output_dir = os.path.join(output_dir, "test_results")

In [19]:
# Create transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation = transforms.InterpolationMode.NEAREST),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

train_dataset = SAROpticalDataset(root_dir="/home/bhargav/Documents/SAR2OPT/OS-dataset512/train", transform=transform)
val_dataset = SAROpticalDataset(root_dir="/home/bhargav/Documents/SAR2OPT/OS-dataset512/val", transform=transform)
test_dataset = SAROpticalDataset(root_dir="/home/bhargav/Documents/SAR2OPT/OS-dataset512/test", transform=transform)
print(f"Dataset splits: Train={len(train_dataset)}, Validation={len(val_dataset)}, Test={len(test_dataset)}")

Total: Found 2011 valid image pairs
Total: Found 238 valid image pairs
Total: Found 424 valid image pairs
Dataset splits: Train=2011, Validation=238, Test=424


In [20]:
# from torch.utils.data import Subset
# import random

# def get_subset(dataset, fraction=0.1):
#     total_len = len(dataset)
#     subset_len = int(total_len * fraction)
#     indices = random.sample(range(total_len), subset_len)
#     return Subset(dataset, indices)

In [21]:
# # Create 10% subsets
# train_dataset = get_subset(train_dataset, fraction=0.01)
# val_dataset = get_subset(val_dataset, fraction=0.01)
# test_dataset = get_subset(test_dataset, fraction=0.01)

In [22]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

In [23]:
dagshub.init(repo_owner='AyushprojectDR', repo_name='sar-to-opt', mlflow=True)

In [None]:
G_SAR_to_Optical, G_Optical_to_SAR, history = train_cyclegan(train_dataloader, val_dataloader, output_dir)

  scaler = GradScaler()


✅ Resuming from epoch 8
Starting training for 500 epochs...
Logging prev data
{'g_losses': [6.817451230574318, 4.74901060318615, 4.446148287941874, 4.187768130842781, 4.003909141832507, 3.621926189179923, 3.6597083501265963], 'd_losses': [0.7982615522936609, 0.28684080832047204, 0.27633709896718767, 0.28020240516062994, 0.28022861778439867, 0.2529945527582946, 0.2420204299433568], 'cycle_losses': [4.847271245232159, 3.7013457570350905, 3.3868415715917206, 3.1049892926073928, 2.9046397014830267, 2.6592525381217182, 2.5685211680755464], 'identity_losses': [0.26737651593287, 0.20316651291686072, 0.1928616942723039, 0.1830502948010892, 0.17704118653450523, 0.169231021395379, 0.16510844233614314], 'val_psnr': [12.846827543874495, 13.136957311611814, 11.99291059993248, 12.69010044500092, 12.320546570976841, 12.622140782787364, 12.235689095036587], 'val_ssim': [0.18393463071240326, 0.16490674239459135, 0.1381940786778785, 0.1629804899127037, 0.1350909610741714, 0.1703723141779022, 0.162302985

Epoch 8/500:   0%|          | 0/503 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():



Epoch 8/500, Iteration 503/503
Generator Loss: 4.3354, Discriminator Loss: 0.3002
Cycle Loss: 2.4950, Identity Loss: 0.1945


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 8/500 completed
Training losses - Generator: 3.6126, Discriminator: 0.2393
Cycle Loss: 2.4665, Identity Loss: 0.1608
Validation PSNR: 12.4671, SSIM: 0.1508
💾 Epoch 8 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 2.4665


Epoch 9/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 9/500, Iteration 503/503
Generator Loss: 3.7642, Discriminator Loss: 0.2952
Cycle Loss: 2.2033, Identity Loss: 0.1546


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 9/500 completed
Training losses - Generator: 3.5237, Discriminator: 0.2179
Cycle Loss: 2.3875, Identity Loss: 0.1634
Validation PSNR: 12.4887, SSIM: 0.1525
💾 Epoch 9 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 2.3875


Epoch 10/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 10/500, Iteration 503/503
Generator Loss: 3.5945, Discriminator Loss: 0.1949
Cycle Loss: 2.6273, Identity Loss: 0.1550


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 10/500 completed
Training losses - Generator: 3.7548, Discriminator: 0.2077
Cycle Loss: 2.4937, Identity Loss: 0.1645
Validation PSNR: 12.1545, SSIM: 0.1477
💾 Epoch 10 complete. Checkpoint saved.


Epoch 11/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 11/500, Iteration 503/503
Generator Loss: 3.5372, Discriminator Loss: 0.1899
Cycle Loss: 2.3867, Identity Loss: 0.1555


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 11/500 completed
Training losses - Generator: 3.4551, Discriminator: 0.2585
Cycle Loss: 2.2935, Identity Loss: 0.1586
Validation PSNR: 12.5911, SSIM: 0.1478
💾 Epoch 11 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 2.2935


Epoch 12/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 12/500, Iteration 503/503
Generator Loss: 4.5781, Discriminator Loss: 0.1730
Cycle Loss: 2.6647, Identity Loss: 0.1784


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 12/500 completed
Training losses - Generator: 3.5313, Discriminator: 0.1971
Cycle Loss: 2.3247, Identity Loss: 0.1621
Validation PSNR: 11.8535, SSIM: 0.1354
💾 Epoch 12 complete. Checkpoint saved.


Epoch 13/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 13/500, Iteration 503/503
Generator Loss: 3.5818, Discriminator Loss: 0.2342
Cycle Loss: 2.5846, Identity Loss: 0.1396


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 13/500 completed
Training losses - Generator: 3.6388, Discriminator: 0.3374
Cycle Loss: 2.3313, Identity Loss: 0.1599
Validation PSNR: 11.9425, SSIM: 0.1206
💾 Epoch 13 complete. Checkpoint saved.


Epoch 14/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 14/500, Iteration 503/503
Generator Loss: 2.7324, Discriminator Loss: 0.3432
Cycle Loss: 1.8824, Identity Loss: 0.1874


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 14/500 completed
Training losses - Generator: 2.8321, Discriminator: 0.2622
Cycle Loss: 2.0809, Identity Loss: 0.1491
Validation PSNR: 12.3014, SSIM: 0.1503
💾 Epoch 14 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 2.0809


Epoch 15/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 15/500, Iteration 503/503
Generator Loss: 3.4138, Discriminator Loss: 0.4813
Cycle Loss: 1.7106, Identity Loss: 0.1113


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 15/500 completed
Training losses - Generator: 2.7326, Discriminator: 0.2692
Cycle Loss: 1.9368, Identity Loss: 0.1480
Validation PSNR: 12.2304, SSIM: 0.1602
💾 Epoch 15 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.9368


Epoch 16/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 16/500, Iteration 503/503
Generator Loss: 2.5795, Discriminator Loss: 0.2108
Cycle Loss: 1.8699, Identity Loss: 0.1434


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 16/500 completed
Training losses - Generator: 2.8706, Discriminator: 0.2562
Cycle Loss: 2.0145, Identity Loss: 0.1486
Validation PSNR: 12.4184, SSIM: 0.1699
💾 Epoch 16 complete. Checkpoint saved.


Epoch 17/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 17/500, Iteration 503/503
Generator Loss: 3.1755, Discriminator Loss: 0.1439
Cycle Loss: 1.8512, Identity Loss: 0.1737


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 17/500 completed
Training losses - Generator: 3.1702, Discriminator: 0.2180
Cycle Loss: 2.0607, Identity Loss: 0.1512
Validation PSNR: 12.0796, SSIM: 0.1600
💾 Epoch 17 complete. Checkpoint saved.


Epoch 18/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 18/500, Iteration 503/503
Generator Loss: 3.2718, Discriminator Loss: 0.2839
Cycle Loss: 2.1712, Identity Loss: 0.0883


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 18/500 completed
Training losses - Generator: 3.0603, Discriminator: 0.2318
Cycle Loss: 2.0574, Identity Loss: 0.1529
Validation PSNR: 13.0274, SSIM: 0.2052
💾 Epoch 18 complete. Checkpoint saved.


Epoch 19/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 19/500, Iteration 503/503
Generator Loss: 3.0777, Discriminator Loss: 0.2194
Cycle Loss: 2.1977, Identity Loss: 0.1605


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 19/500 completed
Training losses - Generator: 2.8026, Discriminator: 0.2731
Cycle Loss: 1.9350, Identity Loss: 0.1473
Validation PSNR: 12.3746, SSIM: 0.1536
💾 Epoch 19 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.9350


Epoch 20/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 20/500, Iteration 503/503
Generator Loss: 2.8173, Discriminator Loss: 0.2253
Cycle Loss: 2.1094, Identity Loss: 0.0870


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 20/500 completed
Training losses - Generator: 2.7568, Discriminator: 0.2498
Cycle Loss: 1.9315, Identity Loss: 0.1437
Validation PSNR: 12.7718, SSIM: 0.1957
💾 Epoch 20 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.9315


Epoch 21/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 21/500, Iteration 503/503
Generator Loss: 2.3827, Discriminator Loss: 0.2320
Cycle Loss: 1.5112, Identity Loss: 0.1402


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 21/500 completed
Training losses - Generator: 2.6684, Discriminator: 0.2561
Cycle Loss: 1.8286, Identity Loss: 0.1416
Validation PSNR: 12.7179, SSIM: 0.1941
💾 Epoch 21 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.8286


Epoch 22/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 22/500, Iteration 503/503
Generator Loss: 3.1238, Discriminator Loss: 0.3493
Cycle Loss: 2.0157, Identity Loss: 0.1412


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 22/500 completed
Training losses - Generator: 2.7000, Discriminator: 0.2503
Cycle Loss: 1.8429, Identity Loss: 0.1421
Validation PSNR: 12.8187, SSIM: 0.2038
💾 Epoch 22 complete. Checkpoint saved.


Epoch 23/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 23/500, Iteration 503/503
Generator Loss: 2.6788, Discriminator Loss: 0.1448
Cycle Loss: 1.7310, Identity Loss: 0.1172


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 23/500 completed
Training losses - Generator: 2.7020, Discriminator: 0.2455
Cycle Loss: 1.8193, Identity Loss: 0.1400
Validation PSNR: 12.8383, SSIM: 0.1892
💾 Epoch 23 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.8193


Epoch 24/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 24/500, Iteration 503/503
Generator Loss: 3.8619, Discriminator Loss: 0.2281
Cycle Loss: 2.7725, Identity Loss: 0.1732


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 24/500 completed
Training losses - Generator: 2.8031, Discriminator: 0.2384
Cycle Loss: 1.8306, Identity Loss: 0.1361
Validation PSNR: 13.3728, SSIM: 0.2415
💾 Epoch 24 complete. Checkpoint saved.


Epoch 25/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 25/500, Iteration 503/503
Generator Loss: 2.6634, Discriminator Loss: 0.2108
Cycle Loss: 1.8856, Identity Loss: 0.1229


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 25/500 completed
Training losses - Generator: 2.8300, Discriminator: 0.2201
Cycle Loss: 1.7695, Identity Loss: 0.1356
Validation PSNR: 12.2766, SSIM: 0.1536
💾 Epoch 25 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.7695


Epoch 26/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 26/500, Iteration 503/503
Generator Loss: 2.9496, Discriminator Loss: 0.2478
Cycle Loss: 2.1760, Identity Loss: 0.1499


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 26/500 completed
Training losses - Generator: 2.7721, Discriminator: 0.2229
Cycle Loss: 1.7564, Identity Loss: 0.1376
Validation PSNR: 12.7884, SSIM: 0.2016
💾 Epoch 26 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.7564


Epoch 27/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 27/500, Iteration 503/503
Generator Loss: 2.6398, Discriminator Loss: 0.1992
Cycle Loss: 1.8197, Identity Loss: 0.1467


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 27/500 completed
Training losses - Generator: 2.8373, Discriminator: 0.2063
Cycle Loss: 1.7419, Identity Loss: 0.1371
Validation PSNR: 12.1034, SSIM: 0.1777
💾 Epoch 27 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.7419


Epoch 28/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 28/500, Iteration 503/503
Generator Loss: 2.9808, Discriminator Loss: 0.1857
Cycle Loss: 2.2242, Identity Loss: 0.1436


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 28/500 completed
Training losses - Generator: 2.8159, Discriminator: 0.2094
Cycle Loss: 1.7474, Identity Loss: 0.1332
Validation PSNR: 12.6541, SSIM: 0.2168
💾 Epoch 28 complete. Checkpoint saved.


Epoch 29/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 29/500, Iteration 503/503
Generator Loss: 2.4905, Discriminator Loss: 0.1448
Cycle Loss: 1.4797, Identity Loss: 0.1252


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 29/500 completed
Training losses - Generator: 2.7524, Discriminator: 0.2092
Cycle Loss: 1.6759, Identity Loss: 0.1303
Validation PSNR: 12.9647, SSIM: 0.1977
💾 Epoch 29 complete. Checkpoint saved.
✓ Saved best model with cycle loss: 1.6759


Epoch 30/500:   0%|          | 0/503 [00:00<?, ?it/s]


Epoch 30/500, Iteration 503/503
Generator Loss: 2.5711, Discriminator Loss: 0.2553
Cycle Loss: 1.6043, Identity Loss: 0.1194


Validation:   0%|          | 0/60 [00:00<?, ?it/s]


Epoch 30/500 completed
Training losses - Generator: 2.8725, Discriminator: 0.1962
Cycle Loss: 1.7416, Identity Loss: 0.1327
Validation PSNR: 12.4392, SSIM: 0.1728
💾 Epoch 30 complete. Checkpoint saved.


Epoch 31/500:   0%|          | 0/503 [00:00<?, ?it/s]

In [None]:
psnr_values, ssim_values = evaluate_model(G_SAR_to_Optical, test_dataloader, test_output_dir)

# Plot evaluation metrics
plt.figure(figsize=(10, 5))
plt.title("PSNR Values for Test Samples")
plt.bar(range(len(psnr_values)), psnr_values)
plt.xlabel("Test Sample")
plt.ylabel("PSNR")
plt.savefig(os.path.join(output_dir, "psnr_values.png"))
plt.close()

plt.figure(figsize=(10, 5))
plt.title("SSIM Values for Test Samples")
plt.bar(range(len(ssim_values)), ssim_values)
plt.xlabel("Test Sample")
plt.ylabel("SSIM")
plt.savefig(os.path.join(output_dir, "ssim_values.png"))
plt.close()