In [1]:
!unzip train3.zip -d train1

Archive:  train3.zip
   creating: train1/train3/
  inflating: train1/train3/.DS_Store  
  inflating: train1/__MACOSX/train3/._.DS_Store  
   creating: train1/train3/input/
   creating: train1/train3/output/
  inflating: train1/train3/input/.DS_Store  
  inflating: train1/__MACOSX/train3/input/._.DS_Store  
  inflating: train1/train3/input/s1.jpg  
  inflating: train1/__MACOSX/train3/input/._s1.jpg  
  inflating: train1/train3/input/s3.jpg  
  inflating: train1/__MACOSX/train3/input/._s3.jpg  
  inflating: train1/train3/input/s2.jpg  
  inflating: train1/__MACOSX/train3/input/._s2.jpg  
  inflating: train1/train3/output/.DS_Store  
  inflating: train1/__MACOSX/train3/output/._.DS_Store  
  inflating: train1/train3/output/o2.jpg  
  inflating: train1/__MACOSX/train3/output/._o2.jpg  
  inflating: train1/train3/output/o3.jpg  
  inflating: train1/__MACOSX/train3/output/._o3.jpg  
  inflating: train1/train3/output/o1.jpg  
  inflating: train1/__MACOSX/train3/output/._o1.jpg  


In [19]:
from torch.utils.data import Dataset
import torch
from PIL import Image
import os
import torchvision.transforms as T

class PairedSAROpticalDataset(Dataset):
    def __init__(self, input_dir, output_dir, transform=None):
        self.input_dir = input_dir
        self.output_dir = output_dir

        # Only allow valid image extensions
        valid_extensions = ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp']
        self.input_files = sorted([
            f for f in os.listdir(input_dir) if os.path.splitext(f)[1].lower() in valid_extensions
        ])
        self.output_files = sorted([
            f for f in os.listdir(output_dir) if os.path.splitext(f)[1].lower() in valid_extensions
        ])

        # Resize to (height, width) = (1733, 2500)
        self.transform = T.Compose([
            T.Resize((1728, 2496)),  # (H, W)
            T.ToTensor(),
            T.Normalize((0.5,), (0.5,))  # Normalize for SAR and Optical
        ])

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        sar_path = os.path.join(self.input_dir, self.input_files[idx])
        opt_path = os.path.join(self.output_dir, self.output_files[idx])

        sar = Image.open(sar_path).convert('L')     # SAR: Grayscale
        opt = Image.open(opt_path).convert('RGB')   # Optical: RGB

        sar = self.transform(sar)
        opt = self.transform(opt)

        return sar, opt


In [20]:
# Simplified UNet generator (input: SAR, output: Optical)
from torchvision.models import resnet18
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, features=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features, features*2, 4, 2, 1),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features*2, features*4, 4, 2, 1),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(features*4, features*2, 4, 2, 1),
            nn.BatchNorm2d(features*2),
            nn.ReLU(),
            nn.ConvTranspose2d(features*2, features, 4, 2, 1),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.ConvTranspose2d(features, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [21]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=4, features=64):  # SAR (1) + Optical (3)
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features, features*2, 4, 2, 1),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features*2, features*4, 4, 2, 1),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features*4, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        input = torch.cat([x, y], dim=1)  # concatenate along channel dim
        return self.model(input)


In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import piq

# --- Initialize Models ---
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

# --- Losses ---
loss_gan = nn.BCEWithLogitsLoss()
loss_l1 = nn.L1Loss()

# --- Optimizers ---
opt_g = optim.Adam(generator.parameters(), lr=9e-4, betas=(0.5, 0.999))
opt_d = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# --- Learning Rate Schedulers (Cosine Annealing with Warm Restarts) ---
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler_g = CosineAnnealingWarmRestarts(opt_g, T_0=10, T_mult=2)
scheduler_d = CosineAnnealingWarmRestarts(opt_d, T_0=10, T_mult=2)

# --- Dataset ---
dataloader = torch.utils.data.DataLoader(
    PairedSAROpticalDataset('train1/train3/input', 'train1/train3/output'),
    batch_size=8, shuffle=True
)

# --- Create directories ---
os.makedirs("samples", exist_ok=True)
os.makedirs("curves", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

# --- Loss Trackers ---
g_losses = []
d_losses = []

# --- Training Loop ---
for epoch in range(1500):
    generator.train()
    discriminator.train()

    total_g_loss = 0
    total_d_loss = 0

    for i, (sar, opt_real) in enumerate(dataloader):
        sar, opt_real = sar.to(device), opt_real.to(device)

        # === Train Discriminator ===
        with torch.no_grad():
            fake_opt = generator(sar).detach()
        real_pred = discriminator(sar, opt_real)
        fake_pred = discriminator(sar, fake_opt)

        d_loss_real = loss_gan(real_pred, torch.ones_like(real_pred))
        d_loss_fake = loss_gan(fake_pred, torch.zeros_like(fake_pred))
        d_loss = (d_loss_real + d_loss_fake) / 2

        opt_d.zero_grad()
        d_loss.backward()
        opt_d.step()

        # === Train Generator ===
        fake_opt = generator(sar)
        pred_fake = discriminator(sar, fake_opt)
        adv_loss = loss_gan(pred_fake, torch.ones_like(pred_fake))

        # Normalize to [0, 1] for SSIM
        fake_opt_norm = (fake_opt + 1) / 2
        opt_real_norm = (opt_real + 1) / 2

        ssim_loss = 1 - piq.ssim(fake_opt_norm, opt_real_norm, data_range=1.0)
        l1_loss = loss_l1(fake_opt, opt_real)

        g_loss = adv_loss + 100 * ssim_loss + 10 * l1_loss

        opt_g.zero_grad()
        g_loss.backward()
        opt_g.step()

        # Warm Restart Scheduler Step (per batch)
        scheduler_g.step(epoch + i / len(dataloader))
        scheduler_d.step(epoch + i / len(dataloader))

        total_d_loss += d_loss.item()
        total_g_loss += g_loss.item()

        if i % 100 == 0:
            print(f"[Epoch {epoch:03d}] Batch {i:03d} - D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

        # Save sample from the 1st batch of each epoch
        if i == 0:
            generator.eval()
            with torch.no_grad():
                sample_sar = sar[:3]
                sample_real = opt_real[:3]
                sample_fake = generator(sample_sar)

                sample_sar_3ch = sample_sar.repeat(1, 3, 1, 1)
                combined = torch.cat([sample_sar_3ch, sample_real, sample_fake], dim=0)

                grid = vutils.make_grid(combined, nrow=4, normalize=True, scale_each=True)
                save_image(grid, f"samples/epoch_{epoch:03d}.png")

    # === Store average losses ===
    avg_g_loss = total_g_loss / len(dataloader)
    avg_d_loss = total_d_loss / len(dataloader)
    g_losses.append(avg_g_loss)
    d_losses.append(avg_d_loss)

    # === Save model checkpoint ===
    torch.save(generator.state_dict(), f"checkpoints/generator_epoch_{epoch:03d}.pth")
    torch.save(discriminator.state_dict(), f"checkpoints/discriminator_epoch_{epoch:03d}.pth")

# --- Plot Loss Curves ---
plt.figure()
plt.plot(g_losses, label="Generator Loss")
plt.plot(d_losses, label="Discriminator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Curves")
plt.legend()
plt.grid(True)
plt.savefig("curves/loss_curve.png")
plt.close()


# Save final models
torch.save(generator.state_dict(), 'generator_final.pth')
torch.save(discriminator.state_dict(), 'discriminator_final.pth')



In [None]:
# --- Mount and Upload to Google Drive ---
from google.colab import drive
drive.mount('/content/drive')

drive_samples_path = "/content/drive/MyDrive/sar_to_optical1/samples"
drive_curves_path = "/content/drive/MyDrive/sar_to_optical1/curves"
drive_checkpoints_path = "/content/drive/MyDrive/sar_to_optical1/checkpoints"

for path in [drive_samples_path, drive_curves_path, drive_checkpoints_path]:
    if os.path.exists(path):
        shutil.rmtree(path)

shutil.copytree("samples", drive_samples_path)
shutil.copytree("curves", drive_curves_path)
shutil.copytree("checkpoints", drive_checkpoints_path)
# Also save to Google Drive
shutil.copy('generator_final.pth', '/content/drive/MyDrive/sar_to_optical/generator_final.pth')
shutil.copy('discriminator_final.pth', '/content/drive/MyDrive/sar_to_optical/discriminator_final.pth')

print("✅ Training complete. Samples, curves, and models uploaded to Google Drive.")


In [None]:
from google.colab import drive
drive.mount('/content/drive')

drive_samples_path = "/content/drive/MyDrive/sar_to_optical1/samples"
drive_curves_path = "/content/drive/MyDrive/sar_to_optical1/curves"
drive_checkpoints_path = "/content/drive/MyDrive/sar_to_optical1/checkpoints"

for path in [drive_samples_path, drive_curves_path, drive_checkpoints_path]:
    if os.path.exists(path):
        shutil.rmtree(path)

shutil.copytree("samples", drive_samples_path)
shutil.copytree("curves", drive_curves_path)



print("✅ Training complete. Samples, curves, and models uploaded to Google Drive.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Training complete. Samples, curves, and models uploaded to Google Drive.


In [None]:
import os
import torch
from torchvision.utils import save_image
import torchvision.utils as vutils

# Set model to eval mode
generator.eval()

# Validation data loader
val_loader = torch.utils.data.DataLoader(
    PairedSAROpticalDataset('val/val/input', 'val/val/output'),
    batch_size=8, shuffle=False
)

# Create folder for validation results
os.makedirs("val_samples", exist_ok=True)

# Disable gradient calculations
with torch.no_grad():
    for i, (sar, opt_real) in enumerate(val_loader):
        sar, opt_real = sar.to(device), opt_real.to(device)
        fake_opt = generator(sar)

        # Convert SAR 1ch → 3ch to make visual comparison better
        sar_rgb = sar.repeat(1, 3, 1, 1)

        # Concatenate SAR, GT, and generated images for visualization
        comparison = torch.cat([sar_rgb, opt_real, fake_opt], dim=0)
        grid = vutils.make_grid(comparison, nrow=8, normalize=True, scale_each=True)
        save_image(grid, f"val_samples/val_batch_{i:03d}.png")

print("✅ Validation outputs saved to val_samples/")


In [None]:
# Save final models
torch.save(generator.state_dict(), 'generator_final.pth')
torch.save(discriminator.state_dict(), 'discriminator_final.pth')

# Also save to Google Drive
shutil.copy('generator_final.pth', '/content/drive/MyDrive/sar_to_optical/generator_final.pth')
shutil.copy('discriminator_final.pth', '/content/drive/MyDrive/sar_to_optical/discriminator_final.pth')


'/content/drive/MyDrive/sar_to_optical/discriminator_final.pth'

In [None]:
!unzip Test_Pairs.zip -d test

Archive:  Test_Pairs.zip
   creating: test/Test_Pairs/Pair-3-Bhopal/
  inflating: test/Test_Pairs/Pair-3-Bhopal/2024-05-20-00 00_2024-05-20-23 59_Sentinel-2_L2A_True_color.jpg  
  inflating: test/Test_Pairs/Pair-3-Bhopal/2024-05-26-00 00_2024-05-26-23 59_Sentinel-1_IW_VV+VH_VH_-_decibel_gamma0.jpg  
   creating: test/Test_Pairs/Pair-4-Kolkata/
  inflating: test/Test_Pairs/Pair-4-Kolkata/2024-05-01-00 00_2024-05-01-23 59_Sentinel-1_IW_VV+VH_VH_-_decibel_gamma0.jpg  
  inflating: test/Test_Pairs/Pair-4-Kolkata/2024-05-03-00 00_2024-05-03-23 59_Sentinel-2_L2A_True_color.jpg  
   creating: test/Test_Pairs/Pair-6-Dehradun/
  inflating: test/Test_Pairs/Pair-6-Dehradun/2024-05-26-00 00_2024-05-26-23 59_Sentinel-1_IW_VV+VH_VH_-_decibel_gamma0.jpg  
  inflating: test/Test_Pairs/Pair-6-Dehradun/2024-05-28-00 00_2024-05-28-23 59_Sentinel-2_L2A_True_color.jpg  


In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import piq
from PIL import Image
from tqdm import tqdm

# === Load Model ===
generator = UNetGenerator().to(device)
generator.load_state_dict(torch.load('generator_final.pth'))
generator.eval()

# === Load and Prepare SAR Image ===
def load_image(path, size=None, grayscale=False):
    img = Image.open(path).convert('L' if grayscale else 'RGB')
    if size: img = img.resize(size)
    return transforms.ToTensor()(img).unsqueeze(0).to(device)  # (1, C, H, W)

# Example usage:
sar_img = load_image("/content/test/Test_Pairs/Pair-3-Bhopal/2024-05-26-00 00_2024-05-26-23 59_Sentinel-1_IW_VV+VH_VH_-_decibel_gamma0.jpg", grayscale=True)  # Shape: (1, 1, 2500, 1800)
opt_img = load_image("/content/test/Test_Pairs/Pair-3-Bhopal/2024-05-20-00 00_2024-05-20-23 59_Sentinel-2_L2A_True_color.jpg")              # Shape: (1, 3, 2500, 1800)

# === Patch-wise Inference ===
def infer_by_patch(model, img_tensor, patch_size=256, stride=256):
    b, c, h, w = img_tensor.shape
    output = torch.zeros((1, 3, h, w)).to(device)
    count_map = torch.zeros((1, 3, h, w)).to(device)

    for i in tqdm(range(0, h - patch_size + 1, stride)):
        for j in range(0, w - patch_size + 1, stride):
            patch = img_tensor[:, :, i:i+patch_size, j:j+patch_size]
            with torch.no_grad():
                out_patch = model(patch)
            output[:, :, i:i+patch_size, j:j+patch_size] += out_patch
            count_map[:, :, i:i+patch_size, j:j+patch_size] += 1

    # Normalize overlapping patches
    output /= count_map
    return output

# Run inference
gen_optical = infer_by_patch(generator, sar_img, patch_size=256, stride=256)

# === Compute SSIM and PSNR ===
# Resize both to ensure dimensions match exactly (in case of rounding)
h, w = opt_img.shape[2:]
gen_optical = F.interpolate(gen_optical, size=(h, w), mode='bilinear', align_corners=False)

# Clamp to valid range [-1, 1] or [0, 1] depending on training
# Ensure output is valid and safe for SSIM
gen_optical = torch.nan_to_num(gen_optical, nan=0.0)
gen_clamped = torch.clamp(gen_optical, 0, 1)
opt_clamped = torch.clamp(opt_img, 0, 1)

# Compute Metrics
ssim_val = piq.ssim(gen_clamped, opt_clamped, data_range=1.0).item()
psnr_val = piq.psnr(gen_clamped, opt_clamped, data_range=1.0).item()




print(f"✅ SSIM: {ssim_val:.4f}, PSNR: {psnr_val:.2f} dB")

# === Save Output Image ===
from torchvision.utils import save_image
save_image(gen_clamped, "generated_output_full.png")


100%|██████████| 6/6 [00:00<00:00, 85.29it/s]


✅ SSIM: 0.0152, PSNR: 5.58 dB


In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# --- Load the trained model ---
generator = UNetGenerator().to('cuda')
generator.load_state_dict(torch.load("generator_final.pth"))
generator.eval()

# --- Load and preprocess SAR image ---
sar_image_path = "/content/2.jpg"  # Your SAR image path (256x256)
img = Image.open(sar_image_path).convert("L")  # Ensure single channel (SAR)

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),           # Converts to [0,1]
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

sar_tensor = transform(img).unsqueeze(0).to('cuda')  # Add batch dimension

# --- Generate Optical Image ---
with torch.no_grad():
    output = generator(sar_tensor)

# --- Postprocess to display/save ---
output = (output.squeeze(0).cpu() + 1) / 2  # [-1, 1] → [0, 1]
output_image = transforms.ToPILImage()(output)

# --- Show or Save ---
output_image.show()
output_image.save("output_optical2.png")


In [7]:
pip install piq


Collecting piq
  Downloading piq-0.8.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
C