# **IMPORTS**

In [None]:
%pip install -q torch torchvision transformers scikit-learn pytorch_lightning
%pip install -q tensorboard

In [None]:
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

import torch.nn as nn
from torchvision import models

import pytorch_lightning as pylight
import torch.optim as optim

import glob
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

import os

from google.colab import drive
drive.mount('/content/drive')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

# **GLOBALS**

In [None]:
# ----------------------------------------- DATA
# standard ImageNet normalization
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

BATCH_SIZE = 32
TEST_BATCH_SIZE = 32


# ----------------------------------------- Paths
NUM_TRANSFORMER_MODEL = 3

DATASET_DIR = "/content/drive/MyDrive/CVUSA_subset/"
SYNT_DATASET_DIR = f"/content/drive/MyDrive/CVUSA_subset/generated_images/generated_images_model_{NUM_TRANSFORMER_MODEL}/"
LOG_DIR = "/content/drive/MyDrive/JFLN_logs/"

CKPT_PATH = LOG_DIR + "checkpoints_1/"


# ----------------------------------------- TRAIN
LR_vgg = 1e-4
LR_fnn = 1e-3
WEIGHT_DECAY = 1e-4

GRADIENT_CLIP = 0.8

EPOCHS = 25

START_FROM_EPOCH = None     # None to start training from scratch, or specify an epoch number to resume training (-1 to use last checkpoint)

# **UTILS**

In [None]:
# Define color-to-class mapping
color_map = {
    (255, 255, 255): 0,   # white -> class 0
    (255, 0, 0): 1,       # red -> class 1
    (0, 0, 255): 2,       # blue -> class 2
    (0, 255, 0): 3,       # green -> class 3
    (255, 255, 0): 4,     # yellow -> class 4
    (0, 255, 255): 5,     # cyan -> class 5

    # pixels that are not in previous classes are considered 0
    (255, 0, 255): 0,
    (255, 255, 255): 0
}


def rgb_to_label(rgb_image):
    """Convert RGB image to class index tensor"""

    # Convert to numpy array
    np_image = np.array(rgb_image)
    # manage different colors
    np_image = np.where(np_image < 128, 0, 255)
    # Create empty label map
    label_map = np.zeros((np_image.shape[0], np_image.shape[1]), dtype=np.int64)

    # Map colors to class indices
    for color, class_idx in color_map.items():
        color_arr = np.array(color)
        mask = (np_image == color_arr).all(axis=-1)
        label_map[mask] = class_idx

    return torch.from_numpy(label_map).unsqueeze(0)


# path = '/content/drive/MyDrive/CVUSA_subset/polarmap/segmap/output0000008.png'
# img = Image.open(path)
# img_np = np.array(img)
# print(img_np.shape)     # (128, 512, 3)

# res = rgb_to_label(img)
# print(res)
# print(res.shape)


def to_onehot(x, num_classes=6):
    # x is a LongTensor [1×H×W] with values 0..5
    x = x.long().squeeze(0)                      # [H×W]
    onehot = F.one_hot(x, num_classes=num_classes)  # [H×W×6]
    return onehot.permute(2, 0, 1).float()           # [6×H×W]

In [8]:
def read_triplets_csv(csv_path):
    """Reads CSV file into list of (aerial, ground, seg) triplets"""
    triplets = []
    with open(csv_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            triplets.append((
                parts[0].strip(),  # ground path
                parts[1].strip(),  # aerial path
                parts[2].strip()   # seg path
            ))
    return triplets


# train_triplets = read_triplets_csv("./CVUSA_subset/train.csv")
# print(train_triplets[:1])  # Print first 5 triplets for verification

In [None]:
# Triplet loss with weighted soft margin
def WeightedSoftMarginTripletLoss(a, p, n, margin=1.0):
    # a: anchor, p: positive, n: negative
    pos_dist = F.pairwise_distance(a, p, p=2)
    neg_dist = F.pairwise_distance(a, n, p=2)
    return torch.log1p(torch.exp(pos_dist - neg_dist + margin)).mean()

# **DATA**

In [None]:
# Dataset definition
class CVUSADataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, synt_dataset_dir, triplet_list, img_size=224, add_dir = "", train=False):
        self.dataset_dir = dataset_dir
        self.synt_dataset_dir = synt_dataset_dir
        self.triplet_list = triplet_list
        self.img_size = img_size
        self.add_dir = add_dir

        # transformations with data augmentation
        if train:
            # For ground view images
            self.ground_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomPerspective(distortion_scale=0.15, p=0.4),    # random perspective
                transforms.RandomApply([transforms.ColorJitter(0.3, 0.2, 0.1, 0.05)], p=0.5),   # augment brightness, contrast, saturation and hue
                transforms.RandomGrayscale(p=0.1),    # convert image to grayscale
                transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),   # blur
                transforms.ToTensor(),
                transforms.Normalize(mean=MEAN, std=STD)
            ])
            # For aerial images
            self.aerial_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomApply([transforms.ColorJitter(0.3, 0.2, 0.1, 0.05)], p=0.5),   # augment brightness, contrast, saturation and hue
                transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),   # blur
                transforms.ToTensor(),
                transforms.Normalize(mean=MEAN, std=STD)
            ])
        else:
            # For ground view images
            self.ground_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # standard ImageNet normalization
            ])
            # For aerial images
            self.aerial_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=MEAN, std=STD)
            ])



        # For aerial segmentation maps
        self.segmentation_transform = transforms.Compose([
            transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.Lambda(rgb_to_label),
            transforms.Lambda(lambda x: to_onehot(x, num_classes=6))
        ])


    def __len__(self):
        return len(self.triplet_list)

    def __getitem__(self, idx):
        ground_rel, aerial_rel, _ = self.triplet_list[idx]

        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Load images
                ground_img = Image.open(self.dataset_dir + ground_rel).convert('RGB')
                aerial_img = Image.open(self.dataset_dir + aerial_rel).convert('RGB')

                img_id_number = ground_rel.split("/")[-1].split(".")[0]
                # Load synthetic images
                synth_img = Image.open((self.synt_dataset_dir + self.add_dir + "/aerial_predictions/seg_pred_" + img_id_number + ".png").strip()).convert('RGB')
                seg_synth_img = Image.open((self.synt_dataset_dir + self.add_dir + "/segmap_predictions/aerial_pred_" + img_id_number + ".png").strip()).convert('RGB')

                # Break if successful
                break
            except Exception as e:
                print(f"Error loading images for index {idx}: {e}")
                if attempt == max_retries - 1:
                    raise e


        # Apply transforms
        ground_tensor = self.ground_transform(ground_img).to(dtype=torch.float32)
        aerial_tensor = self.aerial_transform(aerial_img).to(dtype=torch.float32)
        synth_tensor = self.aerial_transform(synth_img).to(dtype=torch.float32)
        seg_synth_tensor = self.segmentation_transform(seg_synth_img).to(dtype=torch.float32)

        return ground_tensor, aerial_tensor, synth_tensor, seg_synth_tensor

In [None]:
train_triplets = read_triplets_csv("/content/drive/MyDrive/CVUSA_subset/train.csv")
train_triplets, val_triplets = train_test_split(train_triplets, test_size=0.15, random_state=19)  # training/validation set
test_triplets = read_triplets_csv("/content/drive/MyDrive/CVUSA_subset/val.csv")        # test set

train_dataset = CVUSADataset(DATASET_DIR, SYNT_DATASET_DIR, train_triplets, add_dir="train_set", train=True)
val_dataset = CVUSADataset(DATASET_DIR, SYNT_DATASET_DIR, val_triplets, add_dir="val_set")
test_dataset = CVUSADataset(DATASET_DIR, SYNT_DATASET_DIR, test_triplets, add_dir="test_set")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# **NETWORK**

## model

In [None]:
class JointFeatureLearningNet(nn.Module):
    def __init__(self, embed_dim=512, freeze_vgg_layers=9):
        super().__init__()
        
        self.vgg_ground = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        self.vgg_aerial = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        self.vgg_seg = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

        # Modify size of input for segmaps: 6 channels (one-hot)
        self.vgg_seg.features[0] = nn.Conv2d(6, 64, kernel_size=3, padding=1)
        # Initiate weights
        nn.init.kaiming_normal_(self.vgg_seg.features[0].weight, mode='fan_out', nonlinearity='relu')

        # Freezing initial layers for finetuning
        for backbone in (self.vgg_ground.features, self.vgg_aerial.features, self.vgg_seg.features):
            for layer in backbone[:freeze_vgg_layers]:
                for p in layer.parameters():
                    p.requires_grad = False
        
        self.reduce = nn.Conv2d(512, 128, kernel_size=1)
        nn.init.kaiming_normal_(self.reduce.weight, nonlinearity='relu')

        # Define embedding heads for multi-scale features
        self.fc_ground = nn.Linear(3 * 128 * 7 * 7, embed_dim)
        self.fc_aerial = nn.Linear(3 * 128 * 7 * 7, embed_dim)
        self.fc_seg    = nn.Linear(3 * 128 * 7 * 7, embed_dim)
    

    def _extract_multiscale(self, backbone, x):
        # Collect activations after features[16], [23], [30] (conv6,7,8 in VGG16)
        layers = [16, 23, 30]
        feats = []
        for idx, layer in enumerate(backbone.features):
            x = layer(x)
            if idx in layers:
                x = F.relu(self.reduce(x))      # B×128×H×W
                feats.append(x.view(x.size(0), -1))
        return torch.cat(feats, dim=1)  # B x (3*C*H*W)


    def forward(self, ground_img, aerial_img, synth_img, seg_synth_img):
        # Multi-scale features
        f_ground_ms = self._extract_multiscale(self.vgg_ground, ground_img)
        f_aerial_ms = self._extract_multiscale(self.vgg_aerial, aerial_img)
        f_synth_aerial_ms= self._extract_multiscale(self.vgg_aerial, synth_img)
        f_synth_seg_ms = self._extract_multiscale(self.vgg_seg, seg_synth_img)

        # Embeddings
        f_ground = self.fc_ground(f_ground_ms)
        f_aerial = self.fc_aerial(f_aerial_ms)
        f_synth_aerial = self.fc_aerial(f_synth_aerial_ms)  # share head with real aerial
        f_synth_seg = self.fc_seg(f_synth_seg_ms)

        # L2 normalize embeddings
        f_ground  = F.normalize(f_ground, dim=1)
        f_aerial  = F.normalize(f_aerial, dim=1)
        f_synth_aerial = F.normalize(f_synth_aerial, dim=1)
        f_synth_seg  = F.normalize(f_synth_seg, dim=1)

        return f_ground, f_aerial, f_synth_aerial, f_synth_seg

## lightning wrapper

In [None]:
class LightningWrapper(pylight.LightningModule):

  def __init__(self, device=device, model=JointFeatureLearningNet(), l1_weight=10, l2_weight=1, l3_weight=1):
    super().__init__()
    self.dvc = device
    self.model = model

    self.criterion = WeightedSoftMarginTripletLoss()
    self.l1_weight = l1_weight
    self.l2_weight = l2_weight
    self.l3_weight = l3_weight


  def forward(self, ground_view, candidate_aerial, synthetic_aerial, synthetic_seg_aerial):
    return self.model(ground_view, candidate_aerial, synthetic_aerial, synthetic_seg_aerial)



  def training_step(self, batch, batch_idx):
    ground, aerial, synth_aerial, synth_seg_aerial = batch
    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    synth_aerial = synth_aerial.to(self.dvc)
    synth_seg_aerial = synth_seg_aerial.to(self.dvc)

    # Forward pass
    f_ground, f_aerial, f_synth_aerial, f_synth_seg = self.model(ground, aerial, synth_aerial, synth_seg_aerial)

    # Find closest (hard) negative in batch
    # Euclidean distance matrix
    dists = torch.cdist(f_ground, f_aerial, p=2)
    # Mask out the diagonals
    dists.fill_diagonal_(float('inf'))
    # For each anchor i, get hardest negative index
    neg_idx = torch.argmin(dists, dim=1)  # shape [B]
    # negative aerial features
    f_aerial_neg = f_aerial[neg_idx]

    # Compute loss
    L1 = self.criterion(f_ground, f_aerial, f_aerial_neg)
    L2 = self.criterion(f_synth_aerial, f_aerial, f_aerial_neg)
    L3 = self.criterion(f_synth_seg, f_aerial, f_aerial_neg)
    loss = self.l1_weight * L1 + self.l2_weight * L2 + self.l3_weight * L3

    self.log("loss 1", L1, prog_bar=True)
    self.log("loss 2", L2, prog_bar=True)
    self.log("loss 3", L3, prog_bar=True)
    self.log("train_loss", loss, prog_bar=True)

    return loss



  def validation_step(self, batch, batch_idx):
    ground, aerial, synth_aerial, synth_seg_aerial = batch
    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    synth_aerial = synth_aerial.to(self.dvc)
    synth_seg_aerial = synth_seg_aerial.to(self.dvc)

    # Forward pass
    f_ground, f_aerial, f_synth_aerial, f_synth_seg = self.model(ground, aerial, synth_aerial, synth_seg_aerial)

    # Find closest (hard) negative in batch
    # Euclidean distance matrix
    dists = torch.cdist(f_ground, f_aerial, p=2)
    # Mask out the diagonals
    dists.fill_diagonal_(float('inf'))
    # For each anchor i, get hardest negative index
    neg_idx = torch.argmin(dists, dim=1)  # shape [B]
    # negative aerial features
    f_aerial_neg = f_aerial[neg_idx]

    # Compute loss
    L1 = self.criterion(f_ground, f_aerial, f_aerial_neg)
    L2 = self.criterion(f_synth_aerial, f_aerial, f_aerial_neg)
    L3 = self.criterion(f_synth_seg, f_aerial, f_aerial_neg)
    loss = self.l1_weight * L1 + self.l2_weight * L2 + self.l3_weight * L3

    self.log("loss 1", L1, prog_bar=True)
    self.log("loss 2", L2, prog_bar=True)
    self.log("loss 3", L3, prog_bar=True)
    self.log("val_loss", loss, prog_bar=True)

    return loss



  def test_step(self, batch, batch_idx):
    ground, aerial, synth_aerial, synth_seg_aerial = batch
    ground = ground.to(self.dvc)
    aerial = aerial.to(self.dvc)
    synth_aerial = synth_aerial.to(self.dvc)
    synth_seg_aerial = synth_seg_aerial.to(self.dvc)

    # Forward pass
    f_ground, f_aerial, f_synth_aerial, f_synth_seg = self.model(ground, aerial, synth_aerial, synth_seg_aerial)

    # Find closest (hard) negative in batch
    # Euclidean distance matrix
    dists = torch.cdist(f_ground, f_aerial, p=2)
    # Mask out the diagonals
    dists.fill_diagonal_(float('inf'))
    # For each anchor i, get hardest negative index
    neg_idx = torch.argmin(dists, dim=1)  # shape [B]
    # negative aerial features
    f_aerial_neg = f_aerial[neg_idx]

    # Compute loss
    L1 = self.criterion(f_ground, f_aerial, f_aerial_neg)
    L2 = self.criterion(f_synth_aerial, f_aerial, f_aerial_neg)
    L3 = self.criterion(f_synth_seg, f_aerial, f_aerial_neg)
    loss = self.l1_weight * L1 + self.l2_weight * L2 + self.l3_weight * L3

    self.log("loss 1", L1, prog_bar=True)
    self.log("loss 2", L2, prog_bar=True)
    self.log("loss 3", L3, prog_bar=True)
    self.log("test_loss", loss, prog_bar=True)

    return loss



  def configure_optimizers(self):
    optimizer = optim.AdamW([
        {'params': self.model.vgg_ground.parameters(), 'lr': LR_vgg},
        {'params': self.model.vgg_aerial.parameters(), 'lr': LR_vgg},
        {'params': self.model.vgg_seg.parameters(), 'lr': LR_vgg},
        {'params': self.model.fc_ground.parameters(), 'lr': LR_fnn},
        {'params': self.model.fc_aerial.parameters(), 'lr': LR_fnn},
        {'params': self.model.fc_seg.parameters(), 'lr': LR_fnn}
    ], weight_decay=WEIGHT_DECAY)

    scheduler = {
        'scheduler': optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=[LR_vgg, LR_vgg, LR_vgg, LR_fnn, LR_fnn, LR_fnn],  # Matches optimizer groups
            total_steps=len(train_loader)*EPOCHS,
            pct_start=0.3,
            div_factor=10,
            final_div_factor=100
        ),
        'interval': 'step'
    }
    
    return {'optimizer': optimizer, 'lr_scheduler': scheduler}

# **TRAIN**

In [None]:
if START_FROM_EPOCH is not None:
    if START_FROM_EPOCH == -1:    # start from last checkpoint
        pattern = CKPT_PATH + "last.ckpt"
        matching = glob.glob(pattern)
        if len(matching) == 0:
            raise FileNotFoundError(f"No checkpoint file matches: {pattern}")
        start_from_path = matching[0]

    else:
        pattern = CKPT_PATH + f"best-checkpoint-epoch={START_FROM_EPOCH:02d}-*.ckpt"
        matching = glob.glob(pattern)
        if len(matching) == 0:
            raise FileNotFoundError(f"No checkpoint file matches: {pattern}")
        start_from_path = matching[0]
else:
    start_from_path = None


logger = TensorBoardLogger(
    save_dir=LOG_DIR,
    name="lightning_logs",
)

# ___________________________________________________________________________________________ Load model

model = LightningWrapper(device, model=JointFeatureLearningNet(), l1_weight=10, l2_weight=1, l3_weight=1)


# ___________________________________________________________________________________________ Callbacks

checkpoint_callback = ModelCheckpoint(
    dirpath=CKPT_PATH,
    filename="best-checkpoint-{epoch:02d}-{val_loss:-2f}",
    monitor="val_loss",
    mode="min",
    verbose=True,
    save_top_k=6,
    every_n_epochs=2,
    save_last=True
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=7,          # Stop after patience epochs without improvement
    min_delta=0.0001,    # Minimum change to qualify as improvement
    mode="min",
    verbose=True
)


callbacks_list = [
    checkpoint_callback,
    early_stop_callback
]



# ___________________________________________________________________________________________ Training phase 1

# Initialize trainer
trainer = Trainer(
    logger=logger,
    enable_checkpointing=True,
    default_root_dir=LOG_DIR,
    callbacks=callbacks_list,
    enable_progress_bar=True,
    max_epochs=EPOCHS,
    gradient_clip_val=GRADIENT_CLIP,
    accumulate_grad_batches=2,  # Effective batch size = BATCH_SIZE * 2
    precision='16-mixed'
)

# Start training
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=start_from_path)

# **TEST**

In [None]:
epoch_to_load = -1     # number of epoch of file to load. -1 to take last checkpoint

if epoch_to_load == -1:
    ckpt_to_load = CKPT_PATH + "last.ckpt"
else:
    pattern = os.path.join(CKPT_PATH, f"*epoch={epoch_to_load:02d}-*.ckpt")
    matching_files = glob.glob(pattern)
    if not matching_files:
        raise FileNotFoundError(f"No checkpoint file found with epoch {epoch_to_load}.")
    ckpt_to_load = matching_files[0]

print(f"Loading model checkpoint {ckpt_to_load} ...")

model = LightningWrapper.load_from_checkpoint(
    ckpt_to_load,
    device=device,
    model=JointFeatureLearningNet(),
    l1_weight=10,
    l2_weight=1,
    l3_weight=1
)
model.eval()

# ────────────────────────────────────────────────────────────────────────────────

testing_logger = TensorBoardLogger(
    save_dir=LOG_DIR,
    name="testing_lightning_logs",
)

trainer = Trainer(
    logger=testing_logger,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    precision="16-mixed",
    enable_progress_bar=True
)

# ────────────────────────────────────────────────────────────────────────────────

test_results = trainer.test(
    model,
    test_dataloaders=test_loader
)

# ────────────────────────────────────────────────────────────────────────────────
print("\n=== Test Results ===")
for key, val in test_results[0].items():
    print(f"{key}: {val:.4f}")
