In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ["MKL_NUM_THREADS"] = "10" 
os.environ["NUMEXPR_NUM_THREADS"] = "10" 
os.environ["OMP_NUM_THREADS"] = "10" 
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

# Load configuration
import configparser
config = configparser.ConfigParser()
config.read("./config.ini")
import ast

import torch
import numpy as np
import random
from dataloaders.dataloader_robotcar import RobotCarDataset, transform_grd, transform_sat
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.nn import functional as F
import math

from utils.utils import weighted_procrustes_2d, desc_l2norm
from models.loss import loss_bev_space, compute_infonce_loss
from models.model_robotcar_eval import CVM
from models.modules import DinoExtractor

import matplotlib.pyplot as plt


batch_size = 1
beta = 10.0
epoch = 98
num_samples_matches4vis = 50

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility from config
def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(mode=True, warn_only=True)

set_seeds(config.getint("RandomSeed", "seed"))

config = configparser.ConfigParser()
config.read("./config.ini")

# Load hyperparameters from config
dataset = config["Dataset"]["dataset"]

# Dataset depend hyperparameters
dataset_root = config["RobotCar"]["local_dataset_root"]
raw_ground_image_size = ast.literal_eval(config.get("RobotCar", "raw_ground_image_size"))
cropped_ground_image_size = ast.literal_eval(config.get("RobotCar", "cropped_ground_image_size"))
ground_image_size = ast.literal_eval(config.get("RobotCar", "ground_image_size"))
satellite_image_size = ast.literal_eval(config.get("RobotCar", "satellite_image_size"))

grid_size_h = config.getfloat("RobotCar", "grid_size_h")
grid_size_v = config.getfloat("RobotCar", "grid_size_v")

learning_rate = config.getfloat("Training", "learning_rate")

grd_bev_res = config.getint("Model", "grd_bev_res")
grd_height_res = config.getint("Model", "grd_height_res")
sat_bev_res = config.getint("Model", "sat_bev_res")

num_keypoints = config.getint("Model", "num_keypoints")

num_samples_matches = config.getint("Matching", "num_samples_matches")

loss_grid_size = config.getfloat("Loss", "loss_grid_size")
num_virtual_point = config.getint("Loss", "num_virtual_point")

label = (f"Robotcar_num_matches_{num_samples_matches}"
         f"_beta_{beta}_grd_bev_res_{grd_bev_res}_height_res_{grd_height_res}"
         f"_sat_res_{sat_bev_res}_loss_grid_{loss_grid_size}"
         f"_h_{int(grid_size_h)}_v_{grid_size_v}_lr_{learning_rate}")

print(f"Experiment label: {label}")


# Load dataset
training_set = RobotCarDataset(
    root=dataset_root, split='train',
    transform=(transform_grd, transform_sat)
)

val_set = RobotCarDataset(
    root=dataset_root, split='val',
    transform=(transform_grd, transform_sat)
)

test_set = RobotCarDataset(
    root=dataset_root, split='test',
    transform=(transform_grd, transform_sat)
)

# Create DataLoaders
train_dataloader = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

# Initialize shared feature extractor
shared_feature_extractor = DinoExtractor().to(device)

# Initialize CVM Model
CVM_model = CVM(device, grd_bev_res=grd_bev_res, grd_height_res=grd_height_res, 
                sat_bev_res=sat_bev_res, num_keypoints=num_keypoints, 
                embed_dim=1024, grid_size_h=grid_size_h, grid_size_v=grid_size_v)

model_path = f'/home/ziminxia/Work/scitas_mount/checkpoints/{label}/{epoch}/model.pt'
CVM_model.load_state_dict(torch.load(model_path))
CVM_model.to(device)
CVM_model.eval()

In [None]:

def create_metric_grid(grid_size, res, batch_size):
    x, y = np.linspace(-grid_size/2, grid_size/2, res[0]), np.linspace(-grid_size/2, grid_size/2, res[1])
    metric_x, metric_y = np.meshgrid(x, y, indexing='ij')
    metric_x, metric_y = torch.tensor(metric_x).flatten().unsqueeze(0).unsqueeze(-1), torch.tensor(metric_y).flatten().unsqueeze(0).unsqueeze(-1)
    metric_coord = torch.cat((metric_x, metric_y), -1).to(device).float()
    return metric_coord.repeat(batch_size, 1, 1)

metric_coord_grd_B = create_metric_grid(grid_size_h, (int(np.floor(grd_bev_res/2))+1, grd_bev_res), batch_size)
metric_coord_sat_B = create_metric_grid(grid_size_h, (sat_bev_res, sat_bev_res), batch_size)

x = np.linspace(-grid_size_h / 2, 0, int(np.floor(grd_bev_res/2))+1)
y = np.linspace(-grid_size_h / 2, grid_size_h / 2, grd_bev_res)
z = np.linspace(-grid_size_v / 2, grid_size_v / 2, grd_height_res)

num_samples_matches4vis = 50

for idx in range(20, 40):
    print('idx', idx)
    with torch.no_grad():
        grd, sat, tgt, Rgt = test_set.__getitem__(idx)

        grd, sat, tgt, Rgt = grd.unsqueeze(0).to(device), sat.unsqueeze(0).to(device), tgt.unsqueeze(0).to(device), Rgt.unsqueeze(0).to(device)

        B, _, sat_size, _ = sat.shape
        B, _, grd_H, grd_W = grd.shape
        
        grd_feature, sat_feature = shared_feature_extractor(grd), shared_feature_extractor(sat)

        matching_score, sat_desc, grd_desc, sat_indices_topk, grd_indices_topk, matching_score_original, grd_scrs, sat_scrs, height_index = CVM_model(grd_feature, sat_feature)
        _, num_kpts_sat, num_kpts_grd = matching_score.shape
        
        # Sample validation matches
        matches_row = matching_score.flatten(1)
        batch_idx = torch.arange(B).view(B, 1).repeat(1, num_samples_matches).reshape(B, num_samples_matches)
        sampled_matching_idx = torch.multinomial(matches_row, num_samples_matches)
        sampled_idx4vis = torch.argsort(matches_row, descending=True)[:,:num_samples_matches4vis]


        sampled_matching_row = torch.div(sampled_matching_idx, num_kpts_grd, rounding_mode='trunc')
        sampled_matching_col = sampled_matching_idx % num_kpts_grd

        sampled_matching_row4vis = torch.div(sampled_idx4vis, num_kpts_grd, rounding_mode='trunc')
        sampled_matching_col4vis = (sampled_idx4vis % num_kpts_grd)

        sat_indices_sampled = torch.gather(sat_indices_topk.squeeze(1), 1, sampled_matching_row)
        grd_indices_sampled = torch.gather(grd_indices_topk.squeeze(1), 1, sampled_matching_col)

        sat_indices_sampled4vis = torch.gather(sat_indices_topk.squeeze(1), 1, sampled_matching_row4vis)
        grd_indices_sampled4vis = torch.gather(grd_indices_topk.squeeze(1), 1, sampled_matching_col4vis)

        X, Y, weights = metric_coord_sat_B[batch_idx, sat_indices_sampled, :], metric_coord_grd_B[batch_idx, grd_indices_sampled, :], matches_row[batch_idx, sampled_matching_idx]
        R, t, ok_rank = weighted_procrustes_2d(X, Y, use_weights=True, use_mask=True, w=weights)

        if t is None:
            print('⚠️ Skipping batch: Singular transformation matrix')
            continue
        
        # Compute translation error
        t = (t / grid_size_h) * sat_size
        translation_error = torch.norm(t - tgt, dim=-1).cpu().numpy()

        # Compute yaw error
        Rgt_np, R_np = Rgt.cpu().numpy(), R.cpu().numpy()
        for b in range(B):
            cos = R_np[b,0,0]
            sin = R_np[b,1,0]
            yaw = np.degrees( np.arctan2(sin, cos) )            
            
            cos_gt = Rgt_np[b,0,0]
            sin_gt = Rgt_np[b,1,0]
            
            yaw_gt = np.degrees( np.arctan2(sin_gt, cos_gt) )
            
            diff = np.abs(yaw - yaw_gt)

            yaw_error = np.min([diff, 360-diff])         

        # plt.figure()
        # plt.imshow(grd.squeeze(0).permute(1,2,0).cpu().numpy())       
        # plt.axis('off')
        # plt.show()

        # plt.figure(figsize=(10,10))
        # plt.imshow(sat.squeeze(0).permute(1,2,0).cpu().numpy())       
        # plt.quiver(satellite_image_size[1]/2-tgt[0,0,1].cpu(), satellite_image_size[0]/2-tgt[0,0,0].cpu(), np.cos((90-yaw_gt )/ 180 * np.pi), np.sin((90-yaw_gt) / 180 * np.pi), facecolor='g', linewidths=0.6, scale=15, width=0.01)
        # plt.scatter(satellite_image_size[1]/2-tgt[0,0,1].cpu(), satellite_image_size[0]/2-tgt[0,0,0].cpu(), s=300, marker='^', facecolor='g', label='GT', edgecolors='white', zorder=1)
        # plt.quiver(satellite_image_size[1]/2-t[0,0,1].cpu(), satellite_image_size[0]/2-t[0,0,0].cpu(), np.cos((90-yaw )/ 180 * np.pi), np.sin((90-yaw) / 180 * np.pi), facecolor='gold', linewidths=0.6, scale=15, width=0.01)
        # plt.scatter(satellite_image_size[1]/2-t[0,0,1].cpu(), satellite_image_size[0]/2-t[0,0,0].cpu(), s=300, marker='*', facecolor='gold', label='Ours', edgecolors='white', zorder=1)
        # plt.axis('off')
        # plt.show()
        
        print('translation_error', translation_error)
        print('yaw_error', yaw_error)

        sat_image_plot_L = grd_H
        grd_bev_plot_H = int(np.floor(grd_H/2))+1
        grd_bev_plot_W = grd_H 
        
        
        sat_grid_size = sat_image_plot_L/sat_bev_res
        sat_h = (torch.div(sat_indices_sampled4vis, sat_bev_res, rounding_mode='trunc')).cpu().numpy() * sat_grid_size + sat_grid_size/2
        sat_w = (sat_indices_sampled4vis % sat_bev_res).cpu().numpy() * sat_grid_size + sat_grid_size/2

        grd_bev_grid_size = grd_bev_plot_W/grd_bev_res
        
        grd_bev_row_indices = (torch.div(grd_indices_sampled4vis, grd_bev_res, rounding_mode='trunc')).cpu().numpy() 
        grd_bev_col_indices = (grd_indices_sampled4vis % grd_bev_res).cpu().numpy()
        sampled_height_indices = height_index.flatten()[grd_indices_sampled4vis]
        
        grd_bev_h = grd_bev_row_indices * grd_bev_grid_size + grd_bev_grid_size/2
        grd_bev_w = grd_bev_col_indices * grd_bev_grid_size + grd_bev_grid_size/2

        grd_3D_sampled = np.ones([num_samples_matches4vis, 3])
        
        
        for i in range(num_samples_matches4vis):
            grd_3D_sampled[i, 0] = y[grd_bev_col_indices[0,i]]
            grd_3D_sampled[i, 1] = -z[sampled_height_indices[0,i]]
            grd_3D_sampled[i, 2] = -x[grd_bev_row_indices[0,i]]

        
        fx, fy = 964.828979 * ground_image_size[1] / raw_ground_image_size[1], 964.828979 * ground_image_size[0] / raw_ground_image_size[0] 
        # fx, fy = 964.828979, 964.828979  
        cx, cy = (643.788025 - (raw_ground_image_size[1] - cropped_ground_image_size[1]) / 2) * ground_image_size[1] / cropped_ground_image_size[1], 484.407990 * ground_image_size[0] / raw_ground_image_size[0]  
        # cx, cy = 643.788025 , 484.407990   
        camera_k = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]

        projected_points = camera_k @ np.transpose(grd_3D_sampled)
        u = (projected_points[0, :] / projected_points[2, :] / ground_image_size[1])
        v = (projected_points[1, :] / projected_points[2, :] / ground_image_size[0])
        
        
        grd_h = v * grd_H
        grd_w = u * grd_W

        grd_points = []
        grd_bev_points = []
        sat_points = []
        for i in range(num_samples_matches4vis):
            if u[i] >= 0 and u[i] <= 1 and v[i] >= 0 and v[i] <= 1:
                grd_bev_points.append((grd_bev_w[0, i], grd_bev_h[0, i]))
                sat_points.append((sat_w[0, i], sat_h[0, i]))
                grd_points.append((grd_w[i], grd_h[i]))

        
        combined_image = np.ones((grd_H, grd_W+grd_H+10, 3)) 
        grd_to_show = grd[0].permute(1,2,0).cpu().detach().numpy()
        sat_to_show = F.interpolate(sat, size=(grd_H, grd_H), mode='bicubic', align_corners=False)[0].permute(1,2,0).cpu().detach().numpy()
    
        # Place the images side-by-side on the canvas
        combined_image[:, :grd_W, :] = grd_to_show
        combined_image[:, 10+grd_W:, :] = sat_to_show
    
        fig, ax = plt.subplots(figsize=(15,60))
        ax.imshow(combined_image)
        sample_index = 0
        for (x1, y1), (x2, y2) in zip(grd_points, sat_points):
            # Adjust x2 coordinate for the second image
            x2_adjusted = x2 + 10+grd_W
            # ax.plot([x1, x2_adjusted], [y1, y2], marker='o', markersize=4, color='lime', linestyle='-', linewidth=1, alpha=sampled_matching_scores[sample_index])
            ax.plot([x1, x2_adjusted], [y1, y2], marker='o', markersize=2, color='lime', linestyle='-', linewidth=0.8, zorder=0)
            sample_index += 1
        
        plt.quiver(grd_W+sat_image_plot_L/2-tgt[0,0,1].cpu(), sat_image_plot_L/2-tgt[0,0,0].cpu(), np.cos((90-yaw_gt )/ 180 * np.pi), np.sin((90-yaw_gt) / 180 * np.pi), facecolor='g', linewidths=0.3, scale=25, width=0.006)
        plt.scatter(grd_W+sat_image_plot_L/2-tgt[0,0,1].cpu(), sat_image_plot_L/2-tgt[0,0,0].cpu(), s=300, marker='^', facecolor='g', label='GT', edgecolors='white', zorder=1)
        plt.quiver(grd_W+sat_image_plot_L/2-t[0,0,1].cpu(), sat_image_plot_L/2-t[0,0,0].cpu(), np.cos((90-yaw )/ 180 * np.pi), np.sin((90-yaw) / 180 * np.pi), facecolor='gold', linewidths=0.3, scale=25, width=0.006)
        plt.scatter(grd_W+sat_image_plot_L/2-t[0,0,1].cpu(), sat_image_plot_L/2-t[0,0,0].cpu(), s=300, marker='*', facecolor='gold', label='Ours', edgecolors='white', zorder=1)
        plt.show()

In [None]:
X, Y, weights = metric_coord_sat_B[batch_idx, sat_indices_sampled, :], metric_coord_grd_B[batch_idx, grd_indices_sampled, :], matches_row[batch_idx, sampled_matching_idx]
metric_coord_grd_B = metric_coord_grd_B

R, t, ok_rank = weighted_procrustes_2d(X, Y, use_weights=True, use_mask=True, w=weights)
print('R', R)
print('t', t)

In [None]:
for idx in range(20, 40):
    with torch.no_grad():
        grd, sat, tgt, Rgt = test_set.__getitem__(idx)

        grd, sat, tgt, Rgt = grd.unsqueeze(0).to(device), sat.unsqueeze(0).to(device), tgt.unsqueeze(0).to(device), Rgt.unsqueeze(0).to(device)

        B, _, sat_size, _ = sat.shape
        B, _, grd_H, grd_W = grd.shape
        
        grd_feature, sat_feature = desc_l2norm(shared_feature_extractor(grd)), desc_l2norm(shared_feature_extractor(sat))
        _, _, grd_feature_H, grd_feature_W = grd_feature.size()
        _, _, sat_feauture_size, _ = sat_feature.size()
        
        matching_score = torch.matmul(sat_feature.flatten(2).transpose(1, 2).contiguous(), grd_feature.flatten(2))
        
        matches_row = matching_score.flatten(1)
        
        top_indices = torch.argsort(matches_row, descending=True)[0,:num_samples_matches4vis]
        sat_indices = torch.div(top_indices, grd_feature_H*grd_feature_W, rounding_mode='trunc')
        grd_indices = top_indices % (grd_feature_H*grd_feature_W)

        sat_h_indices = (torch.div(sat_indices, sat_feauture_size, rounding_mode='trunc')).cpu().numpy() * 14 + 7
        sat_w_indices = (sat_indices % sat_feauture_size).cpu().numpy() * 14 + 7

        grd_h_indices = (torch.div(grd_indices, grd_feature_W, rounding_mode='trunc')).cpu().numpy() * 14 + 7
        grd_w_indices = (grd_indices % grd_feature_W).cpu().numpy() * 14 + 7

        
        grd_points = []
        grd_bev_points = []
        sat_points = []
        for i in range(num_samples_matches4vis):
            grd_points.append((grd_w_indices[i], grd_h_indices[i]))
            sat_points.append((sat_w_indices[i]/sat_size*grd_H, sat_h_indices[i]/sat_size*grd_H))

        
        combined_image = np.ones((grd_H, grd_W+grd_H+10, 3)) 
        grd_to_show = grd[0].permute(1,2,0).cpu().detach().numpy()
        sat_to_show = F.interpolate(sat, size=(grd_H, grd_H), mode='bicubic', align_corners=False)[0].permute(1,2,0).cpu().detach().numpy()
    
        # Place the images side-by-side on the canvas
        combined_image[:, :grd_W, :] = grd_to_show
        combined_image[:, 10+grd_W:, :] = sat_to_show
    
        fig, ax = plt.subplots(figsize=(15,60))
        ax.imshow(combined_image)
        sample_index = 0
        for (x1, y1), (x2, y2) in zip(grd_points, sat_points):
            # Adjust x2 coordinate for the second image
            x2_adjusted = x2 + 10+grd_W
            # ax.plot([x1, x2_adjusted], [y1, y2], marker='o', markersize=4, color='lime', linestyle='-', linewidth=1, alpha=sampled_matching_scores[sample_index])
            ax.plot([x1, x2_adjusted], [y1, y2], marker='o', markersize=2, color='lime', linestyle='-', linewidth=0.8, zorder=0)
            sample_index += 1


In [None]:
idx = 21
selected_grd_h = 25
selected_grd_w = 0
num_top_matches = 50

with torch.no_grad():
    grd, sat, tgt, Rgt = test_set.__getitem__(idx)
    # sat = torch.flip(sat, dims=[2])
    
    grd, sat, tgt, Rgt = grd.unsqueeze(0).to(device), sat.unsqueeze(0).to(device), tgt.unsqueeze(0).to(device), Rgt.unsqueeze(0).to(device)
    
    B, _, sat_size, _ = sat.shape
    B, _, grd_H, grd_W = grd.shape
    
    grd_feature, sat_feature = desc_l2norm(shared_feature_extractor(grd)), desc_l2norm(shared_feature_extractor(sat))
    _, _, grd_feature_H, grd_feature_W = grd_feature.size()
    _, _, sat_feauture_size, _ = sat_feature.size()

    matching_score = torch.matmul(sat_feature.flatten(2).transpose(1, 2).contiguous(), grd_feature.flatten(2))
    print(matching_score.shape)
    
    for selected_grd_w in range(0,85,7):
        # select a point in the ground view
        seletec_matching_col = selected_grd_h*grd_feature_W + selected_grd_w
        sat_scores = matching_score[0,:,seletec_matching_col]
        
        
        top_indices = torch.argsort(sat_scores, descending=True)[:num_top_matches]
        
        sat_h_indices = (torch.div(top_indices, sat_feauture_size, rounding_mode='trunc')).cpu().numpy() * 14 + 7
        sat_w_indices = (top_indices % sat_feauture_size).cpu().numpy() * 14 + 7


        grd_h_indices = selected_grd_h * 14 + 7
        grd_w_indices = selected_grd_w * 14 + 7
    
        grd_points = []
        sat_points = []
        for i in range(num_top_matches):
            grd_points.append((grd_w_indices, grd_h_indices))
            sat_points.append((sat_w_indices[i]/sat_size*grd_H, sat_h_indices[i]/sat_size*grd_H))
        grd_to_show = grd[0].permute(1,2,0).cpu().detach().numpy()
        sat_to_show = F.interpolate(sat, size=(grd_H, grd_H), mode='bicubic', align_corners=False)[0].permute(1,2,0).cpu().detach().numpy()
    
    
        # Create a new blank canvas for both images
        combined_image = np.ones((grd_H, grd_W+grd_H+10, 3)) 
        
        # Place the images side-by-side on the canvas
        combined_image[:, :grd_W, :] = grd_to_show
        combined_image[:, 10+grd_W:, :] = sat_to_show
    
        fig, ax = plt.subplots(figsize=(10,30))
        ax.imshow(combined_image)
        for (x1, y1), (x2, y2) in zip(grd_points, sat_points):
            # Adjust x2 coordinate for the second image
            x2_adjusted = x2 + 10+grd_W
            ax.plot([x1, x2_adjusted], [y1, y2], marker='o', markersize=2, color='lime', linestyle='-', linewidth=0.8, zorder=0)
        
        plt.show()