In [1]:
import json
import pickle
import sys

import clip
import numpy as np
import pandas as pd
import torch
from PIL import Image
from tqdm import tqdm

from geoscreens.geo_utils import gcd_threshold_eval, vectorized_gc_distance

In [2]:
model_name = "ViT-B/32"
# pretrained_path = "/shared/gbiamby/im2gps_kb/lib/open_clip/logs/lr=1e-06_wd=0.1_agg=True_model=ViT-B32_batchsize=96_workers=4_date=2022-03-15-04-33-55/epoch_2.pt"
# model_name = "RN50x16"
# pretrained_path = "/shared/gbiamby/geo/models/clip/ft_with_binary_classifier/best.ckpt"
device = "cuda"
model, preprocess = clip.load(model_name, device=device)

In [3]:
if pretrained_path:
    if ".ckpt" in pretrained_path:
        state_dict = torch.load(pretrained_path)["model"]
    else:
        state_dict = torch.load(pretrained_path)["state_dict"]
    for key in list(state_dict.keys()):
        state_dict[key.replace("module.", "")] = state_dict.pop(key)
        state_dict[key.replace("model.", "")] = state_dict.pop(key)
    model.load_state_dict(state_dict, strict=False)

NameError: name 'pretrained_path' is not defined

In [4]:
def embed_image(images):
    images = images.to(device)
    with torch.no_grad():
        image_embedddings = model.encode_image(images)
        image_embedddings /= image_embedddings.norm(dim=-1, keepdim=True)
    return image_embedddings


def get_image_embeddings(image_paths, batch_size=100):
    image_embeddings = []
    images = []
    for i in tqdm(range(0, len(image_paths), batch_size)):
        batch_images = [Image.open(image_path) for image_path in image_paths[i : i + batch_size]]
        images_tensor = torch.vstack([preprocess(image).unsqueeze(0) for image in batch_images])
        image_embeddings.append(embed_image(images_tensor))
        images.extend(batch_images)
    image_embeddings = torch.vstack(image_embeddings)
    return images, image_embeddings

CLIP Nearest Neighbors

In [5]:
reference = json.load(open("/shared/g-luo/geoguessr/data/placing2014/placing2014_no_indoor.json"))
streetview = pd.read_csv("/shared/g-luo/geoguessr/data/streetview/val/val.csv")
streetview = streetview.to_dict(orient="records")

In [7]:
# Get flickr embeddings
flickr_embeddings = pickle.load(
    open(
        # "/shared/gbiamby/geo/models/clip_ft_contrastive_00/vit-b32/placing2014_reference.pkl", "rb"
        "/shared/gbiamby/geo/models/clip_ft/vit-b32_zeroshot/placing2014_reference_image.pkl", "rb"
    )
)

# Get streetview embeddings
folder = "/shared/g-luo/geoguessr/data/streetview/val/cutter"
streetview_image_paths = [f"{folder}/{s['IMG_ID']}" for s in streetview]
streetview_images, streetview_embeddings = get_image_embeddings(streetview_image_paths)

100%|██████████| 10/10 [01:19<00:00,  7.96s/it]


In [8]:
r_ids = list(flickr_embeddings.keys())
s_ids = [s["IMG_ID"] for s in streetview]
flickr_embeddings = torch.vstack([torch.from_numpy(r) for r in flickr_embeddings.values()])

In [9]:
reference = {r["hash"]: r for r in reference}
streetview = {s["IMG_ID"]: s for s in streetview}

In [10]:
streetview_embeddings, flickr_embeddings = streetview_embeddings.to(device), flickr_embeddings.to(
    device
)
sims = {}
for s in tqdm(range(streetview_embeddings.shape[0])):
    embed_sims = streetview_embeddings[s] @ flickr_embeddings.T
    values, idxs = torch.topk(embed_sims, dim=-1, k=1)
    s_id = s_ids[s]
    r_id = r_ids[idxs.item()]
    sims[s_id] = (values.item(), r_id)

100%|██████████| 1000/1000 [00:05<00:00, 198.28it/s]


In [11]:
# For each index, get the ground truth GPS
latitudes, longitudes = [], []
latitudes_gt, longitudes_gt = [], []
for s_id, (prob, r_id) in sims.items():
    s = streetview[s_id]
    r = reference[r_id]

    latitudes.append(r["LAT"])
    longitudes.append(r["LON"])

    latitudes_gt.append(s["LAT"])
    longitudes_gt.append(s["LON"])

In [12]:
latitudes, longitudes, latitudes_gt, longitudes_gt = (
    torch.Tensor(latitudes),
    torch.Tensor(longitudes),
    torch.Tensor(latitudes_gt),
    torch.Tensor(longitudes_gt),
)
distances = vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt)
gcd_threshold_eval(distances)

{1: 0.0,
 25: 0.01600000075995922,
 200: 0.0949999988079071,
 750: 0.3089999854564667,
 2500: 0.6050000190734863}