In [None]:
import os
import xml.etree.ElementTree as ET
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
from collections import defaultdict
import random
from sklearn.metrics import average_precision_score
import math

class VeRiDataset(Dataset):
    def __init__(self, data_dir, split='train', transform=None, triplet_sampling=False):
        """
        Initialize the VeRi dataset.

        Args:
            data_dir (str): Root directory of the VeRi dataset
            split (str): 'train', 'query', or 'test'
            transform: PyTorch transforms for image preprocessing
            triplet_sampling (bool): Whether to enable triplet sampling for training
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        self.triplet_sampling = triplet_sampling
        self.image_dir = {
            'train': 'image_train',
            'query': 'image_query',
            'test': 'image_test'
        }[split]
        self.data = self._load_data()
        self.vehicle_id_to_idx = {vid: idx for idx, vid in enumerate(sorted(set(item['vehicle_id'] for item in self.data)))}
        self.camera_id_to_idx = {cid: idx for idx, cid in enumerate(sorted(set(item['camera_id'] for item in self.data)))}

        if triplet_sampling and split == 'train':
            self._build_triplet_index()

    def _load_data(self):
        """Load image file names and labels based on the split."""
        name_file = {
            'train': 'name_train.txt',
            'query': 'name_query.txt',
            'test': 'name_test.txt'
        }[self.split]

        try:
            with open(os.path.join(self.data_dir, name_file), 'r', encoding='utf-8') as f:
                image_files = [line.strip().split()[0] for line in f.readlines() if line.strip()]
        except UnicodeDecodeError:
            with open(os.path.join(self.data_dir, name_file), 'r', encoding='latin-1') as f:
                image_files = [line.strip().split()[0] for line in f.readlines() if line.strip()]

        label_file = 'train_label.xml' if self.split == 'train' else 'test_label.xml'
        label_path = os.path.join(self.data_dir, label_file)

        label_dict = {}
        try:
            with open(label_path, 'r', encoding='utf-8') as f:
                content = f.read()
            root = ET.fromstring(content)
        except (UnicodeDecodeError, ET.ParseError):
            try:
                with open(label_path, 'r', encoding='latin-1') as f:
                    content = f.read()
                root = ET.fromstring(content)
            except ET.ParseError:
                try:
                    tree = ET.parse(label_path)
                    root = tree.getroot()
                except ValueError:
                    with open(label_path, 'rb') as f:
                        content = f.read().decode('utf-8', errors='ignore')
                    root = ET.fromstring(content)

        items_element = root.find('Items')
        if items_element is not None:
            for item in items_element.findall('Item'):
                img_name = item.get('imageName')
                if img_name:
                    label_dict[img_name] = {
                        'vehicle_id': item.get('vehicleID', '0'),
                        'camera_id': item.get('cameraID', 'c001'),
                        'color_id': item.get('colorID', '1'),
                        'type_id': item.get('typeID', '1')
                    }

        data = []
        for img_file in image_files:
            if img_file in label_dict:
                img_path = os.path.join(self.data_dir, self.image_dir, img_file)
                if os.path.exists(img_path):
                    data.append({
                        'img_path': img_path,
                        'vehicle_id': label_dict[img_file]['vehicle_id'],
                        'camera_id': label_dict[img_file]['camera_id'],
                        'color_id': label_dict[img_file]['color_id'],
                        'type_id': label_dict[img_file]['type_id'],
                        'img_name': img_file
                    })
                else:
                    print(f"Warning: Image file not found: {img_path}")
            else:
                print(f"Warning: No label found for image: {img_file}")

        print(f"Loaded {len(data)} samples for {self.split} split")
        return data

    def _build_triplet_index(self):
        """Build index for efficient triplet sampling."""
        self.vehicle_to_images = defaultdict(list)
        for idx, item in enumerate(self.data):
            self.vehicle_to_images[item['vehicle_id']].append(idx)
        self.vehicle_ids = list(self.vehicle_to_images.keys())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.triplet_sampling and self.split == 'train':
            return self._get_triplet(idx)
        else:
            return self._get_single_item(idx)

    def _get_single_item(self, idx):
        item = self.data[idx]
        try:
            img = Image.open(item['img_path']).convert('RGB')
        except Exception as e:
            print(f"Error loading image {item['img_path']}: {e}")
            img = Image.new('RGB', (224, 224), color=(0, 0, 0))

        if self.transform:
            img = self.transform(img)

        vehicle_idx = self.vehicle_id_to_idx[item['vehicle_id']]
        camera_idx = self.camera_id_to_idx[item['camera_id']]

        return {
            'img': img,
            'vehicle_id': vehicle_idx,
            'camera_id_idx': camera_idx,
            'img_name': item.get('img_name', None)
        }

    def _get_triplet(self, idx):
        anchor_item = self.data[idx]
        anchor_vehicle_id = anchor_item['vehicle_id']

        pos_candidates = self.vehicle_to_images.get(anchor_vehicle_id, [])
        pos_candidates = [i for i in pos_candidates if i != idx]
        if len(pos_candidates) > 0:
            pos_idx = random.choice(pos_candidates)
        else:
            pos_idx = idx

        neg_vehicle_id = random.choice([vid for vid in self.vehicle_ids if vid != anchor_vehicle_id])
        neg_idx = random.choice(self.vehicle_to_images[neg_vehicle_id])

        try:
            anchor_img = Image.open(anchor_item['img_path']).convert('RGB')
        except Exception as e:
            print(f"Error loading anchor image {anchor_item['img_path']}: {e}")
            anchor_img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        try:
            pos_img = Image.open(self.data[pos_idx]['img_path']).convert('RGB')
        except Exception as e:
            print(f"Error loading positive image {self.data[pos_idx]['img_path']}: {e}")
            pos_img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        try:
            neg_img = Image.open(self.data[neg_idx]['img_path']).convert('RGB')
        except Exception as e:
            print(f"Error loading negative image {self.data[neg_idx]['img_path']}: {e}")
            neg_img = Image.new('RGB', (224, 224), color=(0, 0, 0))

        if self.transform:
            anchor_img = self.transform(anchor_img)
            pos_img = self.transform(pos_img)
            neg_img = self.transform(neg_img)

        anchor_vehicle_idx = self.vehicle_id_to_idx[anchor_vehicle_id]
        anchor_camera_idx = self.camera_id_to_idx[anchor_item['camera_id']]

        pos_vehicle_id_raw = self.data[pos_idx]['vehicle_id']
        pos_camera_id_raw = self.data[pos_idx]['camera_id']
        pos_vehicle_idx = self.vehicle_id_to_idx[pos_vehicle_id_raw]
        pos_camera_idx = self.camera_id_to_idx[pos_camera_id_raw]

        neg_vehicle_id_raw = self.data[neg_idx]['vehicle_id']
        neg_camera_id_raw = self.data[neg_idx]['camera_id']
        neg_vehicle_idx = self.vehicle_id_to_idx[neg_vehicle_id_raw]
        neg_camera_idx = self.camera_id_to_idx[neg_camera_id_raw]

        return {
            'anchor': anchor_img,
            'positive': pos_img,
            'negative': neg_img,
            'anchor_vehicle_id': anchor_vehicle_idx,
            'anchor_camera_id_idx': anchor_camera_idx,
            'positive_vehicle_id': pos_vehicle_idx,
            'positive_camera_id_idx': pos_camera_idx,
            'negative_vehicle_id': neg_vehicle_idx,
            'negative_camera_id_idx': neg_camera_idx,
        }

class PartBasedModel(nn.Module):
    def __init__(self, backbone, num_parts=6, feature_dim=2048):
        super(PartBasedModel, self).__init__()
        self.backbone = backbone
        self.num_parts = num_parts
        self.feature_dim = feature_dim

        self.backbone.fc = nn.Identity()

        self.part_pool = nn.AdaptiveAvgPool2d((num_parts, 1))
        self.part_bn = nn.BatchNorm1d(feature_dim * num_parts)
        self.part_fc = nn.Linear(feature_dim * num_parts, feature_dim)

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.global_bn = nn.BatchNorm1d(feature_dim)
        self.global_fc = nn.Linear(feature_dim, feature_dim)

    def forward(self, x):
        features = self.backbone.conv1(x)
        features = self.backbone.bn1(features)
        features = self.backbone.relu(features)
        features = self.backbone.maxpool(features)
        features = self.backbone.layer1(features)
        features = self.backbone.layer2(features)
        features = self.backbone.layer3(features)
        features = self.backbone.layer4(features)

        global_feat = self.global_pool(features).flatten(1)
        global_feat = self.global_bn(global_feat)
        global_feat = self.global_fc(global_feat)

        part_feat = self.part_pool(features)
        part_feat = part_feat.flatten(1)
        part_feat = self.part_bn(part_feat)
        part_feat = self.part_fc(part_feat)

        combined_feat = global_feat + part_feat
        combined_feat = F.normalize(combined_feat, p=2, dim=1)

        return combined_feat, global_feat, part_feat

class VehicleReIDModel(nn.Module):
    def __init__(self, num_vehicles, feature_dim=2048, num_parts=6):
        super(VehicleReIDModel, self).__init__()
        backbone = models.resnet50(pretrained=True)
        self.part_model = PartBasedModel(backbone, num_parts, feature_dim)
        self.classifier = nn.Linear(feature_dim, num_vehicles)
        self.feature_dim = feature_dim

    def forward(self, x):
        combined_feat, global_feat, part_feat = self.part_model(x)
        if self.training:
            logits = self.classifier(combined_feat)
            return combined_feat, logits
        else:
            return combined_feat

class TripletLoss(nn.Module):
    def __init__(self, margin=0.3):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = F.pairwise_distance(anchor, positive, p=2)
        neg_dist = F.pairwise_distance(anchor, negative, p=2)
        loss = F.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()

class SpatioTemporalFilter:
    def __init__(self, camera_distances_file=None):
        self.camera_distances = None
        if camera_distances_file:
            self._load_camera_distances(camera_distances_file)

    def _load_camera_distances(self, file_path):
        """Load camera distances from camera_Dist.txt into a numpy array."""
        with open(file_path, 'r') as f:
            lines = f.readlines()
        self.camera_distances = np.array([list(map(float, line.strip().split())) for line in lines])

    def filter_results(self, query_camera_idx, gallery_camera_idxs, distances):
        """Filter distances based on camera indices, penalizing same-camera matches."""
        filtered_distances = distances.copy()
        for i, gallery_cam_idx in enumerate(gallery_camera_idxs):
            if query_camera_idx == gallery_cam_idx:
                filtered_distances[i] *= 1.5  # Penalize same-camera matches
        return filtered_distances

class GraphReRanking:
    def __init__(self, k1=20, k2=6, lambda_value=0.3):
        self.k1 = k1
        self.k2 = k2
        self.lambda_value = lambda_value

    def re_rank(self, query_features, gallery_features):
        """Re-rank the gallery images based on initial distance matrix."""
        dist_matrix = self._compute_distance_matrix(query_features, gallery_features)
        re_ranked_dist = self._k_reciprocal_rerank(dist_matrix)
        return re_ranked_dist

    def _compute_distance_matrix(self, query_feat, gallery_feat):
        q_feat = F.normalize(query_feat, p=2, dim=1)
        g_feat = F.normalize(gallery_feat, p=2, dim=1)
        dist_matrix = torch.cdist(q_feat, g_feat, p=2)
        return dist_matrix.cpu().numpy()

    def _k_reciprocal_rerank(self, dist_matrix):
        # Placeholder: Currently returns original distances
        # Future implementation could use k-reciprocal encoding
        return dist_matrix

class EvaluationMetrics:
    @staticmethod
    def compute_mAP(query_labels, gallery_labels, query_cameras, gallery_cameras, dist_matrix):
        """Compute mean Average Precision using a precomputed distance matrix."""
        mAPs = []
        for i in range(len(query_labels)):
            query_label = query_labels[i]
            query_camera = query_cameras[i]
            distances = dist_matrix[i]
            gt_matches = (gallery_labels == query_label) & (gallery_cameras != query_camera)
            if gt_matches.sum() == 0:
                continue
            sorted_indices = np.argsort(distances)
            sorted_gt = gt_matches[sorted_indices]
            ap = average_precision_score(sorted_gt, -distances[sorted_indices])
            mAPs.append(ap)
        return np.mean(mAPs) if mAPs else 0.0

    @staticmethod
    def compute_CMC(query_labels, gallery_labels, query_cameras, gallery_cameras, dist_matrix, ranks=[1, 5, 10]):
        """Compute CMC scores using a precomputed distance matrix."""
        cmc_scores = {rank: 0 for rank in ranks}
        valid_queries = 0
        for i in range(len(query_labels)):
            query_label = query_labels[i]
            query_camera = query_cameras[i]
            distances = dist_matrix[i]
            gt_matches = (gallery_labels == query_label) & (gallery_cameras != query_camera)
            if gt_matches.sum() == 0:
                continue
            valid_queries += 1
            sorted_indices = np.argsort(distances)
            sorted_gt = gt_matches[sorted_indices]
            for rank in ranks:
                if sorted_gt[:rank].sum() > 0:
                    cmc_scores[rank] += 1
        for rank in ranks:
            cmc_scores[rank] /= valid_queries if valid_queries > 0 else 1
        return cmc_scores

def get_transforms(is_training=True):
    if is_training:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

def train_model(model, train_loader, val_loader, device, num_epochs=100):
    """Training loop with triplet loss and classification loss."""
    triplet_loss_fn = TripletLoss(margin=0.3)
    ce_loss_fn = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            if 'anchor' in batch:
                anchor = batch['anchor'].to(device)
                positive = batch['positive'].to(device)
                negative = batch['negative'].to(device)

                anchor_feat, _ = model(anchor)
                pos_feat, _ = model(positive)
                neg_feat, _ = model(negative)

                loss = triplet_loss_fn(anchor_feat, pos_feat, neg_feat)
            else:
                images = batch['img'].to(device)
                labels = batch['vehicle_id']
                if isinstance(labels, torch.Tensor):
                    labels = labels.to(device)
                else:
                    labels = torch.tensor(labels, dtype=torch.long, device=device)

                features, logits = model(images)
                loss = ce_loss_fn(logits, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        scheduler.step()
        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0.0
        print(f'Epoch {epoch} completed, Average Loss: {avg_loss:.4f}')

def evaluate_model(model, query_loader, gallery_loader, device, data_dir):
    """Evaluate the model using mAP and CMC metrics with spatio-temporal filtering and re-ranking."""
    model.eval()
    model.to(device)

    query_features, query_labels = [], []
    gallery_features, gallery_labels = [], []
    query_cameras, gallery_cameras = [], []

    with torch.no_grad():
        for batch in query_loader:
            images = batch['img'].to(device)
            feats = model(images)
            feats = feats.cpu()
            query_features.append(feats)
            query_labels.extend(batch['vehicle_id'].cpu().tolist() if isinstance(batch['vehicle_id'], torch.Tensor) else batch['vehicle_id'])
            query_cameras.extend(batch['camera_id_idx'].cpu().tolist() if isinstance(batch['camera_id_idx'], torch.Tensor) else batch['camera_id_idx'])

        for batch in gallery_loader:
            images = batch['img'].to(device)
            feats = model(images)
            feats = feats.cpu()
            gallery_features.append(feats)
            gallery_labels.extend(batch['vehicle_id'].cpu().tolist() if isinstance(batch['vehicle_id'], torch.Tensor) else batch['vehicle_id'])
            gallery_cameras.extend(batch['camera_id_idx'].cpu().tolist() if isinstance(batch['camera_id_idx'], torch.Tensor) else batch['camera_id_idx'])

    if len(query_features) > 0:
        query_features = torch.cat(query_features, dim=0)
    else:
        query_features = torch.empty((0, model.feature_dim))

    if len(gallery_features) > 0:
        gallery_features = torch.cat(gallery_features, dim=0)
    else:
        gallery_features = torch.empty((0, model.feature_dim))

    # Convert to numpy arrays
    query_labels_np = np.array(query_labels)
    gallery_labels_np = np.array(gallery_labels)
    query_cameras_np = np.array(query_cameras)
    gallery_cameras_np = np.array(gallery_cameras)

    # Initialize reranker and spatio-temporal filter
    reranker = GraphReRanking()
    st_filter = SpatioTemporalFilter(camera_distances_file=os.path.join(data_dir, 'camera_Dist.txt'))

    # Compute re-ranked distance matrix
    re_ranked_dist = reranker.re_rank(query_features, gallery_features)

    # Apply spatio-temporal filtering
    filtered_dist_matrix = np.zeros_like(re_ranked_dist)
    for i in range(len(query_cameras_np)):
        query_camera_idx = query_cameras_np[i]
        filtered_dist_matrix[i] = st_filter.filter_results(query_camera_idx, gallery_cameras_np, re_ranked_dist[i])

    # Compute metrics
    mAP = EvaluationMetrics.compute_mAP(
        query_labels_np, gallery_labels_np, query_cameras_np, gallery_cameras_np, filtered_dist_matrix
    )
    cmc_scores = EvaluationMetrics.compute_CMC(
        query_labels_np, gallery_labels_np, query_cameras_np, gallery_cameras_np, filtered_dist_matrix
    )

    print(f'mAP: {mAP:.4f}')
    print(f'CMC Rank-1: {cmc_scores.get(1, 0):.4f}')
    print(f'CMC Rank-5: {cmc_scores.get(5, 0):.4f}')
    print(f'CMC Rank-10: {cmc_scores.get(10, 0):.4f}')

    return mAP, cmc_scores

In [None]:
# Configuration
data_dir = '/kaggle/input/veri-vehicle-re-identification-dataset/VeRi'
batch_size = 48
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data transforms
train_transform = get_transforms(is_training=True)
test_transform = get_transforms(is_training=False)

# Datasets
train_dataset = VeRiDataset(data_dir, split='train', transform=train_transform,
                           triplet_sampling=True)
query_dataset = VeRiDataset(data_dir, split='query', transform=test_transform)
test_dataset = VeRiDataset(data_dir, split='test', transform=test_transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
query_loader = DataLoader(query_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Get number of unique vehicles for classification
all_vehicle_ids = set()
for item in train_dataset.data:
    all_vehicle_ids.add(item['vehicle_id'])
num_vehicles = len(all_vehicle_ids)

# Model
model = VehicleReIDModel(num_vehicles=num_vehicles, feature_dim=2048, num_parts=6)

print(f"Training samples: {len(train_dataset)}")
print(f"Query samples: {len(query_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of vehicles: {num_vehicles}")
print(f"Device: {device}")

In [None]:
# Training
print("Starting training...")
train_model(model, train_loader, query_loader, device, num_epochs)

In [None]:
# Evaluation
print("Starting evaluation...")
mAP, cmc_scores = evaluate_model(model, query_loader, test_loader, device, data_dir)

In [None]:
# Save model
torch.save(model.state_dict(), 'veri_reid_model.pth')
print("Model saved as 'veri_reid_model.pth'")