In [1]:
import os
import glob
import xml.etree.ElementTree as ET


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import RoIAlign
from torchvision import transforms
from ultralytics import YOLO

import random
import numpy as np

In [2]:
from PIL import Image


class ResizePad:
    def __init__(self,size=(256,128),fill=0):
        self.target_h, self.target_w = size
        self.fill = fill
    def __call__(self,img):
        orig_w, orig_h = img.size
        scale = min(self.target_w/orig_w, self.target_h/orig_h)
        new_w, new_h = int(orig_w * scale),int(orig_h*scale)

        img = img.resize((new_w,new_h), Image.BILINEAR)

        new_img = Image.new("RGB",(self.target_w,self.target_h),(self.fill,)*3)
        paste_x = (self.target_w-new_w)//2
        paste_y = (self.target_h-new_h)//2
        new_img.paste(img,(paste_x,paste_y))

        return new_img

In [3]:
class FolderGroupedBatchDataset:
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        self.pid_to_paths = defaultdict(list)
        for pid in os.listdir(root_dir):
            pid_folder = os.path.join(root_dir, pid)
            if os.path.isdir(pid_folder):
                for fname in os.listdir(pid_folder):
                    if fname.endswith('.png'):
                        self.pid_to_paths[pid].append(os.path.join(pid_folder, fname))
        self.pids = list(self.pid_to_paths.keys())

    def sample(self, P, K):
        selected_pids = random.sample(self.pids, min(P, len(self.pids)))
        images, labels = [], []
        for pid in selected_pids:
            paths = self.pid_to_paths[pid]
            chosen = random.choices(paths, k=K) if len(paths) < K else random.sample(paths, K)
            for path in chosen:
                img = Image.open(path).convert("RGB")
                img = self.transform(img)
                images.append(img)
                labels.append(pid)
        return torch.stack(images), labels

In [4]:
class FolderGroupedBatchTrainingDataset:
    """
    Groups images by folder (person ID) for training only. Does not inherit from PyTorch Dataset
    because it's not accessed by index but by a custom sampling method.
    """
    def __init__(self, root_dir, transform=None):
        self.transform = transform or transforms.Compose([
            ResizePad((256, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

        self.pid_to_imgs = defaultdict(list)

        for pid in os.listdir(root_dir):
            folder = os.path.join(root_dir, pid)
            if not os.path.isdir(folder): continue

            for img_path in glob.glob(os.path.join(folder, '*.png')):
                self.pid_to_imgs[pid].append(img_path)

        self.pids = [pid for pid, imgs in self.pid_to_imgs.items() if len(imgs) >= 2]

    def sample(self, P, K):
        """
        Sample a batch of P identities with K images each.
        Returns: images (tensor list), labels (list of pids)
        """
        assert len(self.pids) >= P, "Not enough unique IDs to sample."

        batch_pids = random.sample(self.pids, P)
        images = []
        labels = []
        for pid in batch_pids:
            img_paths = random.sample(self.pid_to_imgs[pid], min(K, len(self.pid_to_imgs[pid])))
            for path in img_paths:
                img = Image.open(path).convert('RGB')
                img = self.transform(img)
                images.append(img)
                labels.append(pid)
        return images, labels

In [5]:
# Not used anymore.
def batch_hard_triplet_loss(embeddings, labels, margin=1.0,device=torch.device('cpu')):
    embeddings = F.normalize(embeddings, dim=1)
    labels = torch.tensor([hash(x) for x in labels])  # hash to int
    pdist = 1 - torch.matmul(embeddings, embeddings.T)
    mask_pos = labels.unsqueeze(1) == labels.unsqueeze(0)
    mask_neg = ~mask_pos
    mask_pos = mask_pos.float().to(device)
    mask_neg = mask_neg.float().to(device)

    hardest_pos = (pdist * mask_pos.float().to(device)).max(dim=1)[0]
    hardest_neg = (pdist + mask_pos.float().to(device) * 10).min(dim=1)[0]
    loss = F.relu(hardest_pos - hardest_neg + margin)
    return loss.mean()

In [6]:
import torch
import torch.nn.functional as F

def pairwise_distances(embeddings):
    # Compute cosine distance matrix
    normed = F.normalize(embeddings, p=2, dim=1)
    sim_matrix = torch.matmul(normed, normed.T)
    dist_matrix = 1 - sim_matrix  # cosine distance
    return dist_matrix

def combined_triplet_loss(embeddings, labels, margin=1.0, alpha=0.5, device=torch.device('cpu')):
    """
    Args:
        embeddings: Tensor [N, D]
        labels: List[str] or Tensor of IDs (can be string)
        margin: Triplet margin
        alpha: Weight for hard vs mean loss. alpha=0.5 → 50% hard, 50% mean
    """
    # Convert string labels to integer indices
    if isinstance(labels, list) and isinstance(labels[0], str):
        label_to_index = {label: idx for idx, label in enumerate(sorted(set(labels)))}
        labels = [label_to_index[label] for label in labels]
        labels = torch.tensor(labels, device=embeddings.device)
    elif isinstance(labels, list):
        labels = torch.tensor(labels, device=embeddings.device)
    else:
        labels = labels.to(embeddings.device)

    pairwise_dist = pairwise_distances(embeddings)
    N = embeddings.size(0)

    loss_hard = 0.0
    loss_mean = 0.0
    valid_triplets = 0

    for i in range(N):
        anchor_label = labels[i]
        dists = pairwise_dist[i]

        is_pos = (labels == anchor_label) & (torch.arange(N, device=embeddings.device) != i)
        is_neg = labels != anchor_label

        if torch.sum(is_pos) == 0 or torch.sum(is_neg) == 0:
            continue  # skip if no valid pairs

        # Hardest positive and negative
        hardest_pos = dists[is_pos].max()
        hardest_neg = dists[is_neg].min()
        hard_loss = F.relu(hardest_pos - hardest_neg + margin)

        # Mean-based variant
        mean_pos = dists[is_pos].mean()
        mean_neg = dists[is_neg].mean()
        mean_loss = F.relu(mean_pos - mean_neg + margin)

        # Combine
        loss = alpha * hard_loss + (1 - alpha) * mean_loss
        loss_hard += loss
        valid_triplets += 1

    if valid_triplets == 0:
        return torch.tensor(0.0, requires_grad=True, device=embeddings.device)

    return loss_hard / valid_triplets

In [7]:
import torch
import torch.nn as nn

class VisionAttentionLayer(nn.Module):
    """
    A standard Multi-Head Self-Attention layer for vision tasks.
    This layer is a core component of Vision Transformers (ViT).

    Args:
        dim (int): The embedding dimension of the input tokens.
        heads (int): The number of attention heads.
        dim_head (int, optional): The dimension of each attention head.
                                  Defaults to dim // heads.
        dropout (float, optional): Dropout rate. Defaults to 0.0.
    """
    def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        # The scale factor is a crucial detail for stabilizing training.
        # It's the inverse square root of the head dimension.
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x input shape: (batch_size, num_patches, dim)

        # 1. Project input to Q, K, V
        # Shape: (batch_size, num_patches, inner_dim * 3)
        qkv = self.to_qkv(x).chunk(3, dim=-1)

        # 2. Reshape Q, K, V for multi-head attention
        # Change shape to: (batch_size, heads, num_patches, dim_head)
        q, k, v = map(
            lambda t: t.reshape(t.shape[0], t.shape[1], self.heads, -1).permute(0, 2, 1, 3),
            qkv
        )

        # 3. Calculate scaled dot-product attention scores
        # (q @ k.transpose) -> (b, h, n, d) @ (b, h, d, n) -> (b, h, n, n)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 4. Apply softmax to get attention weights
        attn_weights = self.softmax(dots)
        attn_weights = self.dropout(attn_weights)

        # 5. Apply attention weights to V (values)
        # (attn_weights @ v) -> (b, h, n, n) @ (b, h, n, d) -> (b, h, n, d)
        attended_values = torch.matmul(attn_weights, v)

        # 6. Concatenate heads and project output
        # First, reshape to (b, n, h*d) where h*d = inner_dim
        out = attended_values.permute(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], -1)

        # Finally, project back to the original embedding dimension `dim`
        return self.to_out(out)

In [None]:
class ReIDAtten_v2(nn.Module):
    '''
    ReID Atten v2
    Reduced backbone of YOLOv11 
    Uses Attention Layer for head.
    157,024 parameters. 
    '''
    def __init__(self, yolo_weights='yolo11n.pt', emb_dim=128):
        super().__init__()

        yolo_model = YOLO(yolo_weights)
        self.backbone = nn.Sequential(*yolo_model.model.model[:5])
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
        self.backbone_output_dim = self._get_feat_dim()
        # Caveat : dim = dim_head = heads
        self.attn = VisionAttentionLayer(
            dim=self.backbone_output_dim, 
            heads=4, 
            dim_head=self.backbone_output_dim // 4)
        self.embed = nn.Linear(self.backbone_output_dim, emb_dim)

    def _get_feat_dim(self):
        x = torch.zeros((1, 3, 256, 128))
        with torch.no_grad():
            x = self.backbone(x)
            return x.shape[1]  # fix here
    def forward(self, x):
        x = self.backbone(x)          # (B, C, H, W)


        flat = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        # print("input to atten:", flat.shape)
        att = self.attn(flat)              # (B, H*W, C)
        # print(att.shape)
        att = att.mean(dim=1) 
        # print(att.shape)            # (B, C)
        embed = self.embed(att)             # (B, 128)
        return nn.functional.normalize(embed, dim=1)

In [17]:
model = ReIDAtten_v2()
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)

dummy = torch.randn(2,3,256,128)
out = model(dummy)  # shape: (2,128,16,8)
print(out.shape)

157024
torch.Size([2, 128])


In [18]:
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
import numpy as np
import torch.nn.functional as F

# Validation
@torch.no_grad()
def validate_similarity(model, dataset, device, P=5, K=5, inter_K=10):
    model.eval()
    total_intra, total_inter, count_intra, count_inter = 0.0, 0.0, 0, 0

    selected_pids = random.sample(dataset.pids, min(P, len(dataset.pids)))
    for pid in selected_pids:
        paths = dataset.pid_to_paths[pid]
        if len(paths) < 2:
            continue
        chosen = random.sample(paths, min(K, len(paths)))
        imgs = torch.stack([dataset.transform(Image.open(p).convert("RGB")) for p in chosen]).to(device)
        embs = F.normalize(model(imgs), dim=1)
        for i in range(len(embs)):
            for j in range(i + 1, len(embs)):
                total_intra += F.cosine_similarity(embs[i].unsqueeze(0), embs[j].unsqueeze(0)).item()
                count_intra += 1

        # Inter
        inter_paths = []
        for _ in range(inter_K):
            other_pid = random.choice([x for x in dataset.pids if x != pid])
            other_img = random.choice(dataset.pid_to_paths[other_pid])
            inter_paths.append(other_img)
        inter_imgs = torch.stack([dataset.transform(Image.open(p).convert("RGB")) for p in inter_paths]).to(device)
        inter_embs = F.normalize(model(inter_imgs), dim=1)
        for anchor in embs:
            for inter in inter_embs:
                total_inter += F.cosine_similarity(anchor.unsqueeze(0), inter.unsqueeze(0)).item()
                count_inter += 1

    avg_intra = total_intra / count_intra if count_intra else 0.0
    avg_inter = total_inter / count_inter if count_inter else 0.0
    return avg_intra, avg_inter

In [19]:
import os
# real_epoch=300
# base_dir = os.getcwd()
# model.load_state_dict(torch.load(os.path.join(base_dir, f"saved_model2/ReIDPooling_{real_epoch}.pth")))

In [20]:
# model.train()
# imgs, labels = train_dataset.sample(P=16, K=4)
# imgs = torch.stack(imgs).to(device)
# imgs = imgs.to(device)
# embeddings = model(imgs)
# print(embeddings.shape)
# print(len(labels))

In [21]:
# model = YOLOv11ReID().to(device)
# pytorch_total_params = sum(p.numel() for p in model.parameters())
# print(pytorch_total_params)


In [None]:
from torch.nn import TripletMarginLoss

device ='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = ReIDAtten_v2().to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-4)
real_epoch=0

cuda


In [27]:
import csv


train_dataset = FolderGroupedBatchTrainingDataset('../data/dataset1')
val_val_dataset = FolderGroupedBatchDataset('../data/valid_dataset1')
tr_val_dataset = FolderGroupedBatchDataset('../data/dataset1')

csv_log = f"ReID_csv/attenv2_{real_epoch}.csv"
with open(csv_log, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "loss", "train_intra", "train_inter", "val_intra", "val_inter"])

for epoch in range(50000):
    model.train()
    imgs, labels = train_dataset.sample(P=16, K=4)
    imgs = torch.stack(imgs).to(device)
    imgs = imgs.to(device)
    embeddings = model(imgs)
    loss = combined_triplet_loss(embeddings, labels,device=device)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 20==0:
        print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}")



        train_intra, train_inter = validate_similarity(model, tr_val_dataset, device)
        val_intra, val_inter = validate_similarity(model, val_val_dataset, device)
        print(f"train_intra : {train_intra}, train_inter : {train_inter}")
        print(f"valid_intra : {val_intra}, valid_inter : {val_inter}")

        with open(csv_log, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([real_epoch+1, loss.item(), train_intra, train_inter, val_intra, val_inter])
    if (epoch+1) % 500 == 0:
        torch.save(model.state_dict(), 'ReID_attenv2/'+f"ReIDAttenv2_{real_epoch+1}.pth")
 
    real_epoch += 1


[Epoch 20] Loss: 0.8899
train_intra : 0.6923840015381575, train_inter : 0.20703741423226893
valid_intra : 0.704823547154665, valid_inter : 0.2772107113599777
[Epoch 40] Loss: 0.8145
train_intra : 0.697523667961359, train_inter : 0.19942975732684134
valid_intra : 0.6922737288475037, valid_inter : 0.25784724313020707
[Epoch 60] Loss: 0.7349
train_intra : 0.5942222779989242, train_inter : 0.17763086554408072
valid_intra : 0.6334875009208918, valid_inter : 0.2336116529479623
[Epoch 80] Loss: 0.9897
train_intra : 0.8592259383201599, train_inter : 0.1007606630846858
valid_intra : 0.6218002918362617, valid_inter : 0.18793019834533334
[Epoch 100] Loss: 0.7522
train_intra : 0.6271457250788808, train_inter : 0.13475736662745474
valid_intra : 0.7589260786771774, valid_inter : 0.21939791368693112
[Epoch 120] Loss: 0.8368
train_intra : 0.6749433308839798, train_inter : 0.16169139859452844
valid_intra : 0.5226301111280918, valid_inter : 0.2913842646703124
[Epoch 140] Loss: 0.9913
train_intra : 0.838

KeyboardInterrupt: 

In [None]:
from tqdm import tqdm

@torch.no_grad()
def extract_features_ids(model, dataloader, device):
    model.eval()
    all_features = []
    all_ids = []

    for images, ids in tqdm(dataloader, desc="Extracting features"):
        images = images.to(device)
        emb = model(images)  # [B, 128]
        emb = F.normalize(emb, dim=1)  # cosine normalization

        all_features.append(emb.cpu())
        all_ids.extend(ids)  # assume ids are strings or ints

    features = torch.cat(all_features, dim=0)
    return features, all_ids

In [None]:
import matplotlib.pyplot as plt
def compute_similarity_gaps(features, ids,show=False):
    sim_matrix = cosine_similarity(features.numpy())  # [N, N]
    N = len(ids)
    gaps = []

    for i in range(N):
        query_id = ids[i]
        sim_scores = sim_matrix[i]

        # exclude self
        sim_scores[i] = -np.inf

        # positive scores (same ID, not self)
        pos_mask = np.array([j != i and ids[j] == query_id for j in range(N)])
        neg_mask = np.array([ids[j] != query_id for j in range(N)])

        if not np.any(pos_mask) or not np.any(neg_mask):
            continue  # skip if no pos/neg samples

        best_pos = np.max(sim_scores[pos_mask])
        best_neg = np.max(sim_scores[neg_mask])

        gap = best_pos - best_neg
        gaps.append(gap)

    gaps = np.array(gaps)
    print(f"Avg similarity gap (pos - hardest neg): {np.mean(gaps):.4f}")
    print(f"% queries where positive > negative: {(gaps > 0).mean()*100:.2f}%")
    if(show):
        # Optional: visualize
        plt.hist(gaps, bins=40, color='blue', edgecolor='black')
        plt.title("Distribution of (best positive - hardest negative) similarity gaps")
        plt.xlabel("Similarity Gap")
        plt.ylabel("Number of queries")
        plt.grid(True)
        plt.show()

In [None]:
intra, inter = compute_intra_inter_similarity(features, ids)

Avg similarity (same ID):     0.2792
Avg similarity (different ID): 0.0508
