<a href="https://colab.research.google.com/github/amanmehra-23/RE-Id_RP/blob/main/RE_IDPray.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch_geometric openai-clip

Collecting openai-clip
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from openai-clip)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: openai-clip
  Building wheel for openai-clip (setup.py) ... [?25l[?25hdone
  Created wheel for openai-clip: filename=openai_clip-1.0.1-py3-none-any.whl size=1368605 sha256=c00d8a57e85c450a7b91e20586c1a41d6eaf97644cc81db25719351c4cbf975a
  Stored in directory: /root/.cache/pip/wheels/0d/17/90/042948fd2e2a87f1dcf6db6d438cad015c49db0c53d1d9c7dc
Successfully built openai-clip
Installing collected packages: ftfy, openai-clip
Successfully insta

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from PIL import Image
import requests
from io import BytesIO

In [None]:
# --- Part 1: ResNet-50 Backbone ---
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
        # Remove the final pooling and FC layers: output shape (B, 2048, 7, 7)
        self.features = nn.Sequential(*list(resnet.children())[:-2])

    def forward(self, x):
        return self.features(x)  # Expected shape: (B, 2048, 7, 7)


In [None]:
# --- Part 2: Build Grid Graph ---
def build_grid_edge_index(grid_size):
    """
    Constructs edge indices for a grid graph given grid dimensions.
    Each node (patch) is connected to its right and down neighbor (and vice versa).
    """
    H, W = grid_size
    edges = []
    for i in range(H):
        for j in range(W):
            idx = i * W + j
            # Connect to right neighbor if exists
            if j + 1 < W:
                right_idx = i * W + (j + 1)
                edges.append((idx, right_idx))
                edges.append((right_idx, idx))
            # Connect to down neighbor if exists
            if i + 1 < H:
                down_idx = (i + 1) * W + j
                edges.append((idx, down_idx))
                edges.append((down_idx, idx))
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index  # Shape: (2, num_edges)

In [None]:
# --- Part 3: GNN Branch ---
class GNNBranch(nn.Module):
    def __init__(self, in_channels=2048, hidden_channels=512, out_channels=256, grid_size=(7,7)):
        super(GNNBranch, self).__init__()
        self.grid_size = grid_size
        self.edge_index = build_grid_edge_index(grid_size)  # Fixed for a given grid size

        # Two GCN layers
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.gcn2 = GCNConv(hidden_channels, out_channels)
        # Optional FC layer for further refinement
        self.fc = nn.Linear(out_channels, out_channels)

    def forward(self, x):
        """
        Args:
            x: CNN feature map of shape (B, 2048, H, W) with H,W = grid_size (e.g., 7,7)
        Returns:
            A tensor of shape (B, out_channels) representing the person embedding.
        """
        B, C, H, W = x.shape
        N = H * W  # Number of nodes (e.g., 49)
        # Reshape: (B, C, H, W) -> (B, N, C)
        x_nodes = x.view(B, C, N).permute(0, 2, 1)  # (B, N, 2048)
        embeddings = []
        edge_index = self.edge_index.to(x.device)  # Ensure edge_index is on the same device
        for i in range(B):
            node_feat = x_nodes[i]  # (N, 2048)
            h = F.relu(self.gcn1(node_feat, edge_index))  # (N, hidden_channels)
            h = self.gcn2(h, edge_index)  # (N, out_channels)
            # Global mean pooling: average over the N nodes
            pooled = h.mean(dim=0)  # (out_channels,)
            embeddings.append(pooled)
        embeddings = torch.stack(embeddings, dim=0)  # (B, out_channels)
        embeddings = self.fc(embeddings)
        return embeddings  # (B, out_channels) e.g., (B, 256)


In [None]:
# --- Part 4: CLIP Branch ---
# For the CLIP branch, we use OpenAI's CLIP model.
# Ensure you have installed the clip package (e.g., pip install git+https://github.com/openai/CLIP.git)
import clip

class CLIPBranch(nn.Module):
    def __init__(self, device='cuda'):
        super(CLIPBranch, self).__init__()
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()  # Set to eval mode
        self.proj = nn.Linear(512, 256)

    def forward(self, x):
        # Ensure x is on the correct device
        x = x.to(next(self.clip_model.parameters()).device)
        with torch.no_grad():
            clip_emb = self.clip_model.encode_image(x)  # (B, 512)
        # Convert to float32 to match the projection layer parameters
        clip_emb = clip_emb.float()
        clip_emb = self.proj(clip_emb)  # (B, 256)
        return clip_emb



In [None]:
# --- Part 5: Fusion Module ---
class FusionModule(nn.Module):
    def __init__(self, emb_dim=256):
        super(FusionModule, self).__init__()
        # Fusion via concatenation then projection to emb_dim
        self.fc = nn.Linear(emb_dim * 2, emb_dim)

    def forward(self, gnn_emb, clip_emb):
        """
        Args:
            gnn_emb: Embedding from GNN branch (B, emb_dim)
            clip_emb: Embedding from CLIP branch (B, emb_dim)
        Returns:
            Fused embedding (B, emb_dim)
        """
        fused = torch.cat([gnn_emb, clip_emb], dim=1)  # (B, 2*emb_dim)
        fused = self.fc(fused)
        return fused

In [None]:
# --- Part 6: Combined Model ---
class ReIDMultimodalNet(nn.Module):
    def __init__(self, device='cuda'):
        super(ReIDMultimodalNet, self).__init__()
        self.device = device
        self.backbone = ResNetBackbone(pretrained=True)
        self.gnn_branch = GNNBranch(in_channels=2048, hidden_channels=512, out_channels=256, grid_size=(7,7))
        self.clip_branch = CLIPBranch(device=device)
        self.fusion = FusionModule(emb_dim=256)

    def forward(self, x):
        """
        Args:
            x: Input image tensor of shape (B, 3, 224, 224)
        Returns:
            Fused multimodal embedding (B, 256)
        """
        x = x.to(self.device)
        # CNN backbone to get feature map: (B, 2048, 7, 7)
        feature_map = self.backbone(x)
        # GNN branch: process feature map and produce a 256-D embedding
        gnn_emb = self.gnn_branch(feature_map)
        # CLIP branch: process the image and produce a 256-D embedding
        clip_emb = self.clip_branch(x)
        # Fusion: combine the two embeddings
        fused_emb = self.fusion(gnn_emb, clip_emb)
        return fused_emb

In [None]:
# --- Preprocessing Pipeline for Input Images ---
# Standard preprocessing for ResNet and CLIP (CLIP's preprocessing may differ slightly)
preprocess_pipeline = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225]    # ImageNet std
    )
])

def load_image(image_path_or_url):
    if image_path_or_url.startswith("http"):
        response = requests.get(image_path_or_url)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_path_or_url).convert("RGB")
    return image

# --- Testing the Full Pipeline ---
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ReIDMultimodalNet(device=device).to(device)
    model.eval()

    # Load and preprocess an example image
    image_url = "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
    image = load_image(image_url)
    input_tensor = preprocess_pipeline(image).unsqueeze(0)  # (1, 3, 224, 224)

    with torch.no_grad():
        fused_embedding = model(input_tensor)
    print("Fused multimodal embedding shape:", fused_embedding.shape)  # Expected: (1, 256)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 144MB/s]
100%|███████████████████████████████████████| 338M/338M [00:06<00:00, 54.4MiB/s]


Fused multimodal embedding shape: torch.Size([1, 256])


In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("pengcw1/market-1501")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/market-1501


In [None]:
def supervised_contrastive_loss(embeddings, labels, temperature=0.07):
    """
    Computes the supervised contrastive loss as in Khosla et al. (2020).

    Args:
        embeddings: Tensor of shape (B, D) where B is the batch size and D is embedding dimension.
        labels: Tensor of shape (B,) with integer labels.
        temperature: A scaling factor for the logits.
    Returns:
        loss: A scalar representing the supervised contrastive loss.
    """
    device = embeddings.device
    batch_size = embeddings.shape[0]

    # Normalize embeddings to have unit norm
    embeddings = F.normalize(embeddings, p=2, dim=1)

    # Compute cosine similarity matrix (B, B) and scale by temperature
    similarity_matrix = torch.matmul(embeddings, embeddings.T)
    logits = similarity_matrix / temperature

    # Create mask where mask[i, j] = 1 if samples i and j have the same label and i != j
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.T).float().to(device)

    # Exclude self-comparisons from both the mask and the denominator
    logits_mask = torch.ones_like(mask) - torch.eye(batch_size, device=device)
    mask = mask * logits_mask

    # Compute log probabilities for each pair
    exp_logits = torch.exp(logits) * logits_mask  # (B, B)
    denominator = exp_logits.sum(1, keepdim=True) + 1e-8  # To avoid division by zero
    log_prob = logits - torch.log(denominator)

    # For each sample, calculate mean log-likelihood over all positive pairs
    mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)

    # The loss is the negative average of these log-likelihoods
    loss = -mean_log_prob_pos.mean()
    return loss

In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image

class Market1501Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Directory with all images (e.g., bounding_box_train, bounding_box_test, or query).
            transform: Transformations to be applied on the images.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Loop over all jpg files in the directory
        for file in os.listdir(root_dir):
            if file.endswith('.jpg'):
                # Extract the person ID from the file name.
                # Files are assumed to have a format like "0002_c1s1_000451_03.jpg".
                # However, files with junk/distractor images might have negative IDs like "-1_c..."
                id_str = file.split('_')[0]  # Get the first part
                # Skip if the id_str starts with '-' or is not composed of digits
                if id_str.startswith('-') or not id_str.isdigit():
                    continue
                person_id = int(id_str)
                # Optionally filter out junk or unwanted IDs (e.g., person_id <= 0)
                if person_id <= 0:
                    continue
                self.image_paths.append(os.path.join(root_dir, file))
                self.labels.append(person_id)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image, label


In [None]:
from torch.utils.data import DataLoader

# Use the preprocessing pipeline defined earlier
transform = preprocess_pipeline  # or define your own augmentations


path = "/kaggle/input/market-1501/Market-1501-v15.09.15"
# Build dataset paths
train_dir = os.path.join(path, "bounding_box_train")
test_dir = os.path.join(path, "bounding_box_test")  # For evaluation

# Create dataset objects
train_dataset = Market1501Dataset(root_dir=train_dir, transform=transform)
test_dataset = Market1501Dataset(root_dir=test_dir, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")


Number of training samples: 12936
Number of testing samples: 13115




In [None]:
import torch.optim as optim
# Preprocessing (ensure it suits both the ResNet backbone and CLIP; adjust if needed)
preprocess_pipeline = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225]    # ImageNet std
    )
])

# Assume 'path' is the dataset root from kagglehub (e.g., downloaded from "pengcw1/market-1501")
# Update the directory names as needed based on the actual dataset structure.
train_dir = os.path.join(path, "bounding_box_train")
# Build DataLoader for training set
train_dataset = Market1501Dataset(root_dir=train_dir, transform=preprocess_pipeline)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# ---------------------------
# 3. Model Initialization (Fused Multimodal ReID Model)
# ---------------------------
# Using ReIDMultimodalNet as defined previously (which outputs a fused 256-D embedding)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ReIDMultimodalNet(device=device).to(device)

# ---------------------------
# 4. Training Loop Using Supervised Contrastive Loss
# ---------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 50  # You may adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        embeddings = model(images)  # Fused embedding, shape (B, 256)

        # Compute Supervised Contrastive Loss
        loss = supervised_contrastive_loss(embeddings, labels, temperature=0.07)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Supervised Contrastive Loss: {avg_loss:.4f}")




Epoch 1/50 - Supervised Contrastive Loss: 0.1253
Epoch 2/50 - Supervised Contrastive Loss: 0.1169
Epoch 3/50 - Supervised Contrastive Loss: 0.1055
Epoch 4/50 - Supervised Contrastive Loss: 0.0843
Epoch 5/50 - Supervised Contrastive Loss: 0.0868
Epoch 6/50 - Supervised Contrastive Loss: 0.0831
Epoch 7/50 - Supervised Contrastive Loss: 0.0768
Epoch 8/50 - Supervised Contrastive Loss: 0.0808
Epoch 9/50 - Supervised Contrastive Loss: 0.0834
Epoch 10/50 - Supervised Contrastive Loss: 0.0724
Epoch 11/50 - Supervised Contrastive Loss: 0.0672
Epoch 12/50 - Supervised Contrastive Loss: 0.0677
Epoch 13/50 - Supervised Contrastive Loss: 0.0567
Epoch 14/50 - Supervised Contrastive Loss: 0.0584
Epoch 15/50 - Supervised Contrastive Loss: 0.0591
Epoch 16/50 - Supervised Contrastive Loss: 0.0531
Epoch 17/50 - Supervised Contrastive Loss: 0.0489
Epoch 18/50 - Supervised Contrastive Loss: 0.0448
Epoch 19/50 - Supervised Contrastive Loss: 0.0593
Epoch 20/50 - Supervised Contrastive Loss: 0.0655
Epoch 21/

In [None]:
torch.save(model.state_dict(), "/content/reid_multimodal_model.pth")
