#Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from pathlib import Path
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

#Globals

In [2]:
DATA_ROOT = Path(r"C:\Datasets\VeRi\VeRi")
NUM_WORKERS = 0
EMBED_DIM = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(len(list((DATA_ROOT/"image_train").glob("*.jpg"))))

37778


#Utils

In [3]:
def compute_distance_matrix(qf, gf):
    # cosine distance = 1 - cosine similarity
    sim = F.linear(qf, gf)     # (num_query, num_gallery)
    dist = 1 - sim
    return dist


In [4]:
@torch.no_grad()
def extract_features(model, loader):
    model.eval()
    feats = []
    pids = []
    camids = []

    for imgs, _, p, c in tqdm(loader):
        imgs = imgs.to(DEVICE)
        f, _ = model(imgs)     # normalized embeddings (L2 norm)
        feats.append(f.cpu())
        pids.extend(p)
        camids.extend(c)

    feats = torch.cat(feats, dim=0)
    pids = torch.tensor(pids)
    camids = torch.tensor(camids)

    return feats, pids, camids


In [5]:
class BatchHardTripletLoss(nn.Module):
    """
    Batch-hard triplet loss from:
    'In Defense of the Triplet Loss for Person Re-Identification'

    Expects:
      embeddings: (B, D) L2-normalized
      labels: (B,) long
    """
    def __init__(self, margin=0.3):
        super().__init__()
        self.margin = margin

    def forward(self, embeddings, labels):
        # embeddings: [B, D], labels: [B]
        dist_mat = torch.cdist(embeddings, embeddings, p=2)  # [B, B]

        # same-class mask & different-class mask
        labels = labels.view(-1, 1)
        mask_pos = labels.eq(labels.t())        # [B, B]
        mask_neg = ~mask_pos

        # hardest positive for each anchor
        dist_pos = dist_mat.clone()
        dist_pos[~mask_pos] = -1.0              # ignore non-positives
        hardest_pos, _ = dist_pos.max(dim=1)    # [B]

        # hardest negative for each anchor
        dist_neg = dist_mat.clone()
        dist_neg[~mask_neg] = 1e9               # ignore non-negatives
        hardest_neg, _ = dist_neg.min(dim=1)    # [B]

        # triplet loss
        losses = F.relu(hardest_pos - hardest_neg + self.margin)
        return losses.mean()


In [None]:

def train_one_epoch(model, loader, optimizer, epoch=1,
                    use_triplet=True, alpha=1.0, beta=1.0):
    """
    alpha: weight for CE loss
    beta:  weight for Triplet loss
    """
    model.train()
    total_ce, total_tri = 0.0, 0.0

    for i, (imgs, labels, _, _) in enumerate(loader):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        feats, logits = model(imgs)

        loss_ce = criterion_ce(logits, labels)
        if use_triplet:
            loss_tri = criterion_tri(feats, labels)
        else:
            loss_tri = torch.tensor(0.0, device=DEVICE)

        loss = alpha * loss_ce + beta * loss_tri
        loss.backward()
        optimizer.step()

        bs = imgs.size(0)
        total_ce += loss_ce.item() * bs
        total_tri += loss_tri.item() * bs

        # Optional: progress print every 100 batches
        if (i + 1) % 100 == 0:
            print(f"Epoch {epoch}, batch {i+1}/{len(loader)}, "
                  f"CE={loss_ce.item():.4f}, TRI={loss_tri.item():.4f}")

    n = len(loader.dataset)
    return total_ce / n, total_tri / n

In [7]:
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids):
    num_q, num_g = distmat.size()
    indices = torch.argsort(distmat, dim=1)  # sort ascending (smaller dist = better)

    matches = (g_pids[indices] == q_pids[:, None]).float()

    all_AP = []
    all_rank1 = []

    for i in range(num_q):
        # Remove gallery samples from the same camera & same PID
        valid = ~((g_pids == q_pids[i]) & (g_camids == q_camids[i]))
        y_true = matches[i][valid[indices[i]]]
        
        if y_true.sum() == 0:
            continue

        # Compute cumulative precision
        y_cum = torch.cumsum(y_true, dim=0)
        precision = y_cum / torch.arange(1, len(y_cum)+1).float()

        AP = (precision * y_true).sum() / y_true.sum()
        all_AP.append(AP.item())

        # Rank-1 accuracy
        all_rank1.append(y_true[0].item())

    mAP = sum(all_AP) / len(all_AP)
    rank1 = sum(all_rank1) / len(all_rank1)

    return mAP, rank1


#Data

In [8]:
class VeRiDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.transform = transform

        self.paths = sorted(list(self.img_dir.glob("*.jpg")))
        self.pids = []
        self.camids = []

        for p in self.paths:
            fname = p.stem
            pid_str, cam_str, *_ = fname.split("_")
            self.pids.append(int(pid_str))
            self.camids.append(int(cam_str[1:]))

        unique_pids = sorted(set(self.pids))
        self.pid2label = {p: i for i, p in enumerate(unique_pids)}
        self.labels = [self.pid2label[p] for p in self.pids]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, self.labels[idx], self.pids[idx], self.camids[idx]


In [9]:
criterion_ce = nn.CrossEntropyLoss()
criterion_tri = BatchHardTripletLoss(margin=0.3)


In [10]:
train_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

test_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

train_set = VeRiDataset(DATA_ROOT / "image_train", transform=train_tf)
query_set = VeRiDataset(DATA_ROOT / "image_query", transform=test_tf)
gallery_set = VeRiDataset(DATA_ROOT / "image_test", transform=test_tf)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True,num_workers=NUM_WORKERS,pin_memory=True)
query_loader = DataLoader(query_set, batch_size=32, shuffle=False,num_workers=NUM_WORKERS,pin_memory=True)
gallery_loader = DataLoader(gallery_set, batch_size=32, shuffle=False,num_workers=NUM_WORKERS,pin_memory=True)


#Network

In [12]:
NUM_CLASSES = len(train_set.pid2label)
class ReIDNet(nn.Module):
    def __init__(self, arch="resnet50", num_classes=NUM_CLASSES, embed_dim=EMBED_DIM):
        super().__init__()

        if arch == "resnet50":
            base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        else:
            base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        # keep everything except the final FC
        self.features = nn.Sequential(*list(base.children())[:-1])  # up to avgpool
        in_features = base.fc.in_features

        # embedding head + classifier
        self.embed = nn.Linear(in_features, embed_dim)
        self.bn = nn.BatchNorm1d(embed_dim)
        self.bn.bias.requires_grad_(False)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.features(x)              # [B, C, 1, 1]
        x = x.view(x.size(0), -1)         # [B, C]
        feat = self.embed(x)              # [B, D]
        feat = self.bn(feat)
        feat_norm = F.normalize(feat, p=2, dim=1)
        logits = self.classifier(feat)
        return feat_norm, logits


#Training

In [14]:
model = ReIDNet("resnet50").to(DEVICE)
model = ReIDNet("resnet18").to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
#optimizer = Adam(model.parameters(), lr=3e-4, weight_decay=5e-4)


EPOCHS = 1  

for epoch in range(1, EPOCHS + 1):
    ce_loss, tri_loss = train_one_epoch(
        model, train_loader, optimizer,
        epoch=epoch,
        use_triplet=True,
        alpha=1.0,
        beta=1.0
    )
    print(f"Epoch {epoch}: CE={ce_loss:.4f}, TRI={tri_loss:.4f}")


Epoch 1, batch 100/1181, CE=5.1328, TRI=0.0336
Epoch 1, batch 200/1181, CE=3.2683, TRI=0.0000
Epoch 1, batch 300/1181, CE=2.4595, TRI=0.0010
Epoch 1, batch 400/1181, CE=2.3437, TRI=0.0394
Epoch 1, batch 500/1181, CE=1.8796, TRI=0.0000
Epoch 1, batch 600/1181, CE=1.5260, TRI=0.0000
Epoch 1, batch 700/1181, CE=1.4202, TRI=0.0000
Epoch 1, batch 800/1181, CE=1.2068, TRI=0.0000
Epoch 1, batch 900/1181, CE=0.7579, TRI=0.0000
Epoch 1, batch 1000/1181, CE=0.7532, TRI=0.0018
Epoch 1, batch 1100/1181, CE=0.7632, TRI=0.0224
Epoch 1: CE=2.1562, TRI=0.0125


#Evaluation

In [15]:
# 1. Extract features
q_feats, q_pids, q_camids = extract_features(model, query_loader)
g_feats, g_pids, g_camids = extract_features(model, gallery_loader)

# 2. Compute distance matrix
distmat = compute_distance_matrix(q_feats, g_feats)

# 3. Evaluate
mAP, rank1 = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)

print(f"mAP:   {mAP*100:.2f}%")
print(f"Rank-1: {rank1*100:.2f}%")


100%|██████████| 53/53 [00:10<00:00,  5.22it/s]
100%|██████████| 362/362 [00:58<00:00,  6.14it/s]


mAP:   45.64%
Rank-1: 80.21%
