In [5]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from DELG_Class import DELG ,DELGBackbone,GeM, AttentionModule
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
import heapq

# Load the DELG model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your fine-tuned DELG model
model = DELG(pretrained=False, use_global=True, use_local=True).to(device)
classifier = torch.nn.Linear(2048, 18).to(device)

checkpoint = torch.load("delg_global_finetune_local_pretrained.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
classifier.load_state_dict(checkpoint['classifier_state_dict'])
model.eval()
classifier.eval()

# Prepare image dataset (all gallery/query images)

In [6]:
import os

def scan_dataset(root_dir):
    """
    root_dir: 'data/train' or 'data/val' or 'data/test'
    Returns: list of image paths and their class names
    """
    image_paths = []
    labels = []
    
    for class_name in sorted(os.listdir(root_dir)):
        class_folder = os.path.join(root_dir, class_name)
        if not os.path.isdir(class_folder):
            continue
        for fname in sorted(os.listdir(class_folder)):
            if fname.lower().endswith(('.jpg','.jpeg','.png')):
                image_paths.append(os.path.join(class_folder, fname))
                labels.append(class_name)
    
    return image_paths, labels


In [7]:
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import torch

class LandmarkDataset(Dataset):
    def __init__(self, image_paths, labels=None, transform=None): # Stores image paths, labels, and transforms.
        self.image_paths = image_paths
        self.labels = labels  # None for test set
        self.transform = transform
        
    def __len__(self): # returns the number of images in the dataset.
        return len(self.image_paths)
    
    def __getitem__(self, idx): # loads and returns an image and its label (if available).
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        
        if self.labels is not None:
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return img, label
        else:
            return img, img_path  # return path for query images

In [None]:
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# Replace with your dataset paths
all_image_paths, _ = scan_dataset("data/all_images")  # all images for retrieval
dataset = LandmarkDataset(all_image_paths, labels=None, transform=val_transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)


# Extract Global Descriptors

In [None]:
global_features = []
image_paths = []

with torch.no_grad():
    for imgs, paths in tqdm(loader, desc="Extracting global descriptors"):
        imgs = imgs.to(device)
        feats = model(imgs)['global']           # [B, 2048]
        feats = F.normalize(feats, p=2, dim=1) # L2 normalize
        global_features.append(feats.cpu())
        image_paths.extend(paths)

global_features = torch.cat(global_features, dim=0)  # [N_images, 2048]


# Top-k candidate retrieval

In [None]:
def get_topk_candidates(query_feat, all_feats, all_paths, k=10):
    # query_feat: [2048] tensor
    sims = F.cosine_similarity(query_feat.unsqueeze(0), all_feats)  # [N_images]
    topk_idx = torch.topk(sims, k=k).indices.cpu().numpy()
    return [all_paths[i] for i in topk_idx], sims[topk_idx]

# Example: pick the first image as query
query_feat = global_features[0]
query_path = image_paths[0]
topk_paths, topk_scores = get_topk_candidates(query_feat, global_features, image_paths, k=10)
print("Query image:", query_path)
print("Top-k candidates:", topk_paths)

# Extract local descriptors for top-k + query

In [None]:
def extract_local_desc(image_paths):
    local_descs = []
    attention_maps = []
    dataset_tmp = LandmarkDataset(image_paths, labels=None, transform=val_transform)
    loader_tmp = DataLoader(dataset_tmp, batch_size=16, shuffle=False, num_workers=2)
    with torch.no_grad():
        for imgs, paths in loader_tmp:
            imgs = imgs.to(device)
            out = model(imgs)['local']
            desc = out['descriptors'].cpu()  # [B, H*W, 2048]
            att = out['attention'].cpu()     # [B, 1, H, W]
            local_descs.append(desc)
            attention_maps.append(att)
    local_descs = torch.cat(local_descs, dim=0)
    attention_maps = torch.cat(attention_maps, dim=0)
    return local_descs, attention_maps

# Query local descriptors
query_local, query_att = extract_local_desc([query_path])
# Top-k local descriptors
topk_local, topk_att = extract_local_desc(topk_paths)


# Geometric verification / re-ranking

In [None]:
# Simplified re-ranking: compute cosine similarity between local descriptors
def rerank(query_desc, topk_descs, topk_paths, top_n=5):
    scores = []
    query_desc_flat = query_desc.view(-1, query_desc.size(-1))  # [H*W, 2048]
    for i, desc in enumerate(topk_descs):
        desc_flat = desc.view(-1, desc.size(-1))
        sim_matrix = F.cosine_similarity(query_desc_flat.unsqueeze(1), desc_flat.unsqueeze(0), dim=2)
        score = sim_matrix.max(dim=1)[0].mean().item()  # max per query point
        scores.append((score, topk_paths[i]))
    scores.sort(reverse=True, key=lambda x: x[0])
    return [p for s, p in scores[:top_n]]

topn_paths = rerank(query_local[0], topk_local, topk_paths, top_n=5)
print("Top-n after re-ranking:", topn_paths)