# Vehicle Re-identification using Transformer and Contrastive Learning

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np

from dataset import VehicleReIDDataset
from vit import VisionTransformer
from loss import TripletLoss

In [None]:
DEVICE = torch.device('cuda' if torch.backends.mps.is_available() else 'cpu')
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

True

In [None]:
def get_hard_triplets(embeddings, labels):
    '''
    Selects hard positives and negatives based on Euclidean distances.
    '''
    labels = labels.cpu().numpy()
    pairwise_distances = torch.cdist(embeddings, embeddings)  # Compute pairwise distances

    hard_triplets = []
    for i in range(len(labels)):
        anchor_idx = i
        anchor_label = labels[i]

        # Hard positive: Closest with same label
        positive_indices = np.where(labels == anchor_label)[0]
        positive_indices = positive_indices[positive_indices != i]  # Exclude self
        if len(positive_indices) == 0:
            continue
        positive_idx = positive_indices[torch.argmin(pairwise_distances[i, positive_indices])]

        # Hard negative: Furthest with different label
        negative_indices = np.where(labels != anchor_label)[0]
        negative_idx = negative_indices[torch.argmax(pairwise_distances[i, negative_indices])]

        hard_triplets.append((anchor_idx, positive_idx, negative_idx))

    return hard_triplets

In [None]:
def train(model, dataloader, optimizer, loss_fn, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            # Extract embeddings
            embeddings = model(images)

            # Hard triplet selection
            triplets = get_hard_triplets(embeddings, labels)
            if len(triplets) == 0:
                continue

            anchor, positive, negative = zip(*triplets)
            anchor = torch.stack([embeddings[i] for i in anchor])
            positive = torch.stack([embeddings[i] for i in positive])
            negative = torch.stack([embeddings[i] for i in negative])

            # Compute loss
            loss = loss_fn(anchor, positive, negative)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader):.4f}')

In [None]:
def extract_embeddings(model, dataloader, device):
    '''
    Extracts embeddings for images in the dataset.
    '''
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            embeddings = model(images)
            all_embeddings.append(embeddings.cpu())
            all_labels.extend(labels.numpy())

    return torch.cat(all_embeddings), np.array(all_labels)

def retrieve_top_k(embedding, dataset_embeddings, dataset_labels, k=5):
    '''
    Retrieve top-k similar images using Euclidean distance.
    '''
    distances = torch.cdist(embedding.unsqueeze(0), dataset_embeddings)
    top_k_indices = torch.argsort(distances, dim=1)[0][:k]
    return dataset_labels[top_k_indices]

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize dataset and dataloader
dataset = VehicleReIDDataset(root_dir='path/to/dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Load ViT model (assuming a ViT class is implemented)
model = VisionTransformer(embed_dim=512)  # Adjust based on your implementation
model.to(DEVICE)

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = TripletLoss(margin=0.3)

In [None]:
# Train model
train(model, dataloader, optimizer, loss_fn, DEVICE, epochs=10)

In [None]:
# Extract embeddings for inference
dataset_embeddings, dataset_labels = extract_embeddings(model, dataloader, DEVICE)

# Example inference: Get top-k matches for a new image
query_image, _ = dataset[0]  # Example query image
query_embedding = model(query_image.unsqueeze(0).to(DEVICE)).cpu()
top_k_matches = retrieve_top_k(query_embedding, dataset_embeddings, dataset_labels, k=5)

print('Top-5 Retrieved Labels:', top_k_matches)