# SAR2Optical Fine-tuning with LoRA Adapters

This notebook uses **LoRA (Low-Rank Adaptation)** to fine-tune the pretrained SAR2Optical model on QXSLAB_SAROPT.

**Why LoRA?**
- Freezes pretrained weights (preserves learned features)
- Only trains small adapter layers (~1-5% of parameters)
- Prevents catastrophic forgetting
- Faster training, less overfitting
- Can merge adapters back into original model for inference

## 1. Setup and GPU Check

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install required packages
!pip install -q gdown tqdm pillow

## 2. Download Dataset and Checkpoint

In [None]:
# Download dataset
DATASET_FILE_ID = "1835G9HBouBqmk7tKNnIc5gkJ5B8-4v9I"  # <-- Your QXSLAB_SAROPT.zip ID
!gdown {DATASET_FILE_ID} -O /content/QXSLAB_SAROPT.zip
!unzip -q /content/QXSLAB_SAROPT.zip -d /content/
!ls -la /content/

In [None]:
# Download pretrained checkpoint
CHECKPOINT_FILE_ID = "1avb5ua7fYlgQOarS4Xvi3zpsX6s9Z7NV"  # <-- Your checkpoint ID
!gdown {CHECKPOINT_FILE_ID} -O /content/pix2pix_gen_180.pth
!ls -lh /content/pix2pix_gen_180.pth

## 3. Configuration

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================

# Paths
DATASET_ROOT = "/content/QXSLAB_SAROPT"
SAR_FOLDER = "sar_256_oc_0.2"
OPT_FOLDER = "opt_256_oc_0.2"
PRETRAINED_CHECKPOINT = "/content/pix2pix_gen_180.pth"
OUTPUT_DIR = "/content/lora_checkpoints"

# LoRA Configuration
LORA_CONFIG = {
    "rank": 16,              # LoRA rank (4-64, higher = more capacity)
    "alpha": 32,             # LoRA alpha (scaling factor, usually 2x rank)
    "dropout": 0.1,          # Dropout in LoRA layers
    "target_modules": ["encoder", "decoder"],  # Where to add LoRA
}

# Training parameters
CONFIG = {
    # Model
    "c_in": 3,
    "c_out": 3,
    "lambda_L1": 100.0,
    "use_upsampling": False,
    "mode": "nearest",
    
    # Training
    "num_epochs": 30,         # Can train longer with LoRA
    "batch_size": 32,
    "lr": 0.0002,             # Higher LR is OK for LoRA
    "beta1": 0.5,
    "beta2": 0.999,
    "save_freq": 5,
    "num_workers": 4,
    
    # Dataset split
    "split_ratio": [0.8, 0.1, 0.1],
    "seed": 42,
}

print("LoRA Configuration:")
print(f"  Rank: {LORA_CONFIG['rank']}")
print(f"  Alpha: {LORA_CONFIG['alpha']}")
print(f"  Dropout: {LORA_CONFIG['dropout']}")
print(f"\nTraining Configuration:")
print(f"  Epochs: {CONFIG['num_epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['lr']}")

## 4. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
import torchvision.transforms.functional as TF
from PIL import Image
from pathlib import Path
from typing import Tuple, List, Optional
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
import math

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 5. LoRA Layer Implementation

In [None]:
# ============================================================================
# LoRA LAYER IMPLEMENTATIONS
# ============================================================================

class LoRALayer(nn.Module):
    """
    LoRA (Low-Rank Adaptation) layer for Conv2d.
    
    Instead of updating W directly, we learn:
        W' = W + (alpha/rank) * B @ A
    
    Where A and B are low-rank matrices.
    """
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Low-rank matrices (A: down-projection, B: up-projection)
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
        # Initialize A with Kaiming, B with zeros (so LoRA starts as identity)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch, in_features) or will be reshaped
        # Returns delta to add to original output
        return (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling


class LoRAConv2d(nn.Module):
    """
    LoRA adapter for Conv2d layers.
    Wraps an existing Conv2d and adds low-rank adaptation.
    """
    def __init__(
        self,
        conv: nn.Conv2d,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.conv = conv
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Freeze original conv weights
        for param in self.conv.parameters():
            param.requires_grad = False
        
        in_channels = conv.in_channels
        out_channels = conv.out_channels
        kernel_size = conv.kernel_size[0]
        
        # LoRA uses 1x1 convolutions for efficiency
        # A: (rank, in_channels, 1, 1) - down projection
        # B: (out_channels, rank, 1, 1) - up projection
        self.lora_A = nn.Conv2d(in_channels, rank, kernel_size=1, bias=False)
        self.lora_B = nn.Conv2d(rank, out_channels, kernel_size=1, bias=False)
        
        # Initialize
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)
        
        self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original conv output
        out = self.conv(x)
        
        # LoRA path: x -> A -> B -> scale
        # Need to handle spatial dimensions
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        
        # Match spatial size if different (due to stride/padding)
        if lora_out.shape[2:] != out.shape[2:]:
            lora_out = F.interpolate(lora_out, size=out.shape[2:], mode='bilinear', align_corners=False)
        
        return out + lora_out
    
    def merge_weights(self):
        """Merge LoRA weights into original conv for inference."""
        # For 1x1 LoRA on larger kernels, we add to center
        # This is a simplified merge - works best when kernel_size matches
        with torch.no_grad():
            # Get LoRA contribution as a 1x1 kernel
            # lora_A: (rank, in_ch, 1, 1), lora_B: (out_ch, rank, 1, 1)
            # Combined: (out_ch, in_ch, 1, 1)
            lora_weight = (self.lora_B.weight @ self.lora_A.weight.view(self.rank, -1)).view(
                self.conv.out_channels, self.conv.in_channels, 1, 1
            ) * self.scaling
            
            # Add to center of original kernel
            k = self.conv.kernel_size[0]
            center = k // 2
            self.conv.weight[:, :, center:center+1, center:center+1] += lora_weight
        
        return self.conv


print("LoRA layers defined!")

## 6. Base Network Definitions

In [None]:
# ============================================================================
# BASE NETWORK LAYERS (Modified for LoRA compatibility)
# ============================================================================

class DownsamplingBlock(nn.Module):
    def __init__(self, c_in, c_out, kernel_size=4, stride=2, 
                 padding=1, negative_slope=0.2, use_norm=True):
        super().__init__()
        # Keep conv as separate attribute so LoRA can replace it
        self.conv = nn.Conv2d(c_in, c_out, kernel_size, stride, padding, bias=(not use_norm))
        self.norm = nn.BatchNorm2d(c_out) if use_norm else nn.Identity()
        self.act = nn.LeakyReLU(negative_slope)
        
    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class UpsamplingBlock(nn.Module):
    def __init__(self, c_in, c_out, kernel_size=4, stride=2, 
                 padding=1, use_dropout=False, use_upsampling=False, mode='nearest'):
        super().__init__()
        if use_upsampling:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=2, mode=mode),
                nn.Conv2d(c_in, c_out, 3, 1, padding, bias=False)
            )
        else:
            self.conv = nn.ConvTranspose2d(c_in, c_out, kernel_size, stride, padding, bias=False)
        self.norm = nn.BatchNorm2d(c_out)
        self.dropout = nn.Dropout(0.5) if use_dropout else nn.Identity()
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.dropout(self.norm(self.conv(x))))


class UnetEncoder(nn.Module):
    def __init__(self, c_in=3):
        super().__init__()
        self.enc1 = DownsamplingBlock(c_in, 64, use_norm=False)
        self.enc2 = DownsamplingBlock(64, 128)
        self.enc3 = DownsamplingBlock(128, 256)
        self.enc4 = DownsamplingBlock(256, 512)
        self.enc5 = DownsamplingBlock(512, 512)
        self.enc6 = DownsamplingBlock(512, 512)
        self.enc7 = DownsamplingBlock(512, 512)
        self.enc8 = DownsamplingBlock(512, 512)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        x5 = self.enc5(x4)
        x6 = self.enc6(x5)
        x7 = self.enc7(x6)
        x8 = self.enc8(x7)
        return [x8, x7, x6, x5, x4, x3, x2, x1]


class UnetDecoder(nn.Module):
    def __init__(self, use_upsampling=False, mode='nearest'):
        super().__init__()
        self.dec1 = UpsamplingBlock(512, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode)
        self.dec2 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode)
        self.dec3 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode)
        self.dec4 = UpsamplingBlock(1024, 512, use_upsampling=use_upsampling, mode=mode)
        self.dec5 = UpsamplingBlock(1024, 256, use_upsampling=use_upsampling, mode=mode)
        self.dec6 = UpsamplingBlock(512, 128, use_upsampling=use_upsampling, mode=mode)
        self.dec7 = UpsamplingBlock(256, 64, use_upsampling=use_upsampling, mode=mode)
        self.dec8 = UpsamplingBlock(128, 64, use_upsampling=use_upsampling, mode=mode)

    def forward(self, x):
        x9 = torch.cat([x[1], self.dec1(x[0])], 1)
        x10 = torch.cat([x[2], self.dec2(x9)], 1)
        x11 = torch.cat([x[3], self.dec3(x10)], 1)
        x12 = torch.cat([x[4], self.dec4(x11)], 1)
        x13 = torch.cat([x[5], self.dec5(x12)], 1)
        x14 = torch.cat([x[6], self.dec6(x13)], 1)
        x15 = torch.cat([x[7], self.dec7(x14)], 1)
        return self.dec8(x15)


class UnetGenerator(nn.Module):
    def __init__(self, c_in=3, c_out=3, use_upsampling=False, mode='nearest'):
        super().__init__()
        self.encoder = UnetEncoder(c_in=c_in)
        self.decoder = UnetDecoder(use_upsampling=use_upsampling, mode=mode)
        self.head = nn.Sequential(
            nn.Conv2d(64, c_out, 3, 1, 1, bias=True),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.head(self.decoder(self.encoder(x)))


print("Base networks defined!")

## 7. LoRA-enabled Generator

In [None]:
# ============================================================================
# LoRA-ENABLED GENERATOR (SIMPLIFIED - Just wrap conv layers)
# ============================================================================

class LoRAUnetGenerator(nn.Module):
    """
    UNet Generator with LoRA adapters.
    
    SIMPLE APPROACH: Replace conv layers in-place with LoRA-wrapped versions.
    This way we can use the original forward() method without changes.
    """
    def __init__(
        self, 
        base_generator: UnetGenerator,
        rank: int = 16,
        alpha: float = 32,
        dropout: float = 0.1,
        target_layers: List[str] = None,
    ):
        super().__init__()
        self.base = base_generator
        self.rank = rank
        self.alpha = alpha
        
        # Freeze all base model parameters first
        for param in self.base.parameters():
            param.requires_grad = False
        
        # Default: add LoRA to encoder conv layers only (simpler, more stable)
        if target_layers is None:
            target_layers = ['enc']
        
        # Track LoRA modules for parameter collection
        self.lora_modules = nn.ModuleList()
        
        # Add LoRA adapters by replacing conv layers in-place
        self._add_lora_adapters(rank, alpha, dropout, target_layers)
        
        # Count parameters
        total_params = sum(p.numel() for p in self.base.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"\nTotal base parameters: {total_params:,}")
        print(f"Trainable LoRA parameters: {trainable_params:,}")
        print(f"Trainable ratio: {100*trainable_params/total_params:.2f}%")
    
    def _add_lora_adapters(self, rank, alpha, dropout, target_layers):
        """Replace conv layers with LoRA-wrapped versions."""
        
        # Process encoder
        if 'enc' in target_layers or 'encoder' in target_layers:
            for name in ['enc1', 'enc2', 'enc3', 'enc4', 'enc5', 'enc6', 'enc7', 'enc8']:
                block = getattr(self.base.encoder, name)
                # Create LoRA wrapper for the conv
                lora_conv = LoRAConv2d(block.conv, rank=rank, alpha=alpha, dropout=dropout)
                # Replace the conv in the block
                block.conv = lora_conv
                # Track for parameter collection
                self.lora_modules.append(lora_conv)
                print(f"  Added LoRA to encoder.{name}: {lora_conv.conv.in_channels} -> {lora_conv.conv.out_channels}")
        
        # Process decoder (optional - can be unstable)
        if 'dec' in target_layers or 'decoder' in target_layers:
            for name in ['dec1', 'dec2', 'dec3', 'dec4', 'dec5', 'dec6', 'dec7', 'dec8']:
                block = getattr(self.base.decoder, name)
                conv_layer = block.conv
                
                # Handle ConvTranspose2d
                if isinstance(conv_layer, nn.ConvTranspose2d):
                    lora_conv = LoRAConv2d(conv_layer, rank=rank, alpha=alpha, dropout=dropout)
                    block.conv = lora_conv
                    self.lora_modules.append(lora_conv)
                    print(f"  Added LoRA to decoder.{name}: {lora_conv.conv.in_channels} -> {lora_conv.conv.out_channels}")
                # Handle Sequential (upsampling mode)
                elif isinstance(conv_layer, nn.Sequential):
                    for i, layer in enumerate(conv_layer):
                        if isinstance(layer, nn.Conv2d):
                            lora_conv = LoRAConv2d(layer, rank=rank, alpha=alpha, dropout=dropout)
                            conv_layer[i] = lora_conv
                            self.lora_modules.append(lora_conv)
                            print(f"  Added LoRA to decoder.{name}[{i}]: {lora_conv.conv.in_channels} -> {lora_conv.conv.out_channels}")
                            break
    
    def forward(self, x):
        # Simply use the base model's forward - LoRA is already embedded!
        return self.base(x)
    
    def get_lora_parameters(self):
        """Return only LoRA parameters for optimizer."""
        params = []
        for lora in self.lora_modules:
            params.append(lora.lora_A.weight)
            params.append(lora.lora_B.weight)
        return params
    
    def save_lora_weights(self, path):
        """Save only LoRA weights (small file)."""
        lora_state = {
            'lora_modules': [
                {
                    'lora_A': lora.lora_A.state_dict(),
                    'lora_B': lora.lora_B.state_dict(),
                }
                for lora in self.lora_modules
            ],
            'config': {'rank': self.rank, 'alpha': self.alpha}
        }
        torch.save(lora_state, path)
        print(f"LoRA weights saved to {path}")
    
    def load_lora_weights(self, path):
        """Load LoRA weights."""
        lora_state = torch.load(path, map_location='cpu', weights_only=False)
        for i, lora in enumerate(self.lora_modules):
            lora.lora_A.load_state_dict(lora_state['lora_modules'][i]['lora_A'])
            lora.lora_B.load_state_dict(lora_state['lora_modules'][i]['lora_B'])
        print(f"LoRA weights loaded from {path}")


print("LoRA Generator defined!")

## 8. Discriminator (unchanged)

In [None]:
# ============================================================================
# DISCRIMINATOR (Standard PatchGAN)
# ============================================================================

class PatchDiscriminator(nn.Module):
    def __init__(self, c_in=6, c_hid=64, n_layers=3):
        super().__init__()
        layers = [DownsamplingBlock(c_in, c_hid, use_norm=False)]
        
        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            layers.append(DownsamplingBlock(c_hid * nf_mult_prev, c_hid * nf_mult))
        
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        layers.append(DownsamplingBlock(c_hid * nf_mult_prev, c_hid * nf_mult, stride=1))
        layers.append(nn.Conv2d(c_hid * nf_mult, 1, kernel_size=4, stride=1, padding=1))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)


print("Discriminator defined!")

## 9. Dataset Class

In [None]:
# ============================================================================
# DATASET CLASS
# ============================================================================

class QXSLABDataset(Dataset):
    def __init__(self, root_dir, sar_folder="sar_256_oc_0.2", opt_folder="opt_256_oc_0.2",
                 split=None, split_ratio=(0.8, 0.1, 0.1), augment=False, seed=42):
        self.root_dir = Path(root_dir)
        self.sar_dir = self.root_dir / sar_folder
        self.opt_dir = self.root_dir / opt_folder
        self.augment = augment
        
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.Resize((256, 256)),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.5], std=[0.5]),
        ])
        
        # Find valid pairs
        self.image_pairs = self._find_valid_pairs()
        
        if split:
            self.image_pairs = self._apply_split(split, split_ratio, seed)
            print(f"[{split}] {len(self.image_pairs)} pairs (augment={augment})")
    
    def _find_valid_pairs(self):
        sar_files = {f.stem: f for f in self.sar_dir.glob("*.png")}
        opt_files = {f.stem: f for f in self.opt_dir.glob("*.png")}
        
        pairs = [(sar_files[n], opt_files[n]) for n in sar_files if n in opt_files]
        try:
            pairs.sort(key=lambda x: int(x[0].stem))
        except:
            pairs.sort(key=lambda x: x[0].stem)
        
        print(f"Found {len(pairs)} valid pairs")
        return pairs
    
    def _apply_split(self, split, split_ratio, seed):
        random.seed(seed)
        indices = list(range(len(self.image_pairs)))
        random.shuffle(indices)
        
        n = len(indices)
        train_end = int(n * split_ratio[0])
        val_end = train_end + int(n * split_ratio[1])
        
        if split == 'train':
            indices = indices[:train_end]
        elif split == 'val':
            indices = indices[train_end:val_end]
        elif split == 'test':
            indices = indices[val_end:]
        
        return [self.image_pairs[i] for i in indices]
    
    def _apply_augmentation(self, sar_img, opt_img):
        if random.random() > 0.5:
            sar_img = TF.hflip(sar_img)
            opt_img = TF.hflip(opt_img)
        if random.random() > 0.5:
            sar_img = TF.vflip(sar_img)
            opt_img = TF.vflip(opt_img)
        angle = random.choice([0, 90, 180, 270])
        if angle != 0:
            sar_img = TF.rotate(sar_img, angle)
            opt_img = TF.rotate(opt_img, angle)
        return sar_img, opt_img
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        sar_path, opt_path = self.image_pairs[idx]
        sar_img = Image.open(sar_path).convert('RGB')
        opt_img = Image.open(opt_path).convert('RGB')
        
        if self.augment:
            sar_img, opt_img = self._apply_augmentation(sar_img, opt_img)
        
        return self.transform(sar_img), self.transform(opt_img)


print("Dataset class defined!")

## 10. Create Datasets and DataLoaders

In [None]:
# Create datasets
print("Creating datasets...\n")

train_dataset = QXSLABDataset(
    root_dir=DATASET_ROOT,
    split='train',
    augment=True,
    seed=CONFIG['seed']
)

val_dataset = QXSLABDataset(
    root_dir=DATASET_ROOT,
    split='val',
    augment=False,
    seed=CONFIG['seed']
)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                          shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                        shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 11. Initialize Models

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

# ============================================================================
# LOAD PRETRAINED WEIGHTS WITH KEY REMAPPING
# ============================================================================

def remap_state_dict(old_state_dict):
    """
    Remap keys from old format (conv_block) to new format (conv, norm, act).
    Old: encoder.enc1.conv_block.0.weight -> New: encoder.enc1.conv.weight
    """
    new_state_dict = {}
    
    for key, value in old_state_dict.items():
        new_key = key
        
        # Encoder remapping: conv_block.0 -> conv, conv_block.1 -> norm
        if 'encoder' in key and 'conv_block' in key:
            if '.conv_block.0.' in key:
                new_key = key.replace('.conv_block.0.', '.conv.')
            elif '.conv_block.1.' in key:
                new_key = key.replace('.conv_block.1.', '.norm.')
        
        # Decoder remapping: conv_block.0 -> conv, conv_block.1 -> norm
        if 'decoder' in key and 'conv_block' in key:
            if '.conv_block.0.' in key:
                new_key = key.replace('.conv_block.0.', '.conv.')
            elif '.conv_block.1.' in key:
                new_key = key.replace('.conv_block.1.', '.norm.')
        
        new_state_dict[new_key] = value
    
    return new_state_dict


# Create base generator
print("Creating base generator...")
base_gen = UnetGenerator(c_in=3, c_out=3, use_upsampling=CONFIG['use_upsampling'])

# Load and remap pretrained weights
print(f"Loading pretrained weights from: {PRETRAINED_CHECKPOINT}")
old_state_dict = torch.load(PRETRAINED_CHECKPOINT, map_location=device, weights_only=False)
new_state_dict = remap_state_dict(old_state_dict)

# Load with strict=False to handle any remaining mismatches
missing, unexpected = base_gen.load_state_dict(new_state_dict, strict=False)
if missing:
    print(f"  Missing keys: {len(missing)}")
if unexpected:
    print(f"  Unexpected keys: {len(unexpected)}")
print("Pretrained weights loaded!\n")

# Test base model first (must use eval mode for BatchNorm with small spatial sizes)
print("Testing base model...")
test_input = torch.randn(1, 3, 256, 256).to(device)
base_gen.to(device)
base_gen.eval()  # IMPORTANT: eval mode for BatchNorm
with torch.no_grad():
    test_output = base_gen(test_input)
print(f"  Input: {test_input.shape} -> Output: {test_output.shape}")
print("Base model works!\n")

# Now wrap with LoRA
print("Adding LoRA adapters...")
generator = LoRAUnetGenerator(
    base_gen,
    rank=LORA_CONFIG['rank'],
    alpha=LORA_CONFIG['alpha'],
    dropout=LORA_CONFIG['dropout'],
    target_layers=['enc'],  # Only encoder for stability
).to(device)

# Test LoRA model
print("\nTesting LoRA model...")
generator.eval()  # eval mode for test
with torch.no_grad():
    test_output_lora = generator(test_input)
print(f"  Input: {test_input.shape} -> Output: {test_output_lora.shape}")
print("LoRA model works!\n")

# Create discriminator
print("Creating discriminator...")
discriminator = PatchDiscriminator(c_in=6).to(device)
disc_params = sum(p.numel() for p in discriminator.parameters())
print(f"Discriminator parameters: {disc_params:,}")

In [None]:
# Optimizers - ONLY train LoRA params for generator
lora_params = generator.get_lora_parameters()
print(f"LoRA parameters to train: {sum(p.numel() for p in lora_params):,}")

optimizer_G = torch.optim.Adam(lora_params, lr=CONFIG['lr'], betas=(CONFIG['beta1'], CONFIG['beta2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=CONFIG['lr'], betas=(CONFIG['beta1'], CONFIG['beta2']))

# Loss functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

print("Optimizers created!")

## 12. Training Functions

In [None]:
def train_epoch(gen, disc, loader, opt_G, opt_D, device, epoch):
    gen.train()
    disc.train()
    
    total_D, total_G, total_L1 = 0, 0, 0
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    
    for sar, opt in pbar:
        sar, opt = sar.to(device), opt.to(device)
        batch_size = sar.size(0)
        
        # Generate fake
        fake = gen(sar)
        
        # ----- Train Discriminator -----
        opt_D.zero_grad()
        
        real_pair = torch.cat([sar, opt], 1)
        fake_pair = torch.cat([sar, fake.detach()], 1)
        
        pred_real = disc(real_pair)
        pred_fake = disc(fake_pair)
        
        loss_D = 0.5 * (
            criterion_GAN(pred_real, torch.ones_like(pred_real)) +
            criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
        )
        loss_D.backward()
        opt_D.step()
        
        # ----- Train Generator (LoRA only) -----
        opt_G.zero_grad()
        
        fake_pair = torch.cat([sar, fake], 1)
        pred_fake = disc(fake_pair)
        
        loss_G_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_G_L1 = criterion_L1(fake, opt)
        loss_G = loss_G_GAN + CONFIG['lambda_L1'] * loss_G_L1
        
        loss_G.backward()
        opt_G.step()
        
        total_D += loss_D.item()
        total_G += loss_G.item()
        total_L1 += loss_G_L1.item()
        
        pbar.set_postfix({'D': f"{loss_D.item():.3f}", 'G': f"{loss_G.item():.3f}", 'L1': f"{loss_G_L1.item():.3f}"})
    
    n = len(loader)
    return {'D': total_D/n, 'G': total_G/n, 'L1': total_L1/n}


def validate(gen, disc, loader, device):
    gen.eval()
    disc.eval()
    
    total_D, total_G, total_L1 = 0, 0, 0
    
    with torch.no_grad():
        for sar, opt in loader:
            sar, opt = sar.to(device), opt.to(device)
            fake = gen(sar)
            
            real_pair = torch.cat([sar, opt], 1)
            fake_pair = torch.cat([sar, fake], 1)
            
            pred_real = disc(real_pair)
            pred_fake = disc(fake_pair)
            
            loss_D = 0.5 * (
                criterion_GAN(pred_real, torch.ones_like(pred_real)) +
                criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
            )
            loss_G_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
            loss_G_L1 = criterion_L1(fake, opt)
            loss_G = loss_G_GAN + CONFIG['lambda_L1'] * loss_G_L1
            
            total_D += loss_D.item()
            total_G += loss_G.item()
            total_L1 += loss_G_L1.item()
    
    n = len(loader)
    return {'D': total_D/n, 'G': total_G/n, 'L1': total_L1/n}


def visualize(gen, loader, device, num=4):
    gen.eval()
    sar, opt = next(iter(loader))
    sar, opt = sar[:num].to(device), opt[:num].to(device)
    
    with torch.no_grad():
        fake = gen(sar)
    
    sar = (sar * 0.5 + 0.5).cpu()
    opt = (opt * 0.5 + 0.5).cpu()
    fake = (fake * 0.5 + 0.5).clamp(0, 1).cpu()
    
    fig, axes = plt.subplots(num, 3, figsize=(12, 3*num))
    for i in range(num):
        axes[i,0].imshow(sar[i].permute(1,2,0))
        axes[i,0].set_title('SAR Input')
        axes[i,0].axis('off')
        axes[i,1].imshow(fake[i].permute(1,2,0))
        axes[i,1].set_title('Generated')
        axes[i,1].axis('off')
        axes[i,2].imshow(opt[i].permute(1,2,0))
        axes[i,2].set_title('Ground Truth')
        axes[i,2].axis('off')
    plt.tight_layout()
    plt.show()


print("Training functions ready!")

## 13. Training Loop

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

history = {'train_D':[], 'train_G':[], 'train_L1':[], 'val_D':[], 'val_G':[], 'val_L1':[]}
best_val_l1 = float('inf')

print("="*60)
print("STARTING LoRA FINE-TUNING")
print("="*60)
print(f"LoRA Rank: {LORA_CONFIG['rank']}")
print(f"LoRA Alpha: {LORA_CONFIG['alpha']}")
print(f"Epochs: {CONFIG['num_epochs']}")
print(f"Learning rate: {CONFIG['lr']}")
print("="*60)

In [None]:
for epoch in range(1, CONFIG['num_epochs'] + 1):
    # Train
    train_loss = train_epoch(generator, discriminator, train_loader, 
                             optimizer_G, optimizer_D, device, epoch)
    
    # Validate
    val_loss = validate(generator, discriminator, val_loader, device)
    
    # Store history
    history['train_D'].append(train_loss['D'])
    history['train_G'].append(train_loss['G'])
    history['train_L1'].append(train_loss['L1'])
    history['val_D'].append(val_loss['D'])
    history['val_G'].append(val_loss['G'])
    history['val_L1'].append(val_loss['L1'])
    
    print(f"Epoch {epoch}: Train[D:{train_loss['D']:.4f} G:{train_loss['G']:.4f} L1:{train_loss['L1']:.4f}] "
          f"Val[D:{val_loss['D']:.4f} G:{val_loss['G']:.4f} L1:{val_loss['L1']:.4f}]")
    
    # Save best model (based on val L1)
    if val_loss['L1'] < best_val_l1:
        best_val_l1 = val_loss['L1']
        generator.save_lora_weights(f"{OUTPUT_DIR}/lora_weights_best.pth")
        torch.save(discriminator.state_dict(), f"{OUTPUT_DIR}/disc_best.pth")
        print(f"  -> New best! Val L1: {best_val_l1:.4f}")
    
    # Periodic checkpoint
    if epoch % CONFIG['save_freq'] == 0:
        generator.save_lora_weights(f"{OUTPUT_DIR}/lora_weights_epoch{epoch}.pth")
        print(f"  -> Checkpoint saved")
    
    # Visualize
    if epoch % 10 == 0 or epoch == 1:
        visualize(generator, val_loader, device)

# Save final
generator.save_lora_weights(f"{OUTPUT_DIR}/lora_weights_final.pth")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print(f"Best Val L1: {best_val_l1:.4f}")
print("="*60)

## 14. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
epochs = range(1, len(history['train_D'])+1)

axes[0].plot(epochs, history['train_D'], label='Train')
axes[0].plot(epochs, history['val_D'], label='Val')
axes[0].set_title('Discriminator Loss')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(epochs, history['train_G'], label='Train')
axes[1].plot(epochs, history['val_G'], label='Val')
axes[1].set_title('Generator Loss')
axes[1].legend()
axes[1].grid(True)

axes[2].plot(epochs, history['train_L1'], label='Train')
axes[2].plot(epochs, history['val_L1'], label='Val')
axes[2].set_title('L1 Loss')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_history.png", dpi=150)
plt.show()

## 15. Load Best Model and Visualize

In [None]:
# Load best LoRA weights
generator.load_lora_weights(f"{OUTPUT_DIR}/lora_weights_best.pth")
print("Best model loaded!\n")

print("Final Results on Validation Set:")
visualize(generator, val_loader, device, num=6)

## 16. Export Merged Model (Optional)

Merge LoRA weights into the base model for deployment (no LoRA overhead at inference).

In [None]:
# Option 1: Save just LoRA weights (small file, ~2-5 MB)
print("LoRA weights saved at:")
!ls -lh {OUTPUT_DIR}/lora_weights_*.pth

# Option 2: Export merged model (full size, ~200 MB, no LoRA needed at inference)
# Uncomment to use:
# merged_model = generator.merge_and_export(f"{OUTPUT_DIR}/merged_generator.pth")

## 17. Download Checkpoints

In [None]:
# List all saved files
print("Saved files:")
!ls -lh {OUTPUT_DIR}/

# Note: LoRA weights are very small (~2-5 MB) so easy to download!

In [None]:
# Mount Google Drive to save
from google.colab import drive
drive.mount('/content/drive')

# Copy to Drive
!cp -r {OUTPUT_DIR}/* /content/drive/MyDrive/SAR2Optical_LoRA/
print("Files copied to Google Drive!")

---

## Summary

### What LoRA Does:
- Freezes pretrained weights (54M parameters)
- Only trains small adapter layers (~500K-2M parameters)
- Prevents catastrophic forgetting of learned features
- Faster training, less overfitting

### Files Saved:
- `lora_weights_best.pth` - Best LoRA weights (~2-5 MB)
- `lora_weights_final.pth` - Final LoRA weights
- `disc_best.pth` - Discriminator weights

### For Inference:
1. Load base generator with pretrained weights
2. Wrap with LoRAUnetGenerator
3. Load LoRA weights
4. Run inference

Or merge LoRA into base model for deployment without LoRA overhead.