In [None]:
# Step 1: Imports and Dataset Setup
import os
from glob import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Paths (adjust if needed)
DATA_DIR = "data/Market-1501"
TRAIN_DIR = os.path.join(DATA_DIR, "bounding_box_train")
TEST_DIR = os.path.join(DATA_DIR, "bounding_box_test")
QUERY_DIR = os.path.join(DATA_DIR, "query")

# Basic transform
transform = transforms.Compose([
    transforms.Resize((256, 128)),  # standard ReID input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [None]:
class Market1501Dataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_paths = sorted(glob(os.path.join(img_dir, '*.jpg')))
        self.transform = transform

        # Parse raw person IDs
        self.raw_labels = []
        for path in self.img_paths:
            filename = os.path.basename(path)
            person_id = int(filename.split('_')[0])
            cam_id = int(filename.split('c')[1][0])
            self.raw_labels.append((person_id, cam_id))

        # Normalize person IDs to 0-based indexing
        unique_ids = sorted(set([pid for pid, _ in self.raw_labels]))
        self.id_map = {pid: idx for idx, pid in enumerate(unique_ids)}

        # Final labels
        self.labels = [(self.id_map[pid], cam_id) for pid, cam_id in self.raw_labels]

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        person_id, cam_id = self.labels[idx]
        return img, person_id, cam_id


In [None]:
# Step 3: Load Dataset and Preview
train_dataset = Market1501Dataset(TRAIN_DIR, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Preview
images, labels, cams = next(iter(train_loader))
print("Batch of images:", images.shape)
print("Person IDs:", labels[:10])
print("Camera IDs:", cams[:10])

# Optional: Visualize a few samples
plt.figure(figsize=(12, 4))
for i in range(6):
    plt.subplot(1, 6, i+1)
    img = images[i].permute(1, 2, 0).numpy()
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # unnormalize
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title(f"ID: {labels[i].item()}")
    plt.axis('off')
plt.tight_layout()
plt.show()


In [None]:
# Verify normalized IDs
all_ids = [label for _, label, _ in train_dataset]
unique_ids = sorted(set(all_ids))

print("Total unique person IDs:", len(unique_ids))
print("First 10 IDs:", unique_ids[:10])
print("Last 10 IDs:", unique_ids[-10:])
print("Are IDs normalized and continuous?", unique_ids == list(range(len(unique_ids))))


In [None]:
# Step 4: Re-ID Model with ResNet-50 backbone
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet50_Weights

class ReIDModel(nn.Module):
    def __init__(self, embedding_dim=512, num_classes=751):
        super(ReIDModel, self).__init__()

        # Load ResNet-50
        self.base = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.base.fc = nn.Identity()  # remove the original FC

        # Embedding layer (2048 -> 512)
        self.embedding = nn.Linear(2048, embedding_dim)

        # Optional classifier layer (for ID classification)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        features = self.base(x)                  # shape: (batch_size, 2048)
        embeddings = self.embedding(features)   # shape: (batch_size, 512)
        logits = self.classifier(embeddings)    # for cross-entropy loss
        return embeddings, logits


In [None]:
# Step 5: Check model output
model = ReIDModel(embedding_dim=512, num_classes=751)
model.eval()

sample_imgs, sample_labels, _ = next(iter(train_loader))
with torch.no_grad():
    embs, logits = model(sample_imgs)

print("Input image shape:", sample_imgs.shape)
print("Embedding shape:", embs.shape)
print("Logits shape (for classification):", logits.shape)


In [None]:
# Step 6: Losses, optimizer, scheduler
import torch.nn.functional as F

# Loss functions
cross_entropy_loss = nn.CrossEntropyLoss()
triplet_loss_fn = nn.TripletMarginLoss(margin=0.3)

# Optimizer (Adam works well)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

# Scheduler (optional)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
# Step 7 (updated): Training loop with average loss metrics
import time

EPOCHS = 10
use_triplet = True  # Toggle this as needed

model.train()

for epoch in range(EPOCHS):
    start = time.time()

    total_loss = 0.0
    total_ce = 0.0
    total_tri = 0.0
    total_samples = 0

    for images, labels, _ in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        batch_size = images.size(0)

        # Forward pass
        embeddings, logits = model(images)

        # CrossEntropy loss
        ce_loss = cross_entropy_loss(logits, labels)

        # Triplet loss
        if use_triplet:
            label_to_indices = {}
            for idx, label in enumerate(labels):
                label = label.item()
                label_to_indices.setdefault(label, []).append(idx)

            triplets = []
            for anchor_label, anchor_indices in label_to_indices.items():
                if len(anchor_indices) < 2:
                    continue
                a, p = anchor_indices[0], anchor_indices[1]
                for n_idx in range(len(labels)):
                    if labels[n_idx].item() != anchor_label:
                        triplets.append((a, p, n_idx))

            if triplets:
                a, p, n = zip(*triplets)
                anchor = embeddings[list(a)]
                positive = embeddings[list(p)]
                negative = embeddings[list(n)]
                tri_loss = triplet_loss_fn(anchor, positive, negative)
            else:
                tri_loss = torch.tensor(0.0).to(device)
        else:
            tri_loss = torch.tensor(0.0).to(device)

        # Combine losses
        loss = ce_loss + tri_loss

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate scaled losses
        total_loss += loss.item() * batch_size
        total_ce += ce_loss.item() * batch_size
        total_tri += tri_loss.item() * batch_size
        total_samples += batch_size

    scheduler.step()

    print(f"[Epoch {epoch+1}/{EPOCHS}] "
          f"Avg Loss: {total_loss / total_samples:.4f} | "
          f"Avg CE: {total_ce / total_samples:.4f} | "
          f"Avg Triplet: {total_tri / total_samples:.4f} | "
          f"Time: {time.time() - start:.1f}s")


In [None]:
# Step 8–10: Extract features, match, visualize top-5 results
from torchvision.utils import make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Re-load datasets (query + gallery)
query_dir = os.path.join(DATA_DIR, "query")
gallery_dir = os.path.join(DATA_DIR, "bounding_box_test")

query_dataset = Market1501Dataset(query_dir, transform=transform)
gallery_dataset = Market1501Dataset(gallery_dir, transform=transform)

query_loader = DataLoader(query_dataset, batch_size=1, shuffle=True)
gallery_loader = DataLoader(gallery_dataset, batch_size=64, shuffle=False)

# Put model in eval mode
model.eval()

# Select a single query image
query_img, query_label, _, = next(iter(query_loader))
query_img = query_img.to(device)

with torch.no_grad():
    query_emb, _ = model(query_img)

# Extract gallery embeddings
gallery_embs = []
gallery_paths = []

for imgs, _, _ in gallery_loader:
    imgs = imgs.to(device)
    with torch.no_grad():
        emb, _ = model(imgs)
        gallery_embs.append(emb)
gallery_embs = torch.cat(gallery_embs)
gallery_embs = F.normalize(gallery_embs, p=2, dim=1)  # normalize for cosine similarity
query_emb = F.normalize(query_emb, p=2, dim=1)

# Compute cosine similarity
distances = torch.mm(query_emb, gallery_embs.T).squeeze(0)  # shape: (num_gallery,)
top5_indices = torch.topk(distances, k=5).indices.cpu().numpy()

# Show top-5 matches
plt.figure(figsize=(15, 3))
plt.subplot(1, 6, 1)
plt.imshow(query_img[0].permute(1, 2, 0).cpu().numpy() * 0.229 + 0.485)
plt.title("Query")
plt.axis('off')

for i, idx in enumerate(top5_indices):
    path = gallery_dataset.img_paths[idx]
    img = Image.open(path).convert("RGB")
    plt.subplot(1, 6, i+2)
    plt.imshow(img)
    plt.title(f"Top {i+1}")
    plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
# Match multiple query images (Top-5 each)
num_queries_to_show = 5  # You can change this to 10, 20, etc.

model.eval()
shown = 0
query_loader_vis = DataLoader(query_dataset, batch_size=1, shuffle=True)

for query_img, query_label, query_cam in query_loader_vis:
    if shown >= num_queries_to_show:
        break

    query_img = query_img.to(device)
    query_label = query_label.item()
    query_cam = query_cam.item()

    with torch.no_grad():
        query_emb, _ = model(query_img)
        query_emb = F.normalize(query_emb, p=2, dim=1)

    # Compute gallery embeddings (cached would be better, but doing inline for now)
    gallery_embs = []
    gallery_labels = []
    gallery_cams = []
    gallery_paths = []

    for imgs, labels, cams in gallery_loader:
        imgs = imgs.to(device)
        with torch.no_grad():
            emb, _ = model(imgs)
            emb = F.normalize(emb, p=2, dim=1)
        gallery_embs.append(emb)
        gallery_labels.extend(labels.cpu().numpy())
        gallery_cams.extend(cams.cpu().numpy())
    gallery_embs = torch.cat(gallery_embs)
    gallery_paths = gallery_dataset.img_paths

    # Filter out same-camera same-person matches (optional)
    filtered_indices = [
        i for i, (pid, cam) in enumerate(zip(gallery_labels, gallery_cams))
        if pid != query_label or cam != query_cam
    ]

    filtered_gallery = gallery_embs[filtered_indices]
    filtered_paths = [gallery_paths[i] for i in filtered_indices]
    filtered_pids = [gallery_labels[i] for i in filtered_indices]

    # Compute similarity
    distances = torch.mm(query_emb, filtered_gallery.T).squeeze(0)
    top5_indices = torch.topk(distances, k=5).indices.cpu().numpy()
    top5_paths = [filtered_paths[i] for i in top5_indices]
    top5_labels = [filtered_pids[i] for i in top5_indices]

    # Visualization
    plt.figure(figsize=(15, 3))
    plt.subplot(1, 6, 1)
    img = query_img[0].permute(1, 2, 0).cpu().numpy()
    img = img * 0.229 + 0.485  # unnormalize
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title(f"Query\nID {query_label}")
    plt.axis('off')

    for i, (path, label) in enumerate(zip(top5_paths, top5_labels)):
        img = Image.open(path).convert("RGB")
        plt.subplot(1, 6, i+2)
        plt.imshow(img)
        plt.title(f"Top {i+1}\nID {label}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

    shown += 1


In [None]:
#  Load Underground ReID validation set with .jpeg support and custom ID parsing
from torch.utils.data import Dataset, DataLoader
from glob import glob
from PIL import Image
import os

# Custom dataset class for underground_reid
class UndergroundReIDDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_paths = sorted(
            glob(os.path.join(img_dir, "*.jpg")) +
            glob(os.path.join(img_dir, "*.jpeg"))
        )
        self.transform = transform
        self.raw_labels = []

        for path in self.img_paths:
            filename = os.path.basename(path)
            try:
                person_id = int(filename.split('_')[0])
            except:
                raise ValueError(f"Cannot extract person ID from filename: {filename}")
            cam_id = -1  # Placeholder since underground_reid doesn't provide camera info
            self.raw_labels.append((path, person_id, cam_id))

        # Normalize IDs
        self.unique_ids = sorted(set([pid for _, pid, _ in self.raw_labels]))
        self.id2label = {pid: idx for idx, pid in enumerate(self.unique_ids)}

    def __len__(self):
        return len(self.raw_labels)

    def __getitem__(self, idx):
        path, pid, cam = self.raw_labels[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.id2label[pid]
        return img, label, cam

# 🔗 Define folder paths
underground_query_dir = os.path.join("data", "underground_reid", "probe")
underground_gallery_dir = os.path.join("data", "underground_reid", "gallery")

# 📦 Load datasets and create loaders
underground_query_dataset = UndergroundReIDDataset(underground_query_dir, transform=transform)
underground_gallery_dataset = UndergroundReIDDataset(underground_gallery_dir, transform=transform)

underground_query_loader = DataLoader(underground_query_dataset, batch_size=1, shuffle=False)
underground_gallery_loader = DataLoader(underground_gallery_dataset, batch_size=64, shuffle=False)

# 🟢 Switch into evaluation mode
query_loader = underground_query_loader
gallery_loader = underground_gallery_loader

#  Check loaded counts
print(f" Underground ReID loaded: {len(underground_query_dataset)} queries, {len(underground_gallery_dataset)} gallery images.")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch

# Helper to unnormalize images for display
def unnormalize(img_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return img_tensor * std + mean

# Pick a query image
query_img, query_label, _ = next(iter(underground_query_loader))
query_img = query_img.to(device)
query_label = query_label.item()

# Extract query embedding
model.eval()
with torch.no_grad():
    query_emb, _ = model(query_img)
    query_emb = F.normalize(query_emb, p=2, dim=1)

# Extract all gallery embeddings
gallery_embs = []
gallery_imgs = []
gallery_labels = []

for imgs, labels, _ in underground_gallery_loader:
    imgs = imgs.to(device)
    with torch.no_grad():
        emb, _ = model(imgs)
        emb = F.normalize(emb, p=2, dim=1)
    gallery_embs.append(emb)
    gallery_imgs.append(imgs.cpu())
    gallery_labels.extend(labels.cpu().numpy())

gallery_embs = torch.cat(gallery_embs)
gallery_imgs = torch.cat(gallery_imgs)

# Compute similarity
similarities = torch.mm(query_emb, gallery_embs.T).squeeze(0)
top5_indices = torch.topk(similarities, k=5).indices.cpu().numpy()

# Display results
plt.figure(figsize=(12, 4))

# Query
plt.subplot(1, 6, 1)
query_disp = unnormalize(query_img.squeeze(0).cpu())
plt.imshow(query_disp.permute(1, 2, 0).numpy())
plt.title("Query")
plt.axis('off')

# Top 5 gallery
for i, idx in enumerate(top5_indices):
    plt.subplot(1, 6, i + 2)
    gal_disp = unnormalize(gallery_imgs[idx])
    plt.imshow(gal_disp.permute(1, 2, 0).numpy())
    plt.title(f"Top {i+1}")
    plt.axis('off')

plt.tight_layout()
plt.show()
