In [1]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from geopy.distance import geodesic as GD
from torch.utils.data import DataLoader, random_split

from config import cfg

from geo_clip import GeoCLIP, img_val_transform
from dataset.dataset import GeoCLIPDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def distance_accuracy(targets, preds, dis=2500, gps_gallery=None):
    total = len(targets)
    correct = 0
    gd_avg = 0

    for i in range(total):
        gd = GD(gps_gallery[preds[i]], targets[i]).km
        gd_avg += gd
        if gd <= dis:
            correct += 1

    gd_avg /= total
    return correct / total, gd_avg

In [4]:
model = GeoCLIP(cfg)
state_dict = torch.load(cfg.MODEL.CHECKPOINT_PATH, map_location='cpu')
model.load_state_dict(state_dict['state_dict'])
model.to(device)
model.eval()

GeoCLIP(
  (image_encoder): ImageEncoder(
    (CLIP): CLIPModel(
      (text_model): CLIPTextTransformer(
        (embeddings): CLIPTextEmbeddings(
          (token_embedding): Embedding(49408, 768)
          (position_embedding): Embedding(77, 768)
        )
        (encoder): CLIPEncoder(
          (layers): ModuleList(
            (0-11): 12 x CLIPEncoderLayer(
              (self_attn): CLIPAttention(
                (k_proj): Linear(in_features=768, out_features=768, bias=True)
                (v_proj): Linear(in_features=768, out_features=768, bias=True)
                (q_proj): Linear(in_features=768, out_features=768, bias=True)
                (out_proj): Linear(in_features=768, out_features=768, bias=True)
              )
              (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (mlp): CLIPMLP(
                (activation_fn): QuickGELUActivation()
                (fc1): Linear(in_features=768, out_features=3072, bias=True)
            

In [5]:
dataset = GeoCLIPDataModule(dataset_file=cfg.DATA.EVAL_DATASET_FILE, transform=img_val_transform())

val_dataloader = DataLoader(dataset, pin_memory=True, batch_size=cfg.VALIDATION.BATCH_SIZE, shuffle=False, num_workers=cfg.VALIDATION.NUM_WORKERS, persistent_workers=True)

Loading image paths and coordinates: 4536it [00:00, 26360.81it/s]


In [6]:
gps_gallery = model.gps_gallery.to(device)

preds = []
targets = []

with torch.no_grad():
    for imgs, labels in tqdm(val_dataloader, desc="Evaluating"):
        labels = labels.numpy()
        imgs = imgs.numpy()

        imgs = torch.tensor(imgs, dtype=torch.float32, device=device)
        labels = torch.tensor(labels, dtype=torch.float32, device=device)

        logits_per_image = model(imgs, gps_gallery)
        probs = logits_per_image.softmax(dim=-1)

        output = torch.argmax(probs, dim=-1)

        preds.append(output.cpu().numpy())
        targets.append(labels.cpu().numpy())

preds = np.concatenate(preds, axis=0)
targets = np.concatenate(targets, axis=0)

Evaluating: 100%|██████████| 284/284 [35:35<00:00,  7.52s/it]


In [9]:
distance_thresholds = [2500, 750, 200, 25, 1] # km
accuracy_results = {}
for dis in distance_thresholds:
    acc, avg_distance_error = distance_accuracy(targets, preds, dis, gps_gallery)
    print(f"Accuracy at {dis} km: {round(acc*100, 2)}")
    accuracy_results[f'acc_{dis}_km'] = acc

Accuracy at 2500 km: 67.77
Accuracy at 750 km: 44.44
Accuracy at 200 km: 21.94
Accuracy at 25 km: 9.08
Accuracy at 1 km: 1.21
