# Vehicle Re-identification using Transformer and Contrastive Learning

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import average_precision_score
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import os

from dataset import VehicleReIDDataset
from vit import ViTEncoder
from loss import TripletLoss

In [None]:
# DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')
print(DEVICE)

### Utility Functions

In [None]:
def get_hard_triplets(embeddings, labels):
    '''
    Selects hard positives and negatives based on Euclidean distances.
    '''
    labels = labels.cpu().numpy()
    pairwise_distances = torch.cdist(embeddings, embeddings)  # Compute pairwise distances

    hard_triplets = []
    for i in range(len(labels)):
        anchor_idx = i
        anchor_label = labels[i]

        # Hard positive: Closest with same label
        positive_indices = np.where(labels == anchor_label)[0]
        positive_indices = positive_indices[positive_indices != i]  # Exclude self
        if len(positive_indices) == 0:
            continue
        positive_idx = positive_indices[torch.argmin(pairwise_distances[i, positive_indices])]

        # Hard negative: Furthest with different label
        negative_indices = np.where(labels != anchor_label)[0]
        negative_idx = negative_indices[torch.argmax(pairwise_distances[i, negative_indices])]

        hard_triplets.append((anchor_idx, positive_idx, negative_idx))

    return hard_triplets

In [None]:
def extract_embeddings(model, dataloader, device):
    '''
    Extracts embeddings for images in the dataset.
    '''
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            embeddings = model(images)
            all_embeddings.append(embeddings.cpu())
            all_labels.extend(labels.numpy())

    return torch.cat(all_embeddings), np.array(all_labels)

def retrieve_top_k(embedding, dataset_embeddings, dataset_labels, k=5):
    '''
    Retrieve top-k similar images using Euclidean distance.
    '''
    distances = torch.cdist(embedding.unsqueeze(0), dataset_embeddings)
    top_k_indices = torch.argsort(distances, dim=1)[0][:k]
    return dataset_labels[top_k_indices]

In [None]:
# initilaizing model
model = ViTEncoder(
    embed_dim=256,
    depth=4,
    n_heads=8,
    out_dim=512,
)

# defining transformations
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

### Training

In [None]:
def train(model, dataloader, optimizer, loss_fn, device, epochs=10):
    model.train()
    train_loss_hist = []
    for epoch in range(epochs):
        epoch_loss = 0
        for images, labels in tqdm(dataloader):
            images, labels = images.to(device), labels.to(device)

            # Extract embeddings
            embeddings = model(images)

            # Hard triplet selection
            triplets = get_hard_triplets(embeddings, labels)
            if len(triplets) == 0:
                continue

            anchor, positive, negative = zip(*triplets)
            anchor = torch.stack([embeddings[i] for i in anchor])
            positive = torch.stack([embeddings[i] for i in positive])
            negative = torch.stack([embeddings[i] for i in negative])

            # Compute loss
            loss = loss_fn(anchor, positive, negative)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        epoch_loss /= len(dataloader)
        train_loss_hist.append(epoch_loss)
        print(f'Epoch {epoch+1:3}/{epochs:3} | Loss: {epoch_loss:.4f}')
    return train_loss_hist

In [None]:
# Initialize dataset and dataloader
dataset = VehicleReIDDataset(root_dir='data/prepared_VeRi_CARLA_dataset/test', transform=transform)
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = TripletLoss(margin=0.3)

In [None]:
# Train model
train_loss_hist = train(model, dataloader, optimizer, loss_fn, DEVICE, epochs=10)

In [None]:
# Plot training loss
plt.figure()
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(train_loss_hist)
plt.show()

In [None]:
# Saving model
torch.save(model.state_dict(), 'trained_models/model.pth')

### Inference

In [None]:
model.load_state_dict(torch.load('trained_models/model.pth', map_location=DEVICE))
model.to(DEVICE)

model.eval()

In [None]:
def infer(model, image_path, device=DEVICE, transform=transform):
    '''
    Extracts embeddings for a single image.
    '''
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = model(image)
    return embedding.cpu().numpy().flatten()

trial_image_path = './data/VeRi_CARLA_dataset/image_gallery/20220710050038_0_44.jpg'
infer(model, trial_image_path)

### Performance Analysis

In [None]:
model.load_state_dict(torch.load('trained_models/model.pth', map_location=DEVICE))
model.to(DEVICE)

model.eval()

In [None]:
query_dataset = VehicleReIDDataset(root_dir='data/VeRi_CARLA_dataset/image_query', name='VeRi_CARLA', transform=transform)
gallery_dataset = VehicleReIDDataset(root_dir='data/VeRi_CARLA_dataset/image_gallery', name='VeRi_CARLA', transform=transform)

query_dataloader = DataLoader(query_dataset, batch_size=1, shuffle=False)
gallery_dataloader = DataLoader(gallery_dataset, batch_size=1, shuffle=False)

query_embeddings, query_labels = extract_embeddings(model, query_dataloader, DEVICE)
gallery_embeddings, gallery_labels = extract_embeddings(model, gallery_dataloader, DEVICE)

In [None]:
# Compute cosine similarity
similarity = cosine_similarity(query_embeddings, gallery_embeddings)

In [None]:
def compute_cmc(similarity_matrix, query_ids, gallery_ids, max_rank=10):
    """Computes CMC Curve & Rank-N Accuracy"""
    num_queries = len(query_ids)
    cmc_curve = np.zeros(max_rank)

    for i in range(num_queries):
        # Sort gallery images by similarity score
        sorted_indices = np.argsort(similarity_matrix[i])[::-1]

        # Get ranked list of predicted IDs
        ranked_vehicle_ids = gallery_ids[sorted_indices]

        # Find rank of the first correct match
        correct_match_ranks = np.where(ranked_vehicle_ids == query_ids[i])[0]

        if len(correct_match_ranks) > 0:
            first_correct_rank = correct_match_ranks[0]
            cmc_curve[first_correct_rank:] += 1  # Increment all ranks ≥ first match

    cmc_curve /= num_queries  # Normalize

    # Rank-N Metrics
    rank_1 = cmc_curve[0] * 100
    rank_5 = cmc_curve[4] * 100 if max_rank >= 5 else None
    rank_10 = cmc_curve[9] * 100 if max_rank >= 10 else None

    return cmc_curve, rank_1, rank_5, rank_10

cmc_curve, rank_1, rank_5, rank_10 = compute_cmc(similarity, query_labels, gallery_labels)

print(f"Rank-1 Accuracy: {rank_1:.2f}%")
print(f"Rank-5 Accuracy: {rank_5:.2f}%" if rank_5 else "")
print(f"Rank-10 Accuracy: {rank_10:.2f}%" if rank_10 else "")

ranks = np.arange(1, len(cmc_curve) + 1)
plt.figure(figsize=(8, 6))
plt.plot(ranks, cmc_curve, marker="o", linestyle="-", label="CMC Curve")
plt.xlabel("Rank")
plt.ylabel("Matching Accuracy")
plt.title("CMC Curve for Vehicle Re-Identification")
plt.legend()
plt.grid()
plt.show()

In [None]:
def compute_map(similarity_matrix, query_ids, gallery_ids):
    """Computes mean Average Precision (mAP)"""
    num_queries = len(query_ids)
    average_precisions = []

    for i in range(num_queries):
        # Sort gallery images by similarity score
        sorted_indices = np.argsort(similarity_matrix[i])[::-1]

        # Get binary relevance vector (1 if correct, 0 otherwise)
        relevance = (gallery_ids[sorted_indices] == query_ids[i]).astype(int)

        # Compute Average Precision (AP)
        if relevance.sum() > 0:
            ap = average_precision_score(relevance, similarity_matrix[i, sorted_indices])
            average_precisions.append(ap)

    return np.mean(average_precisions)

map_score = compute_map(similarity, query_labels, gallery_labels)
print(f'Mean Average Precision (mAP): {map_score * 100:.2f}%')