# Pittsburgh250k Good Sample

이 노트북은 `pitts250k_train.mat`를 읽고, 이미지 경로를 실제 폴더 구조에 맞춰 해석한 뒤 샘플 이미지를 확인합니다.

In [None]:
from pathlib import Path
import numpy as np
import scipy.io as sio
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
# 노트북 위치: Netvlad_vanila/goodsample.ipynb
NB_DIR = Path.cwd()
if NB_DIR.name != "Netvlad_vanila":
    # 다른 위치에서 실행해도 자동 보정
    NB_DIR = Path("/home/lairpeteryksong/workspace/toy_Netvlad/Netvlad_vanila")

DATA_ROOT = NB_DIR / "data" / "Pittsburgh250k"
MAT_PATH = DATA_ROOT / "netvlad_v100_datasets" / "datasets" / "pitts250k_train.mat"

assert DATA_ROOT.exists(), f"Missing data root: {DATA_ROOT}"
assert MAT_PATH.exists(), f"Missing mat file: {MAT_PATH}"

DATA_ROOT, MAT_PATH

In [None]:
mat = sio.loadmat(MAT_PATH, squeeze_me=True, struct_as_record=False)
db_struct = mat["dbStruct"]

db_image_fns = np.array(db_struct.dbImageFns).tolist()
q_image_fns = np.array(db_struct.qImageFns).tolist()
utm_db = np.array(db_struct.utmDb)
utm_q = np.array(db_struct.utmQ)
pos_dist_thr = float(db_struct.posDistThr)
nontriv_pos_dist_sq_thr = float(db_struct.nonTrivPosDistSqThr)

print("num db images:", len(db_image_fns))
print("num query images:", len(q_image_fns))
print("utmDb shape:", utm_db.shape)
print("utmQ shape:", utm_q.shape)
print("posDistThr:", pos_dist_thr)
print("nonTrivPosDistSqThr:", nontriv_pos_dist_sq_thr)
print("db sample:", db_image_fns[0])
print("q sample:", q_image_fns[0])

In [None]:
db_roots = [
    DATA_ROOT / "images",
    DATA_ROOT,
]

q_roots = [
    DATA_ROOT / "queries",
    DATA_ROOT / "queries_real",
    DATA_ROOT / "queries_real" / "queries_real",
]

def resolve_rel_path(rel_path: str, roots):
    rel = Path(rel_path)

    # 후보 1: 일반 구조 (root / rel)
    for r in roots:
        p = r / rel
        if p.exists():
            return p

    # 후보 2: 현재 데이터처럼 한 단계 더 중첩된 구조 (root / a / a / file)
    if len(rel.parts) >= 2:
        a, file_name = rel.parts[0], rel.parts[1]
        for r in roots:
            p = r / a / a / file_name
            if p.exists():
                return p

    return None

resolved_db = [resolve_rel_path(p, db_roots) for p in db_image_fns[:100]]
resolved_q = [resolve_rel_path(p, q_roots) for p in q_image_fns[:100]]

db_ok = sum(x is not None for x in resolved_db)
q_ok = sum(x is not None for x in resolved_q)

print(f"resolved db (first 100): {db_ok}/100")
print(f"resolved q  (first 100): {q_ok}/100")
print("example resolved db:", next((str(x) for x in resolved_db if x is not None), None))
print("example resolved q :", next((str(x) for x in resolved_q if x is not None), None))

In [None]:
db_img_path = next(x for x in resolved_db if x is not None)
q_img_path = next(x for x in resolved_q if x is not None)

db_img = Image.open(db_img_path).convert("RGB")
q_img = Image.open(q_img_path).convert("RGB")

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(db_img)
axes[0].set_title(f"DB sample\
{db_img_path.name}")
axes[0].axis("off")

axes[1].imshow(q_img)
axes[1].set_title(f"Query sample\
{q_img_path.name}")
axes[1].axis("off")

plt.tight_layout()

## PyTorch Dataset (Triplet)

아래 셀은 NetVLAD 학습 아이디어와 동일하게 `(query, positive, negative)`를 반환하는 PyTorch Dataset을 만듭니다.


In [None]:
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.spatial import cKDTree


In [None]:
# 전체 DB/Q 경로를 실제 파일 경로로 해석
resolved_db_all = [resolve_rel_path(p, db_roots) for p in db_image_fns]
resolved_q_all = [resolve_rel_path(p, q_roots) for p in q_image_fns]

db_valid_ids = np.array([i for i, p in enumerate(resolved_db_all) if p is not None], dtype=np.int64)
q_valid_ids = np.array([i for i, p in enumerate(resolved_q_all) if p is not None], dtype=np.int64)

print('resolved db:', len(db_valid_ids), '/', len(db_image_fns))
print('resolved q :', len(q_valid_ids), '/', len(q_image_fns))


In [None]:
# UTM 기반 positive / non-trivial positive 계산을 위해 KDTree 사용
db_xy = utm_db.T.astype(np.float32)  # (N_db, 2)
q_xy = utm_q.T.astype(np.float32)    # (N_q, 2)

tree = cKDTree(db_xy)
pos_r = float(pos_dist_thr)
nontriv_r = float(np.sqrt(nontriv_pos_dist_sq_thr))

# NetVLAD 기준
# - potential positive: d <= posDistThr
# - non-trivial positive: 1 < d^2 <= nonTrivPosDistSqThr
all_pos_ids = tree.query_ball_point(q_xy, r=pos_r)
all_nontriv_ids = tree.query_ball_point(q_xy, r=nontriv_r)

valid_db_set = set(db_valid_ids.tolist())
valid_q_set = set(q_valid_ids.tolist())

query_to_pos = {}
query_to_excluded = {}
for qid in range(len(q_image_fns)):
    if qid not in valid_q_set:
        continue

    # negative에서 제외할 positive 집합
    excluded = [i for i in all_pos_ids[qid] if i in valid_db_set]

    # non-trivial positive 후보
    cand = all_nontriv_ids[qid]
    if cand:
        db_pts = db_xy[np.array(cand, dtype=np.int64)]
        d2 = np.sum((db_pts - q_xy[qid]) ** 2, axis=1)
        keep = (d2 > 1.0) & (d2 <= nontriv_pos_dist_sq_thr)
        nontriv = [cand[i] for i, k in enumerate(keep) if k and (cand[i] in valid_db_set)]
    else:
        nontriv = []

    if nontriv:
        query_to_pos[qid] = np.array(nontriv, dtype=np.int64)
        query_to_excluded[qid] = set(excluded)

print('trainable queries (has non-trivial pos):', len(query_to_pos))


In [None]:
class Pitts250kTripletDataset(Dataset):
    def __init__(self, resolved_db_paths, resolved_q_paths, query_to_pos, query_to_excluded, transform=None):
        self.resolved_db_paths = resolved_db_paths
        self.resolved_q_paths = resolved_q_paths
        self.query_to_pos = query_to_pos
        self.query_to_excluded = query_to_excluded
        self.query_ids = sorted(query_to_pos.keys())

        self.valid_db_ids = np.array([
            i for i, p in enumerate(resolved_db_paths) if p is not None
        ], dtype=np.int64)

        if transform is None:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ])
        self.transform = transform

    def __len__(self):
        return len(self.query_ids)

    def _sample_negative(self, qid):
        excluded = self.query_to_excluded[qid]
        for _ in range(100):
            neg_id = int(np.random.choice(self.valid_db_ids))
            if neg_id not in excluded:
                return neg_id

        # 예외적으로 많이 겹치면 안전 fallback
        candidates = [i for i in self.valid_db_ids.tolist() if i not in excluded]
        if not candidates:
            raise RuntimeError('No valid negative candidate found for query id: {}'.format(qid))
        return int(random.choice(candidates))

    def __getitem__(self, idx):
        qid = self.query_ids[idx]

        pos_candidates = self.query_to_pos[qid]
        pos_id = int(np.random.choice(pos_candidates))
        neg_id = self._sample_negative(qid)

        q_img = Image.open(self.resolved_q_paths[qid]).convert('RGB')
        p_img = Image.open(self.resolved_db_paths[pos_id]).convert('RGB')
        n_img = Image.open(self.resolved_db_paths[neg_id]).convert('RGB')

        q = self.transform(q_img)
        p = self.transform(p_img)
        n = self.transform(n_img)

        return {
            'query': q,
            'positive': p,
            'negative': n,
            'query_id': qid,
            'positive_id': pos_id,
            'negative_id': neg_id,
        }


In [None]:
dataset = Pitts250kTripletDataset(
    resolved_db_paths=resolved_db_all,
    resolved_q_paths=resolved_q_all,
    query_to_pos=query_to_pos,
    query_to_excluded=query_to_excluded,
)

loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

print('dataset size:', len(dataset))
batch = next(iter(loader))
print('query shape   :', batch['query'].shape)
print('positive shape:', batch['positive'].shape)
print('negative shape:', batch['negative'].shape)
