# Person Re-Identification Training

## 1. Environment Setup

In [None]:
import google.colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd master-thesis-reid

In [None]:
!pip install -q -r requirements_colab.txt

In [None]:
!pip install torchreid

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import yaml
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

project_root = Path('/content/master-thesis-reid')
DATA_ROOT = "/content/drive/MyDrive/reid_data"
MODEL_ROOT = "/content/drive/MyDrive/reid_models"
RESULTS_ROOT = "/content/drive/MyDrive/reid_results"

sys.path.insert(0, str(project_root))

os.makedirs(MODEL_ROOT, exist_ok=True)
os.makedirs(RESULTS_ROOT, exist_ok=True)

print(f"Project root: {project_root}")
print(f"Data root: {DATA_ROOT}")
print(f"Model root: {MODEL_ROOT}")
print(f"Results root: {RESULTS_ROOT}")

In [None]:
!pip install torchreid

In [None]:
from utils.data_loader import get_dataloaders_from_config
from models.person import *
import torchreid

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Load YAML Configurations

In [None]:
config_dir = project_root / 'config'


with open(config_dir / 'train_experiments.yaml') as f:
    experiments_config = yaml.safe_load(f)
    print("Loaded train_experiments.yaml")

with open(config_dir / 'datasets.yaml') as f:
    dataset_config = yaml.safe_load(f)
    print("Loaded datasets.yaml")

print("\nAvailable person ReID models:")
for model_name in experiments_config['person_reid_experiments'].keys():
    print(f"  - {model_name}")

print("\nAvailable vehicle ReID models:")
for model_name in experiments_config['vehicle_reid_experiments'].keys():
    print(f"  - {model_name}")

print("\nBatch experiment configurations:")
for model_name, config in experiments_config['batch_experiments'].items():
    if model_name != 'global_settings' and config.get('enabled'):
        print(f"  - {model_name}: datasets={config['datasets']}, k_shots={config['k_shots']}")

## 4. Training Configuration

**Edit this cell to configure your training:**

In [None]:
RUN_MODE = 'batch'  

if RUN_MODE == 'manual':
    
    MODELS = ['pcb_p6']
    DATASETS = ['market1501']
    K_SHOTS = [16]
    DATA_TYPES = ['preprocessed']
    LOSSES = ['softmax']
else:
    
    
    MODELS = []
    DATASETS = []
    K_SHOTS = []
    DATA_TYPES = []
    LOSSES = []

all_configs = []

if RUN_MODE == 'batch':
    
    batch_config = experiments_config['batch_experiments']
    
    for model_name, model_batch_config in batch_config.items():
        if model_name == 'global_settings':
            continue
            
        if not model_batch_config.get('enabled', True):
            print(f"Skipping {model_name} (disabled)")
            continue
        
        if model_name in experiments_config['person_reid_experiments']:
            base_config = experiments_config['person_reid_experiments'][model_name]
            reid_type = 'person'
        elif model_name in experiments_config['vehicle_reid_experiments']:
            base_config = experiments_config['vehicle_reid_experiments'][model_name]
            reid_type = 'vehicle'
        else:
            print(f"Warning: {model_name} not found in experiments config")
            continue
        
        
        for dataset in model_batch_config['datasets']:
            for k_shot in model_batch_config['k_shots']:
                for data_type in model_batch_config['data_types']:
                    for loss in model_batch_config['losses']:
                        
                        config = {
                            'model': model_name,
                            'dataset': dataset,
                            'data_type': data_type,
                            'k_shot': k_shot,
                            'loss': loss,
                            'reid_type': reid_type,
                            
                            
                            'num_epochs': base_config['training']['num_epochs'],
                            'batch_size': base_config['training']['batch_size'],
                            'learning_rate': base_config['training']['learning_rate'],
                            'weight_decay': base_config['training']['weight_decay'],
                            'optimizer': base_config['training']['optimizer'],
                            'lr_scheduler': base_config['training']['lr_scheduler'],
                            'warmup_epochs': base_config['training'].get('warmup_epochs', 10),
                            'lr_milestones': base_config['training'].get('lr_milestones', [40, 60]),
                            'lr_gamma': base_config['training'].get('lr_gamma', 0.1),
                            'momentum': base_config['training'].get('momentum', 0.9),
                            'mixed_precision': base_config['training'].get('mixed_precision', True),
                            'save_freq': base_config['training'].get('save_freq', 10),
                            
                            
                            'img_height': base_config['training'].get('img_height', 256),
                            'img_width': base_config['training'].get('img_width', 128),
                            
                            
                            'device': 'cuda' if torch.cuda.is_available() else 'cpu',
                        }
                        
                        
                        dataset_suffix = {
                            'preprocessed': '_preprocessed',
                            'augmented': '_augmented',
                            'original': ''
                        }[data_type]
                        
                        config['dataset_path'] = os.path.join(
                            experiments_config['paths']['data_root'],
                            dataset + dataset_suffix
                        )
                        
                        
                        model_save_name = f"{model_name}_{dataset}_{data_type}_k{k_shot}_l{loss}"
                        config['model_save_dir'] = os.path.join(
                            experiments_config['paths']['model_root'],
                            model_save_name
                        )
                        os.makedirs(config['model_save_dir'], exist_ok=True)
                        
                        all_configs.append(config)

else:
    for model in MODELS:
        for dataset in DATASETS:
            for k_shot in K_SHOTS:
                for data_type in DATA_TYPES:
                    for loss in LOSSES:
                        config = {
                            'model': model,
                            'dataset': dataset,
                            'data_type': data_type,
                            'k_shot': k_shot,
                            'loss': loss,
                            'num_epochs': 60,
                            'batch_size': 32,
                            'learning_rate': 0.00035,
                            'weight_decay': 0.0005,
                            'optimizer': 'adamw',
                            'lr_scheduler': 'warmup_cosine',
                            'warmup_epochs': 10,
                            'lr_milestones': [40, 60],
                            'lr_gamma': 0.1,
                            'momentum': 0.9,
                            'mixed_precision': True,
                            'save_freq': 10,
                            'img_height': 256,
                            'img_width': 128,
                            'device': 'cuda' if torch.cuda.is_available() else 'cpu',
                        }
                        
                        dataset_suffix = {
                            'preprocessed': '_preprocessed',
                            'augmented': '_augmented',
                            'original': ''
                        }[data_type]
                        
                        config['dataset_path'] = os.path.join(DATA_ROOT, dataset + dataset_suffix)
                        
                        model_save_name = f"{model}_{dataset}_{data_type}_k{k_shot}_l{loss}"
                        config['model_save_dir'] = os.path.join(MODEL_ROOT, model_save_name)
                        os.makedirs(config['model_save_dir'], exist_ok=True)
                        
                        all_configs.append(config)


print(f"\nTotal number of configurations to run: {len(all_configs)}")
print("\nConfigurations:")
print("=" * 90)
print(f"{'#':<4} {'Model':<20} {'Dataset':<15} {'K-shot':<8} {'Data Type':<12} {'Loss':<10}")
print("=" * 90)
for i, config in enumerate(all_configs, 1):
    print(f"{i:<4} {config['model']:<20} {config['dataset']:<15} {config['k_shot']:<8} {config['data_type']:<12} {config['loss']:<10}")
print("=" * 90)

In [None]:
print("Checking dataset paths...")
for config in all_configs:
    if not os.path.exists(config['dataset_path']):
        print(f"WARNING: Dataset not found at {config['dataset_path']}")
    else:
        print(f"{config['dataset']}: {config['dataset_path']}")

print("\nAll dataset checks complete!")

In [None]:
def build_model(model_name, num_classes, loss='softmax'):
    """
    Build a person ReID model using wrapper functions from models.person.

    Available models:
    - 'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0': OSNet variants
    - 'pcb_p6', 'pcb_p4': PCB with 6 or 4 parts
    - 'hacnn': HACNN
    - 'transreid_base', 'transreid_small': TransReID variants
    - 'autoreid_plus': Auto-ReID+

    Args:
        model_name: Name of the model
        num_classes: Number of identities for classification
        loss: Loss type ('softmax', 'triplet')

    Returns:
        Model instance
    """

    if loss in ['softmax+triplet', 'triplet+softmax']:
        loss = 'softmax'

    if model_name == 'osnet_x1_0':
        model = osnet_x1_0(num_classes=num_classes, loss=loss, pretrained=True)

    elif model_name == 'pcb_p6':
        model = pcb_p6(num_classes=num_classes, loss=loss, pretrained=True)
    elif model_name == 'pcb_p4':
        model = pcb_p4(num_classes=num_classes, loss=loss, pretrained=True)

    elif model_name == 'hacnn':
        model = hacnn(num_classes=num_classes, loss=loss, pretrained=True)

    elif model_name == 'transreid_base' or model_name == 'transreid' or model_name == 'transreid_base_vehicle':
        model = transreid_base(num_classes=num_classes, loss=loss, pretrained=True)

    elif model_name == 'autoreid_plus' or model_name == 'autoreid':
        model = autoreid_plus(num_classes=num_classes, loss=loss)

    else:
        raise ValueError(f"Unknown model: {model_name}. Available models: "
                        "osnet_x1_0, pcb_p6, pcb_p4, hacnn, transreid_base, autoreid_plus")

    return model

In [None]:
def ensure_float32_for_triplet(outputs):
    if isinstance(outputs, (tuple, list)):
        return tuple(o.float() if torch.is_tensor(o) else o for o in outputs)
    elif torch.is_tensor(outputs):
        return outputs.float()
    else:
        return outputs

<h2>Model-Specific Configuration</h2>

In [None]:
MODEL_TRIPLET_WEIGHTS = {
    
    'osnet_x1_0': 1.0,
    'osnet_x0_75': 1.0,
    'osnet_x0_5': 1.0,
    'osnet_x0_25': 1.0,
    'osnet_ibn_x1_0': 1.0,
    'pcb_p6': 1.0,
    'pcb_p4': 1.0,

    
    'hacnn': 0.5,
    'aaver': 0.5,
    'rptm': 0.5,

    
    'transreid_base': 0.3,
    'transreid_small': 0.3,

    
    'autoreid_plus': 0.5,
    'autoreid': 0.5,

    
    'vat': 0.3,

    
    'resnet50_vehicle': 1.0,
}

def get_triplet_weight(model_name):
    """Get appropriate triplet loss weight for model"""
    return MODEL_TRIPLET_WEIGHTS.get(model_name, 1.0)

print("Model-specific triplet weights loaded")
print(f"TransReID lambda_tri: {get_triplet_weight('transreid_base')}")
print(f"OSNet lambda_tri: {get_triplet_weight('osnet_x1_0')}")



<h2>Helper functions</h2>

In [None]:
def compute_loss_single_output(outputs, pids, loss_type, criterion_ce,
                               criterion_triplet, lambda_tri):
    """
    Compute loss for models with single output (OSNet, TransReID with softmax)

    Returns: (loss, logits, features)
    """
    if loss_type == 'softmax':
        return criterion_ce(outputs, pids), outputs, None

    elif loss_type == 'triplet':
        outputs_f32 = ensure_float32_for_triplet(outputs)
        return criterion_triplet(outputs_f32, pids), None, outputs

    else:  # 'softmax+triplet'
        outputs_f32 = ensure_float32_for_triplet(outputs)
        loss_ce = criterion_ce(outputs, pids)
        loss_tri = criterion_triplet(outputs_f32, pids)
        loss = loss_ce + lambda_tri * loss_tri
        return loss, outputs, outputs
    
def compute_loss_deep_supervision(outputs, pids, loss_type, criterion_ce,
                                 criterion_triplet, lambda_tri):
    """
    Compute loss for DeepSupervision models (PCB, HACNN, AutoReID+, RPTM, AAVER)

    Returns: (loss, logits, features)
    """
    if loss_type == 'softmax':
        loss = torchreid.losses.DeepSupervision(criterion_ce, outputs, pids)
        return loss, outputs[0], None

    elif loss_type == 'triplet':
        outputs_f32 = ensure_float32_for_triplet(outputs)
        loss = torchreid.losses.DeepSupervision(criterion_triplet, outputs_f32, pids)
        return loss, None, outputs[0]

    else:  # 'softmax+triplet'
        loss_ce = torchreid.losses.DeepSupervision(criterion_ce, outputs, pids)
        outputs_f32 = ensure_float32_for_triplet(outputs)
        loss_tri = torchreid.losses.DeepSupervision(criterion_triplet, outputs_f32, pids)
        loss = loss_ce + lambda_tri * loss_tri
        return loss, outputs[0], outputs[0]
    
def compute_loss_vat(outputs, pids, loss_type, criterion_ce,
                    criterion_triplet, lambda_tri):
    """
    Compute loss for VAT model (multi-task: id, color, type)

    Returns: (loss, logits, features)
    """
    id_logits = outputs[0]  

    if loss_type == 'softmax':
        return criterion_ce(id_logits, pids), id_logits, None

    elif loss_type == 'triplet':
        id_logits_f32 = ensure_float32_for_triplet(id_logits)
        return criterion_triplet(id_logits_f32, pids), None, id_logits

    else:  # 'softmax+triplet'
        id_logits_f32 = ensure_float32_for_triplet(id_logits)
        loss_ce = criterion_ce(id_logits, pids)
        loss_tri = criterion_triplet(id_logits_f32, pids)
        loss = loss_ce + lambda_tri * loss_tri
        return loss, id_logits, id_logits
    

def compute_loss_tuple_standard(outputs, pids, loss_type, criterion_ce,
                                criterion_triplet, lambda_tri):
    """
    Compute loss for standard tuple output: (features, logits)

    Returns: (loss, logits, features)
    """
    features, logits = outputs[0], outputs[1]

    if loss_type == 'softmax':
        return criterion_ce(logits, pids), logits, None

    elif loss_type == 'triplet':
        features_f32 = ensure_float32_for_triplet(features)
        return criterion_triplet(features_f32, pids), None, features

    else:  # 'softmax+triplet'
        features_f32 = ensure_float32_for_triplet(features)
        loss_ce = criterion_ce(logits, pids)
        loss_tri = criterion_triplet(features_f32, pids)
        loss = loss_ce + lambda_tri * loss_tri
        return loss, logits, features


def compute_loss(outputs, pids, config, criterion_ce, criterion_triplet):
    """
    Unified loss computation for all models

    Args:
        outputs: Model outputs
        pids: Person IDs
        config: Training configuration dict
        criterion_ce: CrossEntropy loss
        criterion_triplet: Triplet loss

    Returns:
        loss: Total loss
        logits: Logits for accuracy (or None)
        features: Features for triplet (or None)
    """
    model_name = config['model']
    loss_type = config['loss']
    lambda_tri = get_triplet_weight(model_name)

    
    if not isinstance(outputs, (tuple, list)):
        return compute_loss_single_output(
            outputs, pids, loss_type, criterion_ce, criterion_triplet, lambda_tri
        )

    
    if model_name in ['pcb_p6', 'pcb_p4', 'autoreid_plus', 'hacnn', 'rptm', 'aaver']:
        return compute_loss_deep_supervision(
            outputs, pids, loss_type, criterion_ce, criterion_triplet, lambda_tri
        )

    
    if model_name == 'vat':
        return compute_loss_vat(
            outputs, pids, loss_type, criterion_ce, criterion_triplet, lambda_tri
        )

    
    return compute_loss_tuple_standard(
        outputs, pids, loss_type, criterion_ce, criterion_triplet, lambda_tri
    )
    
def compute_accuracy(logits, features, pids, loss_type):
    """
    Compute training accuracy

    Returns: (correct, total)
    """
    if 'softmax' in loss_type and logits is not None:
        _, predicted = logits.max(1)
        total = pids.size(0)
        correct = predicted.eq(pids).sum().item()
        return correct, total

    elif loss_type == 'triplet' and features is not None:
        
        dist_mat = torch.cdist(features, features, p=2)
        dist_mat_temp = dist_mat.clone()
        dist_mat_temp.fill_diagonal_(float('inf'))
        nearest_neighbors = dist_mat_temp.argmin(dim=1)

        total = pids.size(0)
        correct = (pids[nearest_neighbors] == pids).sum().item()
        return correct, total

    return 0, 0



In [None]:
for config_idx, CONFIG in enumerate(all_configs, 1):
    final_model_path = os.path.join(CONFIG['model_save_dir'], 'final_model.pth')
    if os.path.exists(final_model_path):
        print(f"\n[Skipping] {CONFIG['model']} (already trained)")
        continue

    print("\n" + "=" * 80)
    print(f"CONFIGURATION {config_idx}/{len(all_configs)}")
    print("=" * 80)
    print(f"Model: {CONFIG['model']}, Dataset: {CONFIG['dataset']}, "
          f"K-shot: {CONFIG['k_shot']}, Loss: {CONFIG['loss']}")

    if 'triplet' in CONFIG['loss']:
        lambda_tri = get_triplet_weight(CONFIG['model'])
        print(f"Triplet weight (lambda_tri): {lambda_tri}")

    print("=" * 80)

    try:
        print(f"\n[1/7] Loading dataset...")

        img_height, img_width = CONFIG['img_height'], CONFIG['img_width']

        dataset_suffix = {
            'preprocessed': '_preprocessed',
            'augmented': '_augmented',
            'original': ''
        }[CONFIG['data_type']]

        dataloaders = get_dataloaders_from_config(
            root=experiments_config['paths']['data_root'],
            dataset_name=CONFIG['dataset'] + dataset_suffix,
            config_dir=str(project_root / 'config'),
            model_name=CONFIG['model'],
            model_type=CONFIG['reid_type'],
            data_type=CONFIG['data_type'],
            k_shot=CONFIG['k_shot']
        )

        train_loader = dataloaders['train']
        query_loader = dataloaders['query']
        gallery_loader = dataloaders['gallery']

        num_train_pids = train_loader.dataset.num_pids
        print(f"Identities: {num_train_pids}, Train: {len(train_loader.dataset)}, "
              f"Query: {len(query_loader.dataset)}, Gallery: {len(gallery_loader.dataset)}")

        print(f"\n[2/7] Setting up data loader...")

        pid2label = train_loader.dataset.pid2label

        def collate_fn_with_mapping(batch):
            imgs = torch.stack([item[0] for item in batch])
            pids = torch.tensor([pid2label[item[1]] for item in batch], dtype=torch.long)
            camids = torch.tensor([item[2] for item in batch], dtype=torch.long)
            return imgs, pids, camids

        train_loader = torch.utils.data.DataLoader(
            train_loader.dataset,
            batch_size=train_loader.batch_size,
            sampler=train_loader.sampler,
            num_workers=train_loader.num_workers,
            pin_memory=train_loader.pin_memory,
            drop_last=True,
            collate_fn=collate_fn_with_mapping
        )

        print(f"\n[3/7] Building model...")

        model = build_model(CONFIG['model'], num_train_pids, loss=CONFIG['loss'])
        model = model.to(CONFIG['device'])

        num_params = sum(p.numel() for p in model.parameters())
        print(f"Parameters: {num_params:,}")

        print(f"\n[4/7] Setting up training...")

        criterion_ce = torchreid.losses.CrossEntropyLoss(
            num_classes=num_train_pids,
            use_gpu=True,
            label_smooth=True
        )
        criterion_triplet = torchreid.losses.TripletLoss(margin=0.3)

        if CONFIG['optimizer'] == 'sgd':
            optimizer = optim.SGD(
                model.parameters(),
                lr=CONFIG['learning_rate'],
                momentum=CONFIG['momentum'],
                weight_decay=CONFIG['weight_decay']
            )
        else:
            optimizer = optim.AdamW(
                model.parameters(),
                lr=CONFIG['learning_rate'],
                weight_decay=CONFIG['weight_decay']
            )

        if CONFIG['lr_scheduler'] == 'multistep':
            scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=CONFIG['lr_milestones'],
                gamma=CONFIG['lr_gamma']
            )
        elif CONFIG['lr_scheduler'] == 'warmup_cosine':
            from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
            warmup = LinearLR(
                optimizer,
                start_factor=0.01,
                end_factor=1.0,
                total_iters=CONFIG['warmup_epochs']
            )
            cosine = CosineAnnealingLR(
                optimizer,
                T_max=CONFIG['num_epochs'] - CONFIG['warmup_epochs']
            )
            scheduler = SequentialLR(
                optimizer,
                schedulers=[warmup, cosine],
                milestones=[CONFIG['warmup_epochs']]
            )
        else:
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

        scaler = torch.amp.GradScaler('cuda')
        print(f"Optimizer: {CONFIG['optimizer']}, LR: {CONFIG['learning_rate']}, "
              f"Scheduler: {CONFIG['lr_scheduler']}, Epochs: {CONFIG['num_epochs']}")

        print(f"\n[5/7] Training model...")

        history = {'train_loss': [], 'train_acc': [], 'learning_rate': []}
        best_acc = 0.0

        for epoch in range(1, CONFIG['num_epochs'] + 1):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{CONFIG["num_epochs"]}')
            for batch_idx, (imgs, pids, camids) in enumerate(pbar):
                imgs = imgs.to(CONFIG['device'])
                pids = pids.to(CONFIG['device'])

                optimizer.zero_grad()

                with torch.amp.autocast('cuda'):
                    if CONFIG['model'] in ['transreid_base', 'transreid_small']:
                        camids_zero_based = camids - 1
                        outputs = model(imgs, cam_label=camids_zero_based)
                    else:
                        outputs = model(imgs)

                    loss, logits, features = compute_loss(
                        outputs, pids, CONFIG, criterion_ce, criterion_triplet
                    )

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                running_loss += loss.item()

                batch_correct, batch_total = compute_accuracy(
                    logits, features, pids, CONFIG['loss']
                )
                correct += batch_correct
                total += batch_total

                pbar_dict = {'loss': running_loss / (batch_idx + 1)}
                if total > 0:
                    pbar_dict['acc'] = 100. * correct / total
                pbar.set_postfix(pbar_dict)

            scheduler.step()

            epoch_loss = running_loss / len(train_loader)
            epoch_acc = 100. * correct / total if total > 0 else 0.0

            history['train_loss'].append(epoch_loss)
            history['train_acc'].append(epoch_acc)
            history['learning_rate'].append(optimizer.param_groups[0]['lr'])

            # Save best model
            if ((total > 0 and epoch_acc > best_acc) or
                (total == 0 and (epoch == 1 or epoch_loss < min(history['train_loss'][:-1])))):
                best_acc = epoch_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'config': CONFIG,
                    'num_classes': num_train_pids,
                }, os.path.join(CONFIG['model_save_dir'], 'best_model.pth'))

            # Save checkpoint
            if epoch % CONFIG['save_freq'] == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'config': CONFIG,
                    'num_classes': num_train_pids,
                }, os.path.join(CONFIG['model_save_dir'], f"checkpoint_epoch{epoch}.pth"))

        print(f"Training completed! Best accuracy: {best_acc:.2f}%")

        # Save final model
        torch.save({
            'epoch': CONFIG['num_epochs'],
            'model_state_dict': model.state_dict(),
            'config': CONFIG,
            'num_classes': num_train_pids,
            'history': history,
            'best_acc': best_acc,
        }, os.path.join(CONFIG['model_save_dir'], 'final_model.pth'))

        print(f"\n[6/7] Extracting features...")

        model.eval()

        def extract_features(dataloader):
            features_list, pids_list, camids_list = [], [], []
            with torch.no_grad():
                for imgs, pids, camids, _ in tqdm(dataloader, desc='Extracting', leave=False):
                    imgs = imgs.to(CONFIG['device'])
                    feats = model(imgs)

                    if isinstance(feats, tuple):
                        feats = feats[0]

                    feats = F.normalize(feats, p=2, dim=1)
                    features_list.append(feats.cpu())
                    pids_list.append(pids)
                    camids_list.append(camids)

            return (torch.cat(features_list).numpy(),
                    torch.cat(pids_list).numpy(),
                    torch.cat(camids_list).numpy())

        query_features, query_pids, query_camids = extract_features(query_loader)
        gallery_features, gallery_pids, gallery_camids = extract_features(gallery_loader)

        sample_img = torch.randn(1, 3, img_height, img_width).to(CONFIG['device'])

        for _ in range(10):
            with torch.no_grad():
                _ = model(sample_img)

        times = []
        for _ in range(100):
            if CONFIG['device'] == 'cuda':
                torch.cuda.synchronize()
            start = time.time()
            with torch.no_grad():
                _ = model(sample_img)
            if CONFIG['device'] == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)

        avg_time_ms = np.mean(times) * 1000
        fps = 1.0 / np.mean(times)
        print(f"Inference: {avg_time_ms:.2f} ms, {fps:.2f} FPS")

        print(f"\n[7/7] Evaluating model...")

        qf = torch.from_numpy(query_features)
        gf = torch.from_numpy(gallery_features)

        distmat = metrics.compute_distance_matrix(qf, gf, metric='euclidean').numpy()
        distmat_qq = metrics.compute_distance_matrix(qf, qf, metric='euclidean').numpy()
        distmat_gg = metrics.compute_distance_matrix(gf, gf, metric='euclidean').numpy()
        distmat_reranked = torchreid.utils.re_ranking(distmat, distmat_qq, distmat_gg)

        cmc, mAP = metrics.evaluate_rank(
            distmat_reranked,
            query_pids,
            gallery_pids,
            query_camids,
            gallery_camids,
            use_metric_cuhk03=False
        )

        print(f"mAP: {mAP:.2%}, Rank-1: {cmc[0]:.2%}, "
              f"Rank-5: {cmc[4]:.2%}, Rank-10: {cmc[9]:.2%}")

        # Save results
        results = {
            'model': CONFIG['model'],
            'dataset': CONFIG['dataset'],
            'data_type': CONFIG['data_type'],
            'k_shot': CONFIG['k_shot'],
            'loss': CONFIG['loss'],
            'lambda_tri': get_triplet_weight(CONFIG['model']),
            'optimizer': CONFIG['optimizer'],
            'learning_rate': CONFIG['learning_rate'],
            'num_epochs': CONFIG['num_epochs'],
            'mAP': float(mAP),
            'rank1': float(cmc[0]),
            'rank5': float(cmc[4]),
            'rank10': float(cmc[9]),
            'rank20': float(cmc[19]),
            'best_train_acc': float(best_acc),
            'final_train_loss': float(history['train_loss'][-1]),
            'avg_query_time_ms': float(avg_time_ms),
            'fps': float(fps),
            'img_height': img_height,
            'img_width': img_width,
        }

        all_results.append(results)

        model_save_name = (f"{CONFIG['model']}_{CONFIG['dataset']}_"
                          f"{CONFIG['data_type']}_k{CONFIG['k_shot']}")

        # Save history
        with open(os.path.join(RESULTS_ROOT, f"{model_save_name}_history.json"), 'w') as f:
            json.dump(history, f, indent=2)

        
        with open(os.path.join(RESULTS_ROOT, f"{model_save_name}_results.json"), 'w') as f:
            json.dump(results, f, indent=2)

        print(f"\nConfiguration {config_idx}/{len(all_configs)} completed successfully!")

        
        del model, optimizer, scheduler, train_loader, query_loader, gallery_loader
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"\nError in configuration {config_idx}: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

print("\n" + "=" * 80)
print("ALL EXPERIMENTS COMPLETED!")
print("=" * 80)