In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from lightning import LightningModule, Trainer, LightningDataModule
from collections import OrderedDict
from PIL import Image
from glob import glob
import joblib
import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm
from torchvision import transforms
from fastervit import create_model
import timm
from torch.cuda.amp import autocast
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from distill_transforms import get_transforms_val
DIM = 1024

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)
  check_for_updates()


In [14]:
prob_rotate: float = 0.75          # rotates the sat image and ground images simultaneously
prob_flip: float = 0.5             # flipping the sat image and ground images simultaneously
student_sat_size: tuple = (384, 384)
student_street_size: tuple = (384, 768)

sat_transforms_val_student, street_transforms_val_student = get_transforms_val(student_sat_size, student_street_size)

In [15]:
class WeightDistillEvalSet(Dataset):
    def __init__(self, img_type = 'reference', same_area = True):
        
        self.same_area = same_area
        self.img_type = img_type
        
        self.cities = []
        
        if same_area:
            self.cities = ['Chicago', 'NewYork', 'SanFrancisco', 'Seattle']
        else:
            self.cities = ['Chicago', 'SanFrancisco'] 
               
        # load sat list 
        sat_list = []
        for city in self.cities:
            df_tmp = pd.read_csv(f'VIGOR/splits/{city}/satellite_list.txt', header=None, sep='\s+')
            df_tmp = df_tmp.rename(columns={0: "sat"})
            df_tmp["path"] = df_tmp.apply(lambda x: f'VIGOR/{city}/satellite/{x.sat}', axis=1)
            sat_list.append(df_tmp)
        self.df_sat = pd.concat(sat_list, axis=0).reset_index(drop=True)
        
        # idx for complete train and test independent of mode = train or test
        sat2idx = dict(zip(self.df_sat.sat, self.df_sat.index))
        self.idx2sat = dict(zip(self.df_sat.index, self.df_sat.sat))
        self.idx2sat_path = dict(zip(self.df_sat.index, self.df_sat.path))
        
        
        # ground dependent on mode 'train' or 'test'
        ground_list = []
        for city in self.cities:
            
            if same_area:
                df_tmp = pd.read_csv(f'VIGOR/splits/{city}/cross_area_balanced_test.txt', header=None, sep='\s+')
            else:
                df_tmp = pd.read_csv(f'VIGOR/splits/{city}/pano_label_balanced.txt', header=None, sep='\s+')
  
            
            df_tmp = df_tmp.loc[:, [0, 1, 4, 7, 10]].rename(columns={0:  "ground",
                                                                     1:  "sat",
                                                                     4:  "sat_np1",
                                                                     7:  "sat_np2",
                                                                     10: "sat_np3"})
            
            df_tmp["path_ground"] = df_tmp.apply(lambda x: f'VIGOR/{city}/panorama/{x.ground}', axis=1)
            df_tmp["path_sat"] = df_tmp.apply(lambda x: f'VIGOR/{city}/satellite/{x.sat}', axis=1)
            
            df_tmp["path_sat_np1"] = df_tmp.apply(lambda x: f'VIGOR/{city}/satellite/{x.sat_np1}', axis=1)
            df_tmp["path_sat_np2"] = df_tmp.apply(lambda x: f'VIGOR/{city}/satellite/{x.sat_np2}', axis=1)
            df_tmp["path_sat_np3"] = df_tmp.apply(lambda x: f'VIGOR/{city}/satellite/{x.sat_np3}', axis=1)

            
            for sat_n in ["sat", "sat_np1", "sat_np2", "sat_np3"]:
                df_tmp[f'{sat_n}'] = df_tmp[f'{sat_n}'].map(sat2idx)
                
            ground_list.append(df_tmp) 
        self.df_ground = pd.concat(ground_list, axis=0).reset_index(drop=True)
        
        # idx for split train or test dependent on mode = train or test
        self.idx2ground = dict(zip(self.df_ground.index, self.df_ground.ground))
        self.idx2ground_path = dict(zip(self.df_ground.index, self.df_ground.path_ground))
        
        
        if self.img_type == "reference":
            # all sat images of cities in split
            self.images = self.df_sat["path"].values
            self.label = self.df_sat.index.values
            
        else: #img_type == "query":
            self.images = self.df_ground["path_ground"].values
            self.label = self.df_ground[["sat", "sat_np1", "sat_np2", "sat_np3"]].values

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.label[idx]

        # Open and process the image
        with Image.open(img_path).convert('RGB') as img:
            if self.img_type == 'reference':
                img = sat_transforms_val_student(image= np.asarray(img))['image']
            else:
                img = street_transforms_val_student(image= np.asarray(img))['image']
        
        return label, img

# Usage
query_set = WeightDistillEvalSet(img_type= 'query', same_area= False)
reference_set = WeightDistillEvalSet(img_type= 'reference', same_area= False)

query_loader = torch.utils.data.DataLoader(query_set, batch_size= 64, num_workers = 8, shuffle=False, drop_last=False)
reference_loader = torch.utils.data.DataLoader(reference_set, batch_size= 64, num_workers = 8, shuffle=False, drop_last=False)

In [16]:
# class ContrastiveModel(LightningModule):
#     def __init__(self):
#         super().__init__()

#         # self.teacher_model = TimmModel('convnext_base.fb_in22k_ft_in1k_384', pretrained=True, img_size= 384)
#         # self.teacher_model.load_state_dict(torch.load("pretrained/cvact/convnext_base.fb_in22k_ft_in1k_384/weights_e36_90.8149.pth", weights_only=True))
#         # self.teacher_model.eval()
#         # for param in self.teacher_model.parameters():
#         #     param.requires_grad = False
            
#         self.sat_model = create_model('faster_vit_0_224', pretrained=True, model_path="/tmp/faster_vit_0.pth.tar")
#         self.sat_linear = nn.Linear(1000, 1024)
        
#         self.street_model = create_model('faster_vit_0_224', pretrained=True, model_path="/tmp/faster_vit_0.pth.tar")
#         self.street_linear = nn.Linear(1000, 1024)
        
#         self.loss = nn.CosineEmbeddingLoss()

#     def forward(self, teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat):
#         student_sat_features = self.sat_linear(self.sat_model(stud_sat_feat))
#         student_street_features = self.street_linear(self.street_model(stud_street_feat))
        
#         # with autocast():
#         #     teacher_sat_features = self.teacher_model(teacher_sat_feat)
#         #     teacher_street_features = self.teacher_model(teacher_street_feat)
        
#         loss_sat = self.loss(student_sat_features, teacher_sat_features, torch.ones(student_sat_features.shape[0], device= self.device))
#         loss_street = self.loss(student_street_features, teacher_street_features, torch.ones(student_street_features.shape[0], device= self.device))
        
#         return (loss_sat + loss_street) / 2

    
#     def training_step(self, batch, batch_idx):
#         teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat = batch
#         loss = self(teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat)

#         #n = logits.size(0)
	
#         # -1 for off-diagonals and 1 for diagonals
#         #labels = 2 * torch.eye(n, device=logits.device) - 1
        
#         # pairwise sigmoid loss
#         #loss= -torch.sum(F.logsigmoid(labels * logits)) / n
    
#         self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
#         return loss

#     def validation_step(self, batch, batch_idx):
#         teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat = batch
#         loss = self(teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat)

#         #n = logits.size(0)
	
#         # # -1 for off-diagonals and 1 for diagonals
#         #labels = 2 * torch.eye(n, device=logits.device) - 1
        
#         # # pairwise sigmoid loss
#         #loss= -torch.sum(F.logsigmoid(labels * logits)) / n

#         self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
#     def configure_optimizers(self):
#         optimizer = torch.optim.AdamW(self.parameters(), lr= 1e-4)
#         scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs = 1, max_epochs = 512)
#         return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] #optimizer

#     def on_save_checkpoint(self, checkpoint):
#         # Exclude teacher model's weights from checkpoint
#         checkpoint['state_dict'] = {k: v for k, v in checkpoint['state_dict'].items() if 'teacher_model' not in k}

In [17]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# distill_model = ContrastiveModel.load_from_checkpoint("./logs/KD_EMBEDDING_CD_ONLY_VIGOR_CROSS/version_0/checkpoints/tinyvit.ckpt")
# distill_model = distill_model.to(device)
# distill_model.eval()

In [18]:
class TimmModel(nn.Module):

    def __init__(self, 
                 model_name,
                 pretrained=True,
                 img_size=383):
                 
        super(TimmModel, self).__init__()
        
        self.img_size = img_size
        
        if "vit" in model_name:
            # automatically change interpolate pos-encoding to img_size
            self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0, img_size=img_size) 
        else:
            self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
        
    def get_config(self,):
        data_config = timm.data.resolve_model_data_config(self.model)
        return data_config
    
    
    def set_grad_checkpointing(self, enable=True):
        self.model.set_grad_checkpointing(enable)

        
    def forward(self, img1, img2=None):
        
        if img2 is not None:
       
            image_features1 = self.model(img1)     
            image_features2 = self.model(img2)
            
            return image_features1, image_features2
              
        else:
            image_features = self.model(img1)
             
            return image_features

In [19]:
model = TimmModel('convnext_base.fb_in22k_ft_in1k_384', pretrained=True, img_size= 384)
model.load_state_dict(torch.load("pretrained/vigor_cross/convnext_base.fb_in22k_ft_in1k_384/weights_e40_0.6109.pth", weights_only=True))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

  return self.fget.__get__(instance, owner)()


TimmModel(
  (model): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=128, out_features=512, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=512, out_features=128, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (conv_dw):

In [20]:
def process_embeddings(sat_loader, street_loader, model, output_prefix, device):
    all_street_embed = []
    all_sat_embed = []
    all_sat_ids = []
    all_street_ids = []

    with autocast():
        with torch.no_grad():
            # Street Image Embeddings
            for (street_keys, street_images) in tqdm(street_loader):
                # street_features = model.street_linear(model.street_model(street_images.to(device)))
                street_features = model(street_images.to(device))
                all_street_embed.append(street_features.detach().cpu())
                
                all_street_ids.extend(street_keys.tolist())

            # Satellite Image Embeddings
            for (sat_keys, sat_images) in tqdm(sat_loader):
                # sat_features = model.sat_linear(model.sat_model(sat_images.to(device)))
                sat_features = model(sat_images.to(device))
                all_sat_embed.append(sat_features.detach().cpu())
                all_sat_ids.extend(sat_keys.tolist())

    # Save the embeddings
    joblib.dump(torch.vstack(all_street_embed), f'{output_prefix}_street_embed.joblib')
    joblib.dump(torch.vstack(all_sat_embed), f'{output_prefix}_sat_embed.joblib')
    joblib.dump(all_street_ids, f'{output_prefix}_all_street_ids.joblib')
    joblib.dump(all_sat_ids, f'{output_prefix}_all_sat_ids.joblib')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

process_embeddings(
    sat_loader= reference_loader,
    street_loader= query_loader,
    model=  model, # distill_model,
    output_prefix='MB1/vigor_cross_test',
    device=device
)

100%|██████████| 839/839 [07:57<00:00,  1.76it/s]
100%|██████████| 728/728 [03:31<00:00,  3.45it/s]


In [2]:
# Load the stored embeddings and IDs for s2d (satellite to drone)
query_features = joblib.load(open("MB1/vigor_cross_test_street_embed.joblib", 'rb'))
reference_features = joblib.load(open("MB1/vigor_cross_test_sat_embed.joblib", 'rb'))
query_ids = joblib.load(open('MB1/vigor_cross_test_all_street_ids.joblib', 'rb'))
reference_ids = joblib.load(open('MB1/vigor_cross_test_all_sat_ids.joblib', 'rb'))

In [3]:
#Normalizing features
query_features = F.normalize(query_features, dim=-1)
reference_features = F.normalize(reference_features, dim=-1)

In [4]:
import time
import copy

def calculate_scores(query_features, reference_features, query_labels, reference_labels, step_size=1000, ranks=[1,5,10]):

    topk = copy.deepcopy(ranks)
    Q = len(query_features)
    R = len(reference_features)
    
    steps = Q // step_size + 1
    
    query_labels_np = np.array(query_labels)
    reference_labels_np = np.array(reference_labels)
    
    #print(query_labels_np)
    
    ref2index = dict()
    for i, idx in enumerate(reference_labels_np):
        ref2index[idx] = i
    
    #print(ref2index)
    
    similarity = []
    
    for i in range(steps):
        
        start = step_size * i
        
        end = start + step_size
          
        sim_tmp = query_features[start:end] @ reference_features.T
        
        similarity.append(sim_tmp.cpu())
     
    # matrix Q x R
    similarity = torch.cat(similarity, dim=0)
    

    topk.append(R//100)
    
    results = np.zeros([len(topk)])
    
    hit_rate = 0.0
    
    bar = tqdm(range(Q))
    
    for i in bar:
        
        # similiarity value of gt reference
        gt_sim = similarity[i, ref2index[query_labels_np[i][0]]]
        
        # number of references with higher similiarity as gt
        higher_sim = similarity[i,:] > gt_sim
        
         
        ranking = higher_sim.sum()
        for j, k in enumerate(topk):
            if ranking < k:
                results[j] += 1.
                        
        # mask for semi pos
        mask = torch.ones(R)
        for near_pos in query_labels_np[i][1:]:
            mask[ref2index[near_pos]] = 0
        
        # calculate hit rate
        hit = (higher_sim * mask).sum()
        if hit < 1:
            hit_rate += 1.0
                
    
    results = results/ Q * 100.
    hit_rate = hit_rate / Q * 100
    
    bar.close()
    
    # wait to close pbar
    time.sleep(0.1)
    
    string = []
    for i in range(len(topk)-1):
        
        string.append('Recall@{}: {:.4f}'.format(topk[i], results[i]))
        
    string.append('Recall@top1: {:.4f}'.format(results[-1]))
    string.append('Hit_Rate: {:.4f}'.format(hit_rate))             
        
    print(' - '.join(string)) 

    return results[0]

In [5]:
calculate_scores(query_features, reference_features, query_ids, reference_ids)

100%|██████████| 53694/53694 [02:24<00:00, 370.93it/s] 


Recall@1: 61.7797 - Recall@5: 83.4879 - Recall@10: 88.0210 - Recall@top1: 98.1897 - Hit_Rate: 69.9166


61.779714679480016