In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from lightning import LightningModule, Trainer, LightningDataModule
from collections import OrderedDict
from PIL import Image
from glob import glob
import joblib
import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm
from torchvision import transforms
from fastervit import create_model
import timm
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from distill_transforms import get_transforms_val
from torch.cuda.amp import autocast
DIM = 1024

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)
  check_for_updates()


In [2]:
df = pd.read_csv('CVUSA/val.csv')
val_image_pairs_details = df.values.tolist()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
prob_rotate: float = 0.75          # rotates the sat image and ground images simultaneously
prob_flip: float = 0.5             # flipping the sat image and ground images simultaneously
student_sat_size: tuple = (224, 224)
student_street_size: tuple = (224, 224)

sat_transforms_val_student, street_transforms_val_student = get_transforms_val(student_sat_size, student_street_size)

In [4]:
class WeightDistillEvalSet(Dataset):
    def __init__(self, image_pairs_details):
        self.image_pairs_details = image_pairs_details

    def __len__(self):
        return len(self.image_pairs_details)

    def __getitem__(self, idx):
        sat_image_path = 'CVUSA/'+ self.image_pairs_details[idx][0]
        street_image_path = 'CVUSA/'+ self.image_pairs_details[idx][1]

        # Open and process the satellite image
        with Image.open(sat_image_path) as sat_image:
            sat_img = sat_transforms_val_student(image= np.asarray(sat_image))['image']
        # Open and process the street image
        with Image.open(street_image_path) as street_image:
            street_img = street_transforms_val_student(image= np.asarray(street_image))['image']
        
        return sat_img, street_img

# Usage
val_testset = WeightDistillEvalSet(val_image_pairs_details)

val_testloader = torch.utils.data.DataLoader(val_testset, batch_size= 32, num_workers = 6, shuffle=False, drop_last=False)

In [5]:
def contrastive_loss(logits, dim):
    neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
    return -neg_ce.sum()

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity, dim=0)
    image_loss = contrastive_loss(similarity, dim=1)
    return (caption_loss + image_loss) / 2.0

In [6]:
class ContrastiveModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.strict_loading = False
        # self.teacher_model = TimmModel('convnext_base.fb_in22k_ft_in1k_384', pretrained=True, img_size= 384)
        # self.teacher_model.load_state_dict(torch.load("pretrained/cvusa/convnext_base.fb_in22k_ft_in1k_384/weights_e40_98.6830.pth", weights_only=True))
        # self.teacher_model.eval()
        # for param in self.teacher_model.parameters():
        #     param.requires_grad = False
            
        self.sat_model = create_model('faster_vit_0_224', pretrained=True, model_path="/tmp/faster_vit_0.pth.tar")
        self.sat_linear = nn.Linear(1000, 1024)
        
        self.street_model = create_model('faster_vit_0_224', pretrained=True, model_path="/tmp/faster_vit_0.pth.tar")
        self.street_linear = nn.Linear(1000, 1024)
        
        # self.logit_scale = nn.Parameter(torch.tensor(np.log(1 / 0.07), dtype=torch.float32))
        
        # self.bias = nn.Parameter(torch.tensor(-10, dtype=torch.float32))
        
        self.loss = nn.CosineEmbeddingLoss()
                
    def forward(self, teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat):
        student_sat_features = self.sat_linear(self.sat_model(stud_sat_feat))
        student_street_features = self.street_linear(self.street_model(stud_street_feat))
        
        with autocast():
            teacher_sat_features = self.teacher_model(teacher_sat_feat)
            teacher_street_features = self.teacher_model(teacher_street_feat)
            
        # loss_sat = self.loss(student_sat_features, teacher_sat_features)
        # loss_street = self.loss(student_street_features, teacher_street_features)
        
        loss_sat = self.loss(student_sat_features, teacher_sat_features, torch.ones(student_sat_features.shape[0], device= self.device))
        loss_street = self.loss(student_street_features, teacher_street_features, torch.ones(student_street_features.shape[0], device= self.device))
        
        # normalized features
        # student_sat_features = F.normalize(student_sat_features, dim=-1)
        # student_street_features = F.normalize(student_street_features, dim=-1)
        
        # teacher_sat_features = F.normalize(teacher_sat_features, dim=-1)
        # teacher_street_features = F.normalize(teacher_street_features, dim=-1)

        # # cosine similarity as logits
        # logit_scale =  self.logit_scale.exp()

        # student_logits = logit_scale * student_sat_features @ student_street_features.t() + self.bias
        # teacher_logits = teacher_sat_features @ teacher_street_features.t()
        
        # n = student_logits.size(0)
	
        # # -1 for off-diagonals and 1 for diagonals
        # labels = 2 * teacher_logits - 1
        
        # # pairwise sigmoid loss
        # c_loss = -torch.sum(F.logsigmoid(labels * student_logits)) / n
        
        return (loss_sat + loss_street) / 2

    
    def training_step(self, batch, batch_idx):
        teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat = batch
        loss = self(teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat)
    
        #n = logits.size(0)
	
        # -1 for off-diagonals and 1 for diagonals
        #labels = 2 * torch.eye(n, device=logits.device) - 1
        
        # pairwise sigmoid loss
        #loss= -torch.sum(F.logsigmoid(labels * logits)) / n

        # alpha = self.current_epoch / self.trainer.max_epochs
        
        # loss = (1 - alpha) * mse_loss + alpha * c_loss
        
        self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat = batch
        loss = self(teacher_sat_feat, teacher_street_feat, stud_sat_feat, stud_street_feat)

        #n = logits.size(0)
	
        # # -1 for off-diagonals and 1 for diagonals
        #labels = 2 * torch.eye(n, device=logits.device) - 1
        
        # # pairwise sigmoid loss
        #loss= -torch.sum(F.logsigmoid(labels * logits)) / n

        self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr= 1e-4)
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs = 1, max_epochs = 512)
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] #optimizer

    def on_save_checkpoint(self, checkpoint):
        # Exclude teacher model's weights from checkpoint
        checkpoint['state_dict'] = {k: v for k, v in checkpoint['state_dict'].items() if 'teacher_model' not in k}

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
distill_model = ContrastiveModel.load_from_checkpoint("./logs/KD_EMBEDDING_CD_ONLY_CVUSA/version_0/checkpoints/tinyvit.ckpt")
distill_model = distill_model.to(device)
distill_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


ContrastiveModel(
  (sat_model): FasterViT(
    (patch_embed): PatchEmbed(
      (proj): Identity()
      (conv_down): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
    (levels): ModuleList(
      (0): FasterViTLayer(
        (blocks): ModuleList(
          (0): ConvBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act1): GELU(approximate='none')
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm

In [8]:
# model = TimmModel('convnext_base.fb_in22k_ft_in1k_384', pretrained=True, img_size= 384)
# model.load_state_dict(torch.load("pretrained/cvusa/convnext_base.fb_in22k_ft_in1k_384/weights_e40_98.6830.pth", weights_only=True))
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = model.to(device)
# model.eval()

In [9]:
all_street_embed = []
all_sat_embed = []

with torch.no_grad():
    for batch_data in tqdm(val_testloader):
        sat_img, street_img = batch_data
        
        # Street Image Embeddings
        street_features = distill_model.street_linear(distill_model.street_model(street_img.to(device)))
        all_street_embed.append(street_features.detach().cpu())

        # Original Satellite Image Embeddings
        sat_features = distill_model.sat_linear(distill_model.sat_model(sat_img.to(device)))
        all_sat_embed.append(sat_features.detach().cpu())
        

    # Save the embeddings
    joblib.dump(torch.vstack(all_street_embed), 'MB1/cvusa_street_embed.joblib')
    joblib.dump(torch.vstack(all_sat_embed), 'MB1/cvusa_sat_embed.joblib')

100%|██████████| 278/278 [00:22<00:00, 12.55it/s]


In [10]:
cvusa_sat_features = joblib.load(open("MB1/cvusa_sat_embed.joblib", 'rb'))
cvusa_street_features = joblib.load(open("MB1/cvusa_street_embed.joblib", 'rb'))

In [11]:
print(cvusa_sat_features.shape)
print(cvusa_street_features.shape)

torch.Size([8883, 1024])
torch.Size([8883, 1024])


In [12]:
cvusa_sat_features = F.normalize(cvusa_sat_features, dim=-1)
cvusa_street_features = F.normalize(cvusa_street_features, dim=-1)

In [13]:
res = faiss.StandardGpuResources()

cvusa_index = faiss.IndexFlatIP(DIM)
cvusa_index = faiss.index_cpu_to_gpu(res, 1, cvusa_index)

cvusa_index.add(cvusa_sat_features)

In [14]:
cvusa_top_1_per = int(0.01*len(cvusa_street_features))

_, cvusa_ids = cvusa_index.search(x=cvusa_street_features, k=cvusa_top_1_per)

In [15]:
cvusa_matches = {1: 0, 5: 0, 10: 0, 50: 0, 100: 0, cvusa_top_1_per: 0}

for i in range(len(cvusa_street_features)):
    if i == cvusa_ids[i][0]:
        for key in list(cvusa_matches.keys()):
            cvusa_matches[key]+=1  
    elif i in cvusa_ids[i][0:5]:
        for key in list(cvusa_matches.keys())[1:]:
            cvusa_matches[key]+=1 
    elif i in cvusa_ids[i][0:10]:
        for key in list(cvusa_matches.keys())[2:]:
            cvusa_matches[key]+=1
    elif i in cvusa_ids[i][0:50]:
        for key in list(cvusa_matches.keys())[3:]:
            cvusa_matches[key]+=1
    elif i in cvusa_ids[i][0:100]:
        for key in list(cvusa_matches.keys())[4:]:
            cvusa_matches[key]+=1
    elif i in cvusa_ids[i][0:cvusa_top_1_per]:
        for key in list(cvusa_matches.keys())[5:]:
            cvusa_matches[key]+=1

for keys in cvusa_matches.keys():
    cvusa_matches[keys] = (cvusa_matches[keys] * 100) / len(cvusa_street_features)
    print(f"For CVUSA, Top {keys} Accuracy: {cvusa_matches[keys]}%")
print()

For CVUSA, Top 1 Accuracy: 97.43329956095914%
For CVUSA, Top 5 Accuracy: 99.58347405155916%
For CVUSA, Top 10 Accuracy: 99.76359338061465%
For CVUSA, Top 50 Accuracy: 99.85365304514241%
For CVUSA, Top 100 Accuracy: 99.86491050320838%
For CVUSA, Top 88 Accuracy: 99.86491050320838%



In [16]:
import time
import copy

def calculate_scores(query_features, reference_features, step_size=500, ranks=[1,5,10]):

    topk = copy.deepcopy(ranks)
    Q = len(query_features)
    R = len(reference_features)
    
    steps = Q // step_size + 1
    
    similarity = []
    
    for i in range(steps):
        
        start = step_size * i
        
        end = start + step_size
          
        sim_tmp = query_features[start:end] @ reference_features.T
        
        similarity.append(sim_tmp)
     
    # matrix Q x R
    similarity = torch.cat(similarity, dim=0)
    

    topk.append(R//100)
    
    results = np.zeros([len(topk)])
    
    
    bar = tqdm(range(Q))
    
    for i in bar:
        
        # similiarity value of gt reference
        gt_sim = similarity[i, i]
        
        # number of references with higher similiarity as gt
        higher_sim = similarity[i,:] > gt_sim
        
         
        ranking = higher_sim.sum()
        for j, k in enumerate(topk):
            if ranking < k:
                results[j] += 1.
                        
        
    results = results/ Q * 100.
 
    
    bar.close()
    
    # wait to close pbar
    time.sleep(0.1)
    
    string = []
    for i in range(len(topk)-1):
        
        string.append('Recall@{}: {:.4f}'.format(topk[i], results[i]))
        
    string.append('Recall@top1: {:.4f}'.format(results[-1]))            
        
    print(' - '.join(string)) 

    return results[0]

In [17]:
calculate_scores(cvusa_street_features, cvusa_sat_features)

100%|██████████| 8883/8883 [00:00<00:00, 24298.23it/s]


Recall@1: 97.9174 - Recall@5: 99.5835 - Recall@10: 99.7636 - Recall@top1: 99.8649


97.9173702577958

In [18]:
def validate(dist_array, top_k):
    accuracy = 0.0
    data_amount = 0.0
    for i in range(dist_array.shape[0]):
        gt_dist = dist_array[i,i]
        prediction = torch.sum(dist_array[:, i] < gt_dist) 
        if prediction < top_k:
            accuracy += 1.0
        data_amount += 1.0
    accuracy /= data_amount

    return accuracy

In [19]:
print('   compute accuracy')
dist_array = 2.0 - 2.0 * torch.matmul(cvusa_street_features, cvusa_sat_features.T)

top1_percent = int(dist_array.shape[0] * 0.01) + 1
val_accuracy = torch.zeros((1, top1_percent)).cuda()

print('start')

print('top1', ':', validate(dist_array, 1))
print('top5', ':', validate(dist_array, 5))
print('top10', ':', validate(dist_array, 10))
print('top1%', ':', validate(dist_array, top1_percent))


   compute accuracy
start
top1 : 0.9800742992232354
top5 : 0.9954970167736125
top10 : 0.9972982100641675
top1% : 0.9986491050320837
