In [1]:


import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from lightning import LightningModule, Trainer, LightningDataModule
from collections import OrderedDict
from PIL import Image
from glob import glob
import joblib
import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm
from torchvision import transforms
from fastervit import create_model
import timm
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from distill_transforms import get_transforms_val
DIM = 1024

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)
  check_for_updates()


In [2]:
val_df = pd.read_csv('CVACT/val.csv')
test_df = pd.read_csv('CVACT/test.csv')

val_image_pairs_details = val_df.values.tolist()
test_image_pairs_details = test_df.values.tolist()

In [3]:
prob_rotate: float = 0.75          # rotates the sat image and ground images simultaneously
prob_flip: float = 0.5             # flipping the sat image and ground images simultaneously
student_sat_size: tuple = (224, 224)
student_street_size: tuple = (224, 224)

sat_transforms_val_student, street_transforms_val_student = get_transforms_val(student_sat_size, student_street_size)

In [4]:
class WeightDistillEvalSet(Dataset):
    def __init__(self, image_pairs_details):
        self.image_pairs_details = image_pairs_details

    def __len__(self):
        return len(self.image_pairs_details)

    def __getitem__(self, idx):
        sat_image_path = 'CVACT/'+ self.image_pairs_details[idx][0]
        street_image_path = 'CVACT/'+ self.image_pairs_details[idx][1]

        # Open and process the satellite image
        with Image.open(sat_image_path) as sat_image:
            sat_img = sat_transforms_val_student(image= np.asarray(sat_image))['image']
        # Open and process the street image
        with Image.open(street_image_path) as street_image:
            street_img = street_transforms_val_student(image= np.asarray(street_image))['image']
        
        return sat_img, street_img

# Usage
test_set = WeightDistillEvalSet(test_image_pairs_details)
val_set = WeightDistillEvalSet(val_image_pairs_details)

test_loader = torch.utils.data.DataLoader(test_set, batch_size= 64, num_workers = 8, shuffle=False, drop_last=False)
val_loader = torch.utils.data.DataLoader(val_set, batch_size= 64, num_workers = 8, shuffle=False, drop_last=False)

In [5]:
class ContrastiveModel(LightningModule):
    def __init__(self):
        super().__init__()

        # self.teacher_model = TimmModel('convnext_base.fb_in22k_ft_in1k_384', pretrained=True, img_size= 384)
        # self.teacher_model.load_state_dict(torch.load("pretrained/cvact/convnext_base.fb_in22k_ft_in1k_384/weights_e36_90.8149.pth", weights_only=True))
        # self.teacher_model.eval()
        # for param in self.teacher_model.parameters():
        #     param.requires_grad = False
            
        self.sat_model = create_model('faster_vit_0_224', pretrained=True, model_path="/tmp/faster_vit_0.pth.tar")
        self.sat_linear = nn.Linear(1000, 1024)
        
        self.street_model = create_model('faster_vit_0_224', pretrained=True, model_path="/tmp/faster_vit_0.pth.tar")
        self.street_linear = nn.Linear(1000, 1024)
        
        self.loss = nn.CosineEmbeddingLoss()

    def forward(self, teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat):
        student_sat_features = self.sat_linear(self.sat_model(stud_sat_feat))
        student_street_features = self.street_linear(self.street_model(stud_street_feat))
        
        # with autocast():
        #     teacher_sat_features = self.teacher_model(teacher_sat_feat)
        #     teacher_street_features = self.teacher_model(teacher_street_feat)
        
        loss_sat = self.loss(student_sat_features, teacher_sat_features, torch.ones(student_sat_features.shape[0], device= self.device))
        loss_street = self.loss(student_street_features, teacher_street_features, torch.ones(student_street_features.shape[0], device= self.device))
        
        return (loss_sat + loss_street) / 2

    
    def training_step(self, batch, batch_idx):
        teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat = batch
        loss = self(teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat)

        #n = logits.size(0)
	
        # -1 for off-diagonals and 1 for diagonals
        #labels = 2 * torch.eye(n, device=logits.device) - 1
        
        # pairwise sigmoid loss
        #loss= -torch.sum(F.logsigmoid(labels * logits)) / n
    
        self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat = batch
        loss = self(teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat)

        #n = logits.size(0)
	
        # # -1 for off-diagonals and 1 for diagonals
        #labels = 2 * torch.eye(n, device=logits.device) - 1
        
        # # pairwise sigmoid loss
        #loss= -torch.sum(F.logsigmoid(labels * logits)) / n

        self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr= 1e-4)
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs = 1, max_epochs = 512)
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] #optimizer

    def on_save_checkpoint(self, checkpoint):
        # Exclude teacher model's weights from checkpoint
        checkpoint['state_dict'] = {k: v for k, v in checkpoint['state_dict'].items() if 'teacher_model' not in k}

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
distill_model = ContrastiveModel.load_from_checkpoint("./logs/KD_EMBEDDING_CD_ONLY_CVACT/version_1/checkpoints/tinyvit.ckpt")
distill_model = distill_model.to(device)
distill_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


ContrastiveModel(
  (sat_model): FasterViT(
    (patch_embed): PatchEmbed(
      (proj): Identity()
      (conv_down): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
    (levels): ModuleList(
      (0): FasterViTLayer(
        (blocks): ModuleList(
          (0): ConvBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act1): GELU(approximate='none')
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm

In [7]:
# model = TimmModel('convnext_base.fb_in22k_ft_in1k_384', pretrained=True, img_size= 384)
# model.load_state_dict(torch.load("pretrained/cvact/convnext_base.fb_in22k_ft_in1k_384/weights_e36_90.8149.pth", weights_only=True))
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = model.to(device)
# model.eval()

In [8]:
all_street_embed = []
all_sat_embed = []

with torch.no_grad():
    for batch_data in tqdm(test_loader):
        sat_img, street_img = batch_data
        
        # Street Image Embeddings
        street_features = distill_model.street_linear(distill_model.street_model(street_img.to(device)))
        all_street_embed.append(street_features.detach().cpu())

        # Original Satellite Image Embeddings
        sat_features = distill_model.sat_linear(distill_model.sat_model(sat_img.to(device)))
        all_sat_embed.append(sat_features.detach().cpu())
        

    # Save the embeddings
    joblib.dump(torch.vstack(all_street_embed), 'MB1/cvact_test_street_embed.joblib')
    joblib.dump(torch.vstack(all_sat_embed), 'MB1/cvact_test_sat_embed.joblib')
    

    all_street_embed = []
    all_sat_embed = []

    for batch_data in tqdm(val_loader):
        sat_img, street_img = batch_data
        
        # Street Image Embeddings
        street_features = distill_model.street_linear(distill_model.street_model(street_img.to(device)))
        all_street_embed.append(street_features.detach().cpu())

        # Original Satellite Image Embeddings
        sat_features = distill_model.sat_linear(distill_model.sat_model(sat_img.to(device)))
        all_sat_embed.append(sat_features.detach().cpu())

    # Save the embeddings
    joblib.dump(torch.vstack(all_street_embed), 'MB1/cvact_val_street_embed.joblib')
    joblib.dump(torch.vstack(all_sat_embed), 'MB1/cvact_val_sat_embed.joblib')

100%|██████████| 1451/1451 [03:27<00:00,  7.01it/s]
100%|██████████| 139/139 [00:20<00:00,  6.62it/s]


In [9]:
val_sat_features = joblib.load(open("MB1/cvact_val_sat_embed.joblib", 'rb'))
val_street_features = joblib.load(open("MB1/cvact_val_street_embed.joblib", 'rb'))

test_sat_features = joblib.load(open("MB1/cvact_test_sat_embed.joblib", 'rb'))
test_street_features = joblib.load(open("MB1/cvact_test_street_embed.joblib", 'rb'))

In [10]:
print(val_sat_features.shape)
print(val_street_features.shape)
print()
print(test_sat_features.shape)
print(test_street_features.shape)

torch.Size([8883, 1024])
torch.Size([8883, 1024])

torch.Size([92801, 1024])
torch.Size([92801, 1024])


In [11]:
#Normalizing features
val_street_features = F.normalize(val_street_features, dim=-1)
test_street_features = F.normalize(test_street_features, dim=-1)

val_sat_features = F.normalize(val_sat_features, dim=-1)
test_sat_features = F.normalize(test_sat_features, dim=-1)

In [12]:
res = faiss.StandardGpuResources()

val_index = faiss.IndexFlatIP(DIM)
val_index = faiss.index_cpu_to_gpu(res, 1, val_index)

test_index = faiss.IndexFlatIP(DIM)
test_index = faiss.index_cpu_to_gpu(res, 1, test_index)

val_index.add(val_sat_features)
test_index.add(test_sat_features)

In [13]:
test_top_1_per = int(0.01*len(test_street_features))
val_top_1_per = int(0.01*len(val_street_features))

#Val table searches
_, val_ids = val_index.search(x=val_street_features, k=val_top_1_per)

#Test table searches
_, test_ids = test_index.search(x=test_street_features, k=test_top_1_per)

In [14]:
val_matches = {1: 0, 5: 0, 10: 0, 50: 0, 100: 0, val_top_1_per: 0}

for i in range(len(val_street_features)):
    if i == val_ids[i][0]:
        for key in list(val_matches.keys()):
            val_matches[key]+=1        
    elif i in val_ids[i][0:5]:
        for key in list(val_matches.keys())[1:]:
            val_matches[key]+=1 
    elif i in val_ids[i][0:10]:
        for key in list(val_matches.keys())[2:]:
            val_matches[key]+=1
    elif i in val_ids[i][0:50]:
        for key in list(val_matches.keys())[3:]:
            val_matches[key]+=1
    elif i in val_ids[i][0:100]:
        for key in list(val_matches.keys())[4:]:
            val_matches[key]+=1
    elif i in val_ids[i][0:val_top_1_per]:
        for key in list(val_matches.keys())[5:]:
            val_matches[key]+=1

for keys in val_matches.keys():
    val_matches[keys] = (val_matches[keys] * 100) / len(val_street_features)
    print(f"For CVACT Val, Top {keys} Accuracy: {val_matches[keys]}%")

For CVACT Val, Top 1 Accuracy: 89.98086232128786%
For CVACT Val, Top 5 Accuracy: 96.39761341889002%
For CVACT Val, Top 10 Accuracy: 97.21940785770573%
For CVACT Val, Top 50 Accuracy: 98.58156028368795%
For CVACT Val, Top 100 Accuracy: 98.82922436113925%
For CVACT Val, Top 88 Accuracy: 98.82922436113925%


In [15]:
test_matches = {1: 0, 5: 0, 10: 0, 50: 0, 100: 0, test_top_1_per: 0}

for i in range(len(test_street_features)):
    if i == test_ids[i][0]:
        for key in list(test_matches.keys()):
            test_matches[key]+=1        
    elif i in test_ids[i][0:5]:
        for key in list(test_matches.keys())[1:]:
            test_matches[key]+=1 
    elif i in test_ids[i][0:10]:
        for key in list(test_matches.keys())[2:]:
            test_matches[key]+=1
    elif i in test_ids[i][0:50]:
        for key in list(test_matches.keys())[3:]:
            test_matches[key]+=1
    elif i in test_ids[i][0:100]:
        for key in list(test_matches.keys())[4:]:
            test_matches[key]+=1
    elif i in test_ids[i][0:test_top_1_per]:
        for key in list(test_matches.keys())[5:]:
            test_matches[key]+=1

for keys in test_matches.keys():
    test_matches[keys] = (test_matches[keys] * 100) / len(test_street_features)
    print(f"For CVACT Test, Top {keys} Accuracy: {test_matches[keys]}%")

For CVACT Test, Top 1 Accuracy: 70.55419661426062%
For CVACT Test, Top 5 Accuracy: 91.63802114201356%
For CVACT Test, Top 10 Accuracy: 93.81687697330848%
For CVACT Test, Top 50 Accuracy: 96.66706177735153%
For CVACT Test, Top 100 Accuracy: 97.46015667934613%
For CVACT Test, Top 928 Accuracy: 98.75755649184815%


In [16]:
import time
import copy

def calculate_scores(query_features, reference_features, step_size=10000, ranks=[1,5,10]):

    topk = copy.deepcopy(ranks)
    Q = len(query_features)
    R = len(reference_features)
    
    steps = Q // step_size + 1
    
    similarity = []
    
    for i in range(steps):
        
        start = step_size * i
        
        end = start + step_size
          
        sim_tmp = query_features[start:end] @ reference_features.T
        
        similarity.append(sim_tmp)
     
    # matrix Q x R
    #similarity = torch.cat(similarity, dim=0)
    

    topk.append(R//100)
    
    results = np.zeros([len(topk)])
    
    
    bar = tqdm(range(Q))
    
    for i in bar:
        step_idx = i // step_size
        local_idx = i % step_size

        # Access the correct similarity chunk and row
        sim_chunk = similarity[step_idx]
        gt_sim = sim_chunk[local_idx, i]  # Ground truth similarity
        
        # number of references with higher similiarity as gt
        higher_sim = sim_chunk[local_idx, :] > gt_sim
        
         
        ranking = higher_sim.sum()
        for j, k in enumerate(topk):
            if ranking < k:
                results[j] += 1.
                        
        
    results = results/ Q * 100.
 
    
    bar.close()
    
    # wait to close pbar
    time.sleep(0.1)
    
    string = []
    for i in range(len(topk)-1):
        
        string.append('Recall@{}: {:.4f}'.format(topk[i], results[i]))
        
    string.append('Recall@top1: {:.4f}'.format(results[-1]))            
        
    print(' - '.join(string)) 

    return results[0]

In [17]:
calculate_scores(val_street_features, val_sat_features)

100%|██████████| 8883/8883 [00:00<00:00, 26178.70it/s]


Recall@1: 90.0371 - Recall@5: 96.3976 - Recall@10: 97.2194 - Recall@top1: 98.8292


90.03714961161769

In [18]:
calculate_scores(test_street_features, test_sat_features)

100%|██████████| 92801/92801 [00:07<00:00, 11643.66it/s]


Recall@1: 71.0736 - Recall@5: 91.6628 - Recall@10: 93.8255 - Recall@top1: 98.7576


71.07358756909947

In [19]:
def validate(dist_array, top_k):
    accuracy = 0.0
    data_amount = 0.0
    for i in range(dist_array.shape[0]):
        gt_dist = dist_array[i,i]
        prediction = torch.sum(dist_array[:, i] < gt_dist) 
        if prediction < top_k:
            accuracy += 1.0
        data_amount += 1.0
    accuracy /= data_amount

    return accuracy

In [20]:
print('   compute accuracy')
dist_array = 2.0 - 2.0 * torch.matmul(val_street_features, val_sat_features.T)

top1_percent = int(dist_array.shape[0] * 0.01) + 1
val_accuracy = torch.zeros((1, top1_percent)).cuda()

print('start')

print('top1', ':', validate(dist_array, 1))
print('top5', ':', validate(dist_array, 5))
print('top10', ':', validate(dist_array, 10))
print('top1%', ':', validate(dist_array, top1_percent))

   compute accuracy
start
top1 : 0.9005966452774964
top5 : 0.9636384104469211
top10 : 0.9737701227062929
top1% : 0.9882922436113926


In [21]:
def validate_batched(street_features, sat_features, top_k, batch_size=1024):
    accuracy = 0.0
    data_amount = 0.0
    num_samples = street_features.shape[0]
    
    for i in range(0, num_samples, batch_size):
        # Get a batch of street features
        street_batch = street_features[i:min(i+batch_size, num_samples)]
        
        # Compute distances for the current batch
        dist_batch = 2.0 - 2.0 * torch.matmul(street_batch, sat_features.T)
        
        # Iterate over the batch
        for j in range(dist_batch.shape[0]):
            global_index = i + j  # Get the correct global index
            
            # Ensure global_index is within bounds
            if global_index >= num_samples:
                break
            
            # Ground truth distance is the diagonal element in the full matrix
            gt_dist = dist_batch[j, global_index]

            # Calculate how many distances are smaller than the ground truth distance
            prediction = torch.sum(dist_batch[j, :] < gt_dist).item()

            if prediction < top_k:
                accuracy += 1.0
            data_amount += 1.0

    accuracy /= data_amount
    return accuracy

# Assuming test_street_features and test_sat_features are the feature arrays
batch_size = 1024
top1_percent = int(test_street_features.shape[0] * 0.01) + 1

print('   compute accuracy')
print('top1', ':', validate_batched(test_street_features, test_sat_features, 1, batch_size))
print('top5', ':', validate_batched(test_street_features, test_sat_features, 5, batch_size))
print('top10', ':', validate_batched(test_street_features, test_sat_features, 10, batch_size))
print('top1%', ':', validate_batched(test_street_features, test_sat_features, top1_percent, batch_size))


   compute accuracy
top1 : 0.7107358756909947
top5 : 0.9166280535770088
top10 : 0.9382549757006928
top1% : 0.9875863406644325
