#Localization task ground-aerial

## IMPORTS

In [None]:
import os
import cv2
import math
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tqdm import tqdm
from PIL import Image
from torchvision import transforms as T, models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision.transforms.functional import to_pil_image


## GLOBALS

In [None]:
SAVING = False
LOADING = False

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

cuda
PyTorch version: 2.7.1+cu118
CUDA available: True
CUDA version: 11.8
Device: cuda


In [None]:
path = "C:/Users/andre/Desktop/vision_dataset/"
csv_path_train = os.path.join(path, "CVUSA_subset/training_csv.csv")
csv_path_validation = os.path.join(path, "CVUSA_subset/validation_csv.csv")
csv_path_test = os.path.join(path, "CVUSA_subset/test_csv.csv")


train_df = pd.read_csv(csv_path_train, sep=",", names=[
     'ground', 'sintetic', 'segmentation', 'bingmap'
 ], encoding='utf-8', header=0)

val_df = pd.read_csv(csv_path_validation, sep=",", names=[
     'ground', 'sintetic', 'segmentation', 'bingmap'
 ], encoding='utf-8', header=0)

test_df = pd.read_csv(csv_path_test, sep=",", names=[
     'ground', 'sintetic', 'segmentation', 'bingmap'
 ], encoding='utf-8', header=0)

## UTILS

In [None]:
def recall_at_k(queries, database, K=1):

    queries = F.normalize(queries, dim=1)
    database = F.normalize(database, dim=1)

    dists = torch.cdist(queries, database, p=2)

    _ , top_k = torch.topk(dists, k=K, largest=False)

    correct = 0
    for i in range(len(queries)):
        if i in top_k[i]:
            correct += 1

    return correct / len(queries), top_k


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


def denormalize(tensor, mean, std):
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor.clamp(0, 1)


def show_top_k_images(queries, database, query_images, database_images, top_k, num_to_show=5):
    for i in range(num_to_show):
        fig, axs = plt.subplots(1, top_k.size(1) + 1, figsize=(16, 4))


        query_img = denormalize(query_images[i].cpu(), mean, std)
        axs[0].imshow(to_pil_image(query_img))
        axs[0].set_title("Query")
        axs[0].axis("off")

        for j in range(top_k.size(1)):
            db_idx = top_k[i][j].item()
            db_img = denormalize(database_images[db_idx].cpu(), mean, std)
            axs[j+1].imshow(to_pil_image(db_img))
            axs[j+1].axis("off")
            if db_idx == i:
                axs[j+1].set_title(f"Top {j+1} []")
            else:
                axs[j+1].set_title(f"Top {j+1}")

        plt.tight_layout()
        plt.show()


## DATA

In [None]:
transform = T.Compose([
    T.ToTensor(),
    T.Resize((256, 256)),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225])

])


class CVUSATripletDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.root_dir = root_dir
        self.transform = transform

        self.sv_images = [os.path.join(path, sv_filename) for sv_filename in self.df['ground'].to_list()]
        self.sat_images = [os.path.join(path, sat_filename) for sat_filename in self.df['bingmap'].to_list()]
        self.segmented = [os.path.join(path, depth_filename) for depth_filename in self.df['segmentation'].to_list()]
        self.syntetic = [os.path.join(path, depth_filename) for depth_filename in self.df['sintetic'].to_list()]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        ground = cv2.cvtColor(cv2.imread(os.path.join(self.root_dir, self.sv_images[idx])) , cv2.COLOR_BGR2RGB)
        synthetic = cv2.cvtColor(cv2.imread(os.path.join(self.root_dir, self.syntetic[idx])) , cv2.COLOR_BGR2RGB)

        segmented = cv2.imread(os.path.join(self.root_dir, self.segmented[idx]))
        segmented = cv2.cvtColor(segmented, cv2.COLOR_BGR2RGB)
        candidate_pos = cv2.cvtColor(cv2.imread(os.path.join(self.root_dir, self.sat_images[idx])) , cv2.COLOR_BGR2RGB)

        neg_idx = random.choice([i for i in range(len(self.df)) if i != idx])
        candidate_neg = cv2.cvtColor(cv2.imread(os.path.join(self.root_dir, self.sat_images[neg_idx])) , cv2.COLOR_BGR2RGB)

        if self.transform:
            ground = self.transform(ground)
            synthetic = self.transform(synthetic)
            candidate_pos = self.transform(candidate_pos)
            candidate_neg = self.transform(candidate_neg)
            segmented = self.transform(segmented)

        return ground, synthetic, segmented, candidate_pos, candidate_neg

    def _load_image(self, rel_path):
        img_path = os.path.join(self.root_dir, rel_path)

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            img = self.transform(img)
        return img



In [None]:
train_set = CVUSATripletDataset(train_df, path, transform)
val_set = CVUSATripletDataset(val_df, path, transform)
test_set = CVUSATripletDataset(test_df, path, transform)

batch_size = 8
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

## NETWORK

In [None]:
class JointFeatureLearningNet(nn.Module):
    def __init__(self):
        super(JointFeatureLearningNet, self).__init__()
        self.ground_vgg = self._make_encoder()
        self.shared_vgg = self._make_encoder()

        self.fusion = nn.Conv2d(4 * 512, 512, kernel_size=1, padding=1)

    def _make_encoder(self):
        return models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features

    def forward(self, ground, aerial, segmented, positive,negative):
        f_g = self.ground_vgg(ground)
        f_a = self.shared_vgg(aerial)
        f_s = self.shared_vgg(segmented)
        f_cp = self.shared_vgg(positive)
        f_cn = self.shared_vgg(negative)

        fused = torch.cat([f_g, f_a, f_s, f_cp], dim=1)

        joint_feat = self.fusion(fused)

        return joint_feat , f_g , f_a , f_cp , f_cn

class FeatureFusionNet(nn.Module):
    def __init__(self):
        super(FeatureFusionNet, self).__init__()

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fusion_fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 , 256)
        )

    def forward(self,x):
        x = self.pool(x)
        x = self.fusion_fc(x)
        return x


class GeoLocalizationNet(nn.Module):
    def __init__(self):
        super(GeoLocalizationNet, self).__init__()
        self.joint_net = JointFeatureLearningNet()
        self.fusion_net = FeatureFusionNet()

    def get_embedding(self, x):
        x = F.adaptive_avg_pool2d(x, (1, 1))
        return x.view(x.size(0), -1)

    def forward(self, ground, synthetic, segmented, positive,negative):
        joint_feat,f_g,f_a,f_cp,f_cn = self.joint_net(ground, synthetic, segmented, positive,negative)
        f_cp = self.fusion_net(f_cp)
        f_cn = self.fusion_net(f_cn)
        f_a = self.fusion_net(f_a)
        f_g = self.fusion_net(f_g)

        fused_feat = self.fusion_net(joint_feat)

        return fused_feat,f_g,f_a,f_cp,f_cn

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

model = GeoLocalizationNet()
total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

Total parameters: 30609792
Trainable parameters: 30609792


### Loss function

In [None]:
class WeightedSoftMarginTripletLoss(nn.Module):
      def __init__(self, alpha=0.1, lambda1=1.0, lambda2=1.0, reduction='mean'):
          super().__init__()
          self.alpha = alpha
          self.lambda1 = lambda1
          self.lambda2 = lambda2
          self.reduction = reduction
      def forward(self, fg, fa_pos, fa_neg, fa_syn):

          dp = F.pairwise_distance(fg, fa_pos, p=2)
          dn = F.pairwise_distance(fg, fa_neg, p=2)
          ds = F.pairwise_distance(fa_syn, fa_pos, p=2)

          triplet_term = torch.log1p(torch.exp(self.alpha * (dp - dn)))
          aux_term = torch.log1p(torch.exp(self.alpha * ds))

          loss = self.lambda1 * triplet_term + self.lambda2 * aux_term

          if self.reduction == 'mean':
              return loss.mean()
          else:
              return loss

## TRAIN

In [None]:
# === Training ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GeoLocalizationNet().to(device)
criterion = WeightedSoftMarginTripletLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


checkpoint_path = 'C:/Users/andre/Desktop/vision_dataset/checkpoint_parte2_alpha10/checkpoint8.pth'

if LOADING:
    start_epoch = 8
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Checkpoint loaded. Epoch: {start_epoch}")
else:
    start_epoch = 0

i=0
n_epochs = 30
for epoch in range(start_epoch,n_epochs):
    model.train()
    total_loss = 0
    for ground, synthetic, segmented, pos, neg in tqdm(train_loader):
        ground, synthetic = ground.to(device), synthetic.to(device)
        segmented, pos, neg = segmented.to(device), pos.to(device), neg.to(device)
        optimizer.zero_grad()
        output,anchor,synthetic,positive,negative = model(ground, synthetic, segmented, pos , neg)
        negative = negative.detach()
        i=i+1
        loss = criterion(output,positive,negative,synthetic)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()/ground.shape[0]

    avg_loss = total_loss / (int(len(train_loader)/batch_size))

    print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {avg_loss:.4f}")


    # === Validation ===

    model.eval()
    embeddings = []
    positive_embeddings = []
    query_imgs_list = []
    database_imgs_list = []
    val_loss = 0

    with torch.no_grad():
        for ground, synthetic, segmented, pos, neg in tqdm(val_loader):
            ground, synthetic = ground.to(device), synthetic.to(device)
            segmented, pos, neg = segmented.to(device), pos.to(device), neg.to(device)

            fused_feat, f_g, f_a, f_cp, f_cn = model(ground, synthetic, segmented, pos, neg)

            loss = criterion(fused_feat, f_cp, f_cn, f_a)
            val_loss += loss.item()/ground.shape[0]

            embeddings.append(fused_feat.cpu())
            positive_embeddings.append(f_cp.cpu())

            query_imgs_list.append(ground.cpu())
            database_imgs_list.append(pos.cpu())

    avg_val_loss = val_loss / (int(len(val_loader)/batch_size))
    print(f"Validation Loss: {avg_val_loss:.4f}")

    all_queries = torch.cat(embeddings, dim=0)
    all_database = torch.cat(positive_embeddings, dim=0)
    query_images = torch.cat(query_imgs_list, dim=0)
    database_images = torch.cat(database_imgs_list, dim=0)

    recall_1 , top_k1 = recall_at_k(all_queries, all_database, K=1)
    recall_5 , top_k5 = recall_at_k(all_queries, all_database, K=5)
    recall_10 , top_k10 = recall_at_k(all_queries, all_database, K=10)

    print(f"Recall@1: {recall_1:.4f}, Recall@5: {recall_5:.4f} , Recall@10: {recall_10:.4f}")

    show_top_k_images(all_queries, all_database, query_images, database_images, top_k5, num_to_show=5)

    if SAVING:
      checkpoint = {
              'epoch': epoch + 1,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'train_loss': avg_loss,
              'val_loss': avg_val_loss,
              'recall_1': recall_1,
              'recall_5': recall_5,
              'recall_10': recall_10
          }
      torch.save(checkpoint, f'C:/Users/andre/Desktop/vision_dataset/checkpoint_parte2_alpha10/checkpoint{epoch}.pth')

Output hidden; open in https://colab.research.google.com to view.

## TEST


In [None]:
checkpoint_path = 'C:/Users/andre/Desktop/vision_dataset/checkpoint_parte2_alpha10/Geolocalization_best_epoch.pth'
model = GeoLocalizationNet().to(device)

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

start_epoch = checkpoint['epoch']
print(f"Checkpoint loaded. Epoch: {start_epoch}")


Checkpoint loaded. Epoch: 21


In [None]:
model.eval()
embeddings = []
positive_embeddings = []
query_imgs_list = []
database_imgs_list = []

with torch.no_grad():
    for ground, synthetic, segmented, pos, neg in tqdm(test_loader):
        ground, synthetic = ground.to(device), synthetic.to(device)
        segmented, pos, neg = segmented.to(device), pos.to(device), neg.to(device)

        fused_feat, f_g, f_a, f_cp, f_cn = model(ground, synthetic, segmented, pos, neg)

        embeddings.append(fused_feat.cpu())
        positive_embeddings.append(f_cp.cpu())

        query_imgs_list.append(ground.cpu())
        database_imgs_list.append(pos.cpu())

all_queries = torch.cat(embeddings, dim=0)
all_database = torch.cat(positive_embeddings, dim=0)
query_images = torch.cat(query_imgs_list, dim=0)
database_images = torch.cat(database_imgs_list, dim=0)

recall_1 , top_k1 = recall_at_k(all_queries, all_database, K=1)
recall_5 , top_k5 = recall_at_k(all_queries, all_database, K=5)
recall_10 , top_k10 = recall_at_k(all_queries, all_database, K=10)

print(f"Recall@1: {recall_1:.4f}, Recall@5: {recall_5:.4f} , Recall@10: {recall_10:.4f}")

show_top_k_images(all_queries, all_database, query_images, database_images, top_k5, num_to_show=5)

Output hidden; open in https://colab.research.google.com to view.