<a href="https://colab.research.google.com/github/amanmehra-23/RE-Id_RP/blob/main/ReIdTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kagglehub openai-clip torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms,models
from PIL import Image
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import requests
from io import BytesIO

In [None]:
# ---------------------------
# (Assumed) Definition of Market1501Dataset
# ---------------------------
class Market1501Dataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Directory with images (e.g., query, bounding_box_test).
            transform: Transformations applied on images.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Loop over files in root_dir
        for file in os.listdir(root_dir):
            if file.endswith('.jpg'):
                # Expected format: "0002_c1s1_000451_03.jpg"
                id_str = file.split('_')[0]  # Get first token
                # Skip distractors or junk (ids starting with '-' or non-digit)
                if id_str.startswith('-') or not id_str.isdigit():
                    continue
                person_id = int(id_str)
                if person_id <= 0:
                    continue
                self.image_paths.append(os.path.join(root_dir, file))
                self.labels.append(person_id)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
# ---------------------------
# Define Preprocessing Pipeline
# ---------------------------
preprocess_pipeline = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225]    # ImageNet std
    )
])

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("pengcw1/market-1501")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/market-1501


In [None]:
# ---------------------------
# Set Dataset Paths
# ---------------------------
# 'path' should be set to the root directory you downloaded via KaggleHub.
# Adjust these as necessary based on your folder structure.
dataset_path = "/kaggle/input/market-1501/Market-1501-v15.09.15"  # Change as needed
query_dir = os.path.join(dataset_path, "query")
gallery_dir = os.path.join(dataset_path, "bounding_box_test")


In [None]:
# ---------------------------
# Create Dataset and DataLoader Objects
# ---------------------------
query_dataset = Market1501Dataset(root_dir=query_dir, transform=preprocess_pipeline)
gallery_dataset = Market1501Dataset(root_dir=gallery_dir, transform=preprocess_pipeline)

# It is often useful to use a lower number of workers to avoid freezing (e.g., num_workers=2)
query_loader = DataLoader(query_dataset, batch_size=32, shuffle=False, num_workers=2)
gallery_loader = DataLoader(gallery_dataset, batch_size=32, shuffle=False, num_workers=2)


In [None]:
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
        # Remove the final pooling and FC layers: output shape (B, 2048, 7, 7)
        self.features = nn.Sequential(*list(resnet.children())[:-2])

    def forward(self, x):
        return self.features(x)  # Expected shape: (B, 2048, 7, 7)

In [None]:
# --- Part 2: Build Grid Graph ---
def build_grid_edge_index(grid_size):
    """
    Constructs edge indices for a grid graph given grid dimensions.
    Each node (patch) is connected to its right and down neighbor (and vice versa).
    """
    H, W = grid_size
    edges = []
    for i in range(H):
        for j in range(W):
            idx = i * W + j
            # Connect to right neighbor if exists
            if j + 1 < W:
                right_idx = i * W + (j + 1)
                edges.append((idx, right_idx))
                edges.append((right_idx, idx))
            # Connect to down neighbor if exists
            if i + 1 < H:
                down_idx = (i + 1) * W + j
                edges.append((idx, down_idx))
                edges.append((down_idx, idx))
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index  # Shape: (2, num_edges)

In [None]:
# --- Part 3: GNN Branch ---
class GNNBranch(nn.Module):
    def __init__(self, in_channels=2048, hidden_channels=512, out_channels=256, grid_size=(7,7)):
        super(GNNBranch, self).__init__()
        self.grid_size = grid_size
        self.edge_index = build_grid_edge_index(grid_size)  # Fixed for a given grid size

        # Two GCN layers
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.gcn2 = GCNConv(hidden_channels, out_channels)
        # Optional FC layer for further refinement
        self.fc = nn.Linear(out_channels, out_channels)

    def forward(self, x):
        """
        Args:
            x: CNN feature map of shape (B, 2048, H, W) with H,W = grid_size (e.g., 7,7)
        Returns:
            A tensor of shape (B, out_channels) representing the person embedding.
        """
        B, C, H, W = x.shape
        N = H * W  # Number of nodes (e.g., 49)
        # Reshape: (B, C, H, W) -> (B, N, C)
        x_nodes = x.view(B, C, N).permute(0, 2, 1)  # (B, N, 2048)
        embeddings = []
        edge_index = self.edge_index.to(x.device)  # Ensure edge_index is on the same device
        for i in range(B):
            node_feat = x_nodes[i]  # (N, 2048)
            h = F.relu(self.gcn1(node_feat, edge_index))  # (N, hidden_channels)
            h = self.gcn2(h, edge_index)  # (N, out_channels)
            # Global mean pooling: average over the N nodes
            pooled = h.mean(dim=0)  # (out_channels,)
            embeddings.append(pooled)
        embeddings = torch.stack(embeddings, dim=0)  # (B, out_channels)
        embeddings = self.fc(embeddings)
        return embeddings  # (B, out_channels) e.g., (B, 256)


In [None]:
# --- Part 4: CLIP Branch ---
# For the CLIP branch, we use OpenAI's CLIP model.
# Ensure you have installed the clip package (e.g., pip install git+https://github.com/openai/CLIP.git)
import clip

class CLIPBranch(nn.Module):
    def __init__(self, device='cuda'):
        super(CLIPBranch, self).__init__()
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()  # Set to eval mode
        self.proj = nn.Linear(512, 256)

    def forward(self, x):
        # Ensure x is on the correct device
        x = x.to(next(self.clip_model.parameters()).device)
        with torch.no_grad():
            clip_emb = self.clip_model.encode_image(x)  # (B, 512)
        # Convert to float32 to match the projection layer parameters
        clip_emb = clip_emb.float()
        clip_emb = self.proj(clip_emb)  # (B, 256)
        return clip_emb

In [None]:
# --- Part 5: Fusion Module ---
class FusionModule(nn.Module):
    def __init__(self, emb_dim=256):
        super(FusionModule, self).__init__()
        # Fusion via concatenation then projection to emb_dim
        self.fc = nn.Linear(emb_dim * 2, emb_dim)

    def forward(self, gnn_emb, clip_emb):
        """
        Args:
            gnn_emb: Embedding from GNN branch (B, emb_dim)
            clip_emb: Embedding from CLIP branch (B, emb_dim)
        Returns:
            Fused embedding (B, emb_dim)
        """
        fused = torch.cat([gnn_emb, clip_emb], dim=1)  # (B, 2*emb_dim)
        fused = self.fc(fused)
        return fused

In [None]:
class ReIDMultimodalNet(nn.Module):
    def __init__(self, device='cuda'):
        super(ReIDMultimodalNet, self).__init__()
        self.device = device
        self.backbone = ResNetBackbone(pretrained=True)
        self.gnn_branch = GNNBranch(in_channels=2048, hidden_channels=512, out_channels=256, grid_size=(7,7))
        self.clip_branch = CLIPBranch(device=device)
        self.fusion = FusionModule(emb_dim=256)

    def forward(self, x):
        """
        Args:
            x: Input image tensor of shape (B, 3, 224, 224)
        Returns:
            Fused multimodal embedding (B, 256)
        """
        x = x.to(self.device)
        # CNN backbone to get feature map: (B, 2048, 7, 7)
        feature_map = self.backbone(x)
        # GNN branch: process feature map and produce a 256-D embedding
        gnn_emb = self.gnn_branch(feature_map)
        # CLIP branch: process the image and produce a 256-D embedding
        clip_emb = self.clip_branch(x)
        # Fusion: combine the two embeddings
        fused_emb = self.fusion(gnn_emb, clip_emb)
        return fused_emb

In [None]:

# -------------------------------------------------
# 5. Define Helper Functions for Evaluation
# -------------------------------------------------
def extract_embeddings(model, data_loader, device):
    """
    Extract embeddings and labels for all images in a DataLoader.
    """
    model.eval()
    all_embeddings = []
    all_labels = []
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            embeddings = model(images)  # Expect shape: (B, 256)
            all_embeddings.append(embeddings.cpu())
            all_labels.extend(labels.numpy())
    all_embeddings = torch.cat(all_embeddings, dim=0)
    return all_embeddings, np.array(all_labels)

def compute_distance_matrix(query_emb, gallery_emb, metric='euclidean'):
    """
    Compute a pairwise distance matrix between query and gallery embeddings.

    Args:
        query_emb: Tensor (num_query, D)
        gallery_emb: Tensor (num_gallery, D)
        metric: 'euclidean' or 'cosine'
    Returns:
        dist_matrix: Tensor (num_query, num_gallery)
    """
    if metric == 'euclidean':
        dist_matrix = torch.cdist(query_emb, gallery_emb, p=2)
    elif metric == 'cosine':
        query_norm = F.normalize(query_emb, p=2, dim=1)
        gallery_norm = F.normalize(gallery_emb, p=2, dim=1)
        dist_matrix = 1 - torch.mm(query_norm, gallery_norm.t())
    else:
        raise ValueError("Unsupported metric")
    return dist_matrix

def evaluate_rank1_map(dist_matrix, query_labels, gallery_labels):
    """
    Compute Rank-1 accuracy and mean Average Precision (mAP) given a distance matrix.
    """
    num_queries = dist_matrix.size(0)
    rank1 = 0
    ap_list = []

    query_labels = np.array(query_labels)
    gallery_labels = np.array(gallery_labels)

    for i in range(num_queries):
        distances = dist_matrix[i].cpu().numpy()
        sorted_indices = np.argsort(distances)
        matches = (gallery_labels[sorted_indices] == query_labels[i])

        if matches[0]:
            rank1 += 1

        num_relevant = matches.sum()
        if num_relevant == 0:
            continue

        precisions = []
        correct = 0
        for j, flag in enumerate(matches):
            if flag:
                correct += 1
                precisions.append(correct / (j + 1))
        ap_list.append(np.mean(precisions))

    rank1_accuracy = rank1 / num_queries
    mAP = np.mean(ap_list) if ap_list else 0
    return rank1_accuracy, mAP



In [None]:
model.load_state_dict(torch.load("/content/reid_multimodal_model.pth", map_location=device))
model.eval()

ReIDMultimodalNet(
  (backbone): ResNetBackbone(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inpla

In [None]:
# -------------------------------------------------
# 6. Run the Evaluation
# -------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# Assuming your model is already loaded and set to eval mode
# (Your model instance 'model' is loaded with the saved checkpoint as shown before)
query_embeddings, query_labels = extract_embeddings(model, query_loader, device)
gallery_embeddings, gallery_labels = extract_embeddings(model, gallery_loader, device)

# Choose a metric: 'euclidean' or 'cosine'
dist_matrix = compute_distance_matrix(query_embeddings, gallery_embeddings, metric='cosine')

rank1_accuracy, mAP = evaluate_rank1_map(dist_matrix, query_labels, gallery_labels)
print("Rank-1 Accuracy: {:.2%}".format(rank1_accuracy))
print("mAP: {:.2%}".format(mAP))

Rank-1 Accuracy: 86.88%
mAP: 37.19%
