# ReID Experiments Evaluation

This notebook evaluates saved ReID models using configurations from `experiments.yaml`.
It supports flexible checkpoint selection (final_model.pth, best_model.pth, or specific epoch checkpoints).

## 1. Setup

In [None]:

try:
    import google.colab
    IN_COLAB = True
    from google.colab import drive
    drive.mount('/content/drive')
    %cd /content/master-thesis-reid
    !pip install -q torchreid matplotlib seaborn scipy
except:
    IN_COLAB = False
    print("Running locally")

In [None]:
import sys
import os
import torch
import torch.nn.functional as F
from pathlib import Path
import yaml
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import wilcoxon

if IN_COLAB:
    project_root = Path('/content/master-thesis-reid')
else:
    project_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()

sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

In [None]:
from utils.data_loader import get_dataloaders_from_config
from models.person import *
from models.vehicle import *
import torchreid
from torchreid import metrics

## 2. Load Configuration

In [None]:
config_dir = project_root / 'config'


with open(config_dir / 'experiments.yaml') as f:
    exp_config = yaml.safe_load(f)
    print("Loaded experiments.yaml")


with open(config_dir / 'train_experiments.yaml') as f:
    train_config = yaml.safe_load(f)
    print("Loaded train_experiments.yaml")


MODEL_ROOT = exp_config['global_methodology']['model_root']
DATA_ROOT = exp_config['global_methodology']['data_root']
RESULTS_DIR = exp_config['global_methodology']['results_dir']

os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(os.path.join(RESULTS_DIR, 'visualizations'), exist_ok=True)

print(f"\nModel root: {MODEL_ROOT}")
print(f"Data root: {DATA_ROOT}")
print(f"Results dir: {RESULTS_DIR}")

## 3. Helper Functions

In [None]:
def get_model_type(model_name):
    """Determine if model is for person or vehicle ReID."""
    vehicle_models = ['resnet50_vehicle', 'aaver', 'rptm', 'vat']
    return 'vehicle' if model_name in vehicle_models else 'person'

def build_model(model_name, num_classes, loss='softmax', model_type='person'):
    """Build a ReID model."""
    
    if model_type == 'person':
        if model_name == 'osnet_x1_0':
            return osnet_x1_0(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'pcb_p6':
            return pcb_p6(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'pcb_p4':
            return pcb_p4(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'hacnn':
            return hacnn(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name in ['transreid_base', 'transreid']:
            return transreid_base(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'autoreid_plus':
            return autoreid_plus(num_classes=num_classes, loss=model_loss)
        else:
            raise ValueError(f"Unknown person ReID model: {model_name}")
    elif model_type == 'vehicle':
        if model_name == 'resnet50_vehicle':
            return resnet50_vehicle(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'aaver':
            return aaver(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'rptm':
            return rptm(num_classes=num_classes, loss=model_loss, pretrained=True)
        elif model_name == 'vat':
            return vat(num_classes=num_classes, loss=model_loss, pretrained=True)
        else:
            raise ValueError(f"Unknown vehicle ReID model: {model_name}")
    else:
        raise ValueError(f"Unknown model type: {model_type}")

def extract_features(model, dataloader, device):
    """Extract features from dataloader."""
    features_list, pids_list, camids_list, img_paths_list, times = [], [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for imgs, pids, camids, img_paths in tqdm(dataloader, desc='Extracting', leave=False):
            imgs = imgs.to(device)
            
            if device == 'cuda':
                torch.cuda.synchronize()
            start = time.time()
            
            outputs = model(imgs)
            
            if device == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)
            
            
            if isinstance(outputs, tuple):
                feats = outputs[0]
            else:
                feats = outputs
            
            feats = F.normalize(feats, p=2, dim=1)
            features_list.append(feats.cpu())
            pids_list.append(pids)
            camids_list.append(camids)
            img_paths_list.extend(img_paths)
    
    return (
        torch.cat(features_list).numpy(),
        torch.cat(pids_list).numpy(),
        torch.cat(camids_list).numpy(),
        img_paths_list,
        times
    )

## 4. Configure Experiments to Run


In [None]:

experiments = exp_config.get('experiment_evaluations', {})

enabled_experiments = {name: cfg for name, cfg in experiments.items() if cfg.get('enabled', True)}

print(f"Total experiments to run: {len(enabled_experiments)}\n")
for exp_name, exp_cfg in enabled_experiments.items():
    print(f"\n{exp_name}:")
    print(f"  Dataset: {exp_cfg['dataset']}, K={exp_cfg['k_shot']}, Data: {exp_cfg['data_type']}")
    print(f"  Models ({len(exp_cfg['models'])}):")
    for model in exp_cfg['models']:
        print(f"    - {model['name']}: {model['model']} ({model['loss']}, {model['checkpoint']})")

## 5. Run Evaluations

In [None]:
print("=" * 80)
print("RUNNING EXPERIMENTS")
print("=" * 80)

all_results = []
all_error_data = {}  
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for exp_idx, (exp_name, exp_cfg) in enumerate(enabled_experiments.items(), 1):
    if not exp_cfg.get('enabled', True):
        continue
    
    print(f"\n{'='*80}")
    print(f"EXPERIMENT {exp_idx}/{len(enabled_experiments)}: {exp_name}")
    print(f"{'='*80}")
    print(f"Dataset: {exp_cfg['dataset']}, K-shot: {exp_cfg['k_shot']}, Data type: {exp_cfg['data_type']}")
    print(f"Models to evaluate: {len(exp_cfg['models'])}")
    print(f"{'='*80}\n")
    
    experiment_results = []
    
    dataset_suffix = '_preprocessed'
    
    dataset_name_with_suffix = exp_cfg['dataset'] + dataset_suffix
    
    for model_idx, model_cfg in enumerate(exp_cfg['models'], 1):
        model_type = get_model_type(model_cfg['model'])
        
        print(f"\n[{model_idx}/{len(exp_cfg['models'])}] Evaluating: {model_cfg['name']}")
        print(f"  Model: {model_cfg['model']}, Loss: {model_cfg['loss']}, Checkpoint: {model_cfg['checkpoint']}")
        print("-" * 80)
        
        try:
            if model_type == 'person':
                model_train_cfg = train_config['person_reid_experiments'].get(model_cfg['model'])
            else:
                model_train_cfg = train_config['vehicle_reid_experiments'].get(model_cfg['model'])
            
            if model_train_cfg:
                img_height = model_train_cfg['training'].get('img_height', 256)
                img_width = model_train_cfg['training'].get('img_width', 128)
            else:
                img_height, img_width = 256, 128
            
            print(f"  [1/7] Loading dataset...")
            
            dataloaders = get_dataloaders_from_config(
                root=DATA_ROOT,
                dataset_name=dataset_name_with_suffix,
                config_dir=str(project_root / 'config'),
                k_shot=exp_cfg['k_shot'],
                override_params={'height': img_height, 'width': img_width}
            )
            
            query_loader = dataloaders['query']
            gallery_loader = dataloaders['gallery']
            
            print(f"  Query: {len(query_loader.dataset)}, Gallery: {len(gallery_loader.dataset)}")
            
            print(f"  [2/7] Loading model...")
            
            
            loss_suffix = model_cfg['loss']
            train_dataset = model_cfg.get('train_dataset', exp_cfg['dataset'])
            model_save_name = f"{model_cfg['model']}_{train_dataset}_{exp_cfg['data_type']}_k{exp_cfg['k_shot']}_l{loss_suffix}"
            model_path = os.path.join(MODEL_ROOT, model_save_name, model_cfg['checkpoint'])

            
            if 'train_dataset' in model_cfg:
                print(f"  Cross-domain: Trained on {train_dataset}, testing on {exp_cfg['dataset']}")
            
            if not os.path.exists(model_path):
                print(f"  ERROR: Model not found at {model_path}")
                continue
            
            checkpoint = torch.load(model_path, map_location='cpu')
            num_classes = checkpoint['num_classes']
            
            model = build_model(model_cfg['model'], num_classes, loss=model_cfg['loss'], model_type=model_type)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = model.to(device)
            model.eval()
            
            print(f"  Loaded from epoch {checkpoint['epoch']}")
            
            print(f"  [3/7] Extracting features...")
            
            qf, qp, qc, q_paths, qt = extract_features(model, query_loader, device)
            gf, gp, gc, g_paths, gt = extract_features(model, gallery_loader, device)
            
            
            total_time = sum(qt)
            fps = len(qf) / total_time if total_time > 0 else 0
            avg_time_ms = (total_time / len(qf)) * 1000 if len(qf) > 0 else 0
            
            print(f"  [4/7] Computing distances...")
            
            qf_tensor = torch.from_numpy(qf)
            gf_tensor = torch.from_numpy(gf)
            
            distmat = metrics.compute_distance_matrix(qf_tensor, gf_tensor, metric='euclidean').numpy()
            distmat_qq = metrics.compute_distance_matrix(qf_tensor, qf_tensor, metric='euclidean').numpy()
            distmat_gg = metrics.compute_distance_matrix(gf_tensor, gf_tensor, metric='euclidean').numpy()
            
            print(f"  [5/7] Re-ranking...")
            distmat_reranked = torchreid.utils.re_ranking(distmat, distmat_qq, distmat_gg) if exp_cfg.get('reranking', True) else distmat
            
            print(f"  [6/7] Evaluating metrics...")
            cmc, mAP = metrics.evaluate_rank(distmat_reranked, qp, gp, qc, gc, use_metric_cuhk03=False)
            
            print(f"  [7/7] Analyzing errors...")
            
            hard_queries = []
            fp_count = 0
            fn_top10_count = 0
            
            for q_idx in range(len(qp)):
                ranked_indices = np.argsort(distmat_reranked[q_idx])
                ranked_pids = gp[ranked_indices]
                
                correct_mask = ranked_pids == qp[q_idx]
                if correct_mask.any():
                    first_correct_rank = np.where(correct_mask)[0][0]
                else:
                    first_correct_rank = len(gp)
                
                hard_queries.append({
                    'query_idx': q_idx,
                    'pid': qp[q_idx],
                    'cam': qc[q_idx],
                    'first_correct_rank': first_correct_rank,
                    'img_path': q_paths[q_idx]
                })
                
                
                if ranked_pids[0] != qp[q_idx]:
                    fp_count += 1
                
                
                if qp[q_idx] not in ranked_pids[:10]:
                    fn_top10_count += 1
            
            hard_queries = sorted(hard_queries, key=lambda x: x['first_correct_rank'], reverse=True)
            
            fp_rate = fp_count / len(qp)
            fn_rate = fn_top10_count / len(qp)
            
            print(f"\n  RESULTS:")
            print(f"    mAP={mAP:.2%}, R-1={cmc[0]:.2%}, R-5={cmc[4]:.2%}, R-10={cmc[9]:.2%}")
            print(f"    FPS={fps:.1f}, Avg time={avg_time_ms:.1f}ms")
            print(f"    False Positive Rate (Rank-1): {fp_rate:.2%} ({fp_count}/{len(qp)})")
            print(f"    False Negative Rate (Top-10): {fn_rate:.2%} ({fn_top10_count}/{len(qp)})")
            print(f"    Hardest query: Rank {hard_queries[0]['first_correct_rank']+1} (PID={hard_queries[0]['pid']}, Cam={hard_queries[0]['cam']})")
            
            
            results = {
                'experiment_name': exp_name,
                'model_name': model_cfg['name'],
                'model': model_cfg['model'],
                'model_type': model_type,
                'dataset': exp_cfg['dataset'],
                'data_type': exp_cfg['data_type'],
                'k_shot': exp_cfg['k_shot'],
                'checkpoint': model_cfg['checkpoint'],
                'loss': model_cfg['loss'],
                'mAP': float(mAP),
                'rank1': float(cmc[0]),
                'rank5': float(cmc[4]),
                'rank10': float(cmc[9]),
                'rank20': float(cmc[19]),
                'fps': float(fps),
                'avg_query_time_ms': float(avg_time_ms),
                'train_epoch': checkpoint['epoch'],
                'false_positive_rate': float(fp_rate),
                'false_negative_top10_rate': float(fn_rate),
                'hardest_query_rank': int(hard_queries[0]['first_correct_rank'])
            }
            
            experiment_results.append(results)
            all_results.append(results)
            
            error_key = f"{exp_name}_{model_cfg['name']}"
            all_error_data[error_key] = {
                'hard_queries': hard_queries[:50],  
                'qp': qp,
                'qc': qc,
                'gp': gp,
                'gc': gc,
                'q_paths': q_paths,
                'g_paths': g_paths,
                'distmat': distmat_reranked,
                'model_name': model_cfg['name'],
                'experiment': exp_name
            }
            
            
            del model
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"\n  ERROR: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    if experiment_results:
        print(f"\n{'='*80}")
        print(f"EXPERIMENT SUMMARY: {exp_name}")
        print(f"{'='*80}")
        print(f"{'Model':<20} {'mAP':<8} {'R-1':<8} {'R-5':<8} {'FP Rate':<10} {'FN Rate':<10}")
        print("-" * 80)
        for r in experiment_results:
            print(f"{r['model_name']:<20} {r['mAP']*100:<8.2f} {r['rank1']*100:<8.2f} "
                  f"{r['rank5']*100:<8.2f} {r['false_positive_rate']*100:<10.2f} {r['false_negative_top10_rate']*100:<10.2f}")
        print(f"{'='*80}\n")

print(f"\n{'='*80}")
print("ALL EXPERIMENTS COMPLETE")
print(f"{'='*80}")

## 6. Save results

In [None]:
if all_results:
    df = pd.DataFrame(all_results)
    
    
    combined_path = os.path.join(RESULTS_DIR, 'experiment_results_complete.json')
    with open(combined_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"Saved: {combined_path}")
    
    
    csv_path = os.path.join(RESULTS_DIR, 'experiment_results_complete.csv')
    df.to_csv(csv_path, index=False)
    print(f"Saved: {csv_path}")
    
    
    error_path = os.path.join(RESULTS_DIR, 'error_analysis_data.json')
    
    error_data_serializable = {}
    for key, data in all_error_data.items():
        error_data_serializable[key] = {
            'hard_queries': data['hard_queries'],
            'model_name': data['model_name'],
            'experiment': data['experiment']
        }
    with open(error_path, 'w') as f:
        json.dump(error_data_serializable, f, indent=2)
    print(f"Saved: {error_path}")
else:
    print("No results to save")

## Summary

In [None]:
if all_results:
    df = pd.DataFrame(all_results)
    
    print("="*80)
    print("COMPREHENSIVE EXPERIMENT SUMMARY FOR THESIS")
    print("="*80)
    
    print("\n1. BEST PERFORMING MODELS")
    print("-"*80)
    person_best = df[df['dataset'] == 'market1501'].nlargest(3, 'mAP')
    print("\nPerson ReID (Market-1501):")
    for i, (_, row) in enumerate(person_best.iterrows(), 1):
        print(f"  {i}. {row['model_name']}: mAP={row['mAP']:.2%}, R-1={row['rank1']:.2%} (K={row['k_shot']}, {row['data_type']})")
    
    vehicle_best = df[df['dataset'] == 'veri776'].nlargest(3, 'mAP')
    print("\nVehicle ReID (VeRi-776):")
    for i, (_, row) in enumerate(vehicle_best.iterrows(), 1):
        print(f"  {i}. {row['model_name']}: mAP={row['mAP']:.2%}, R-1={row['rank1']:.2%} (K={row['k_shot']}, {row['data_type']})")
    
    print("\n2. K-SHOT ANALYSIS SUMMARY")
    print("-"*80)
    person_kshot = df[(df['model'] == 'osnet_x1_0') & (df['dataset'] == 'market1501')].sort_values('k_shot')
    if len(person_kshot) > 0:
        print("\nPerson ReID (OSNet):")
        for _, row in person_kshot.iterrows():
            print(f"  K={row['k_shot']:2d}: mAP={row['mAP']:.2%}, R-1={row['rank1']:.2%}")
    
    vehicle_kshot = df[(df['model'] == 'rptm') & (df['dataset'] == 'veri776')].sort_values('k_shot')
    if len(vehicle_kshot) > 0:
        print("\nVehicle ReID (RPTM):")
        for _, row in vehicle_kshot.iterrows():
            print(f"  K={row['k_shot']:2d}: mAP={row['mAP']:.2%}, R-1={row['rank1']:.2%}")
    
    print("\n3. ERROR ANALYSIS SUMMARY")
    print("-"*80)
    print("\nAverage False Positive Rates (Rank-1 errors):")
    person_fp = df[df['dataset'] == 'market1501']['false_positive_rate'].mean()
    vehicle_fp = df[df['dataset'] == 'veri776']['false_positive_rate'].mean()
    print(f"  Person ReID: {person_fp:.2%}")
    print(f"  Vehicle ReID: {vehicle_fp:.2%}")
    
    print("\nAverage False Negative Rates (Top-10 misses):")
    person_fn = df[df['dataset'] == 'market1501']['false_negative_top10_rate'].mean()
    vehicle_fn = df[df['dataset'] == 'veri776']['false_negative_top10_rate'].mean()
    print(f"  Person ReID: {person_fn:.2%}")
    print(f"  Vehicle ReID: {vehicle_fn:.2%}")
    
    print("\n4. INFERENCE SPEED SUMMARY")
    print("-"*80)
    print("\nFastest models (Person ReID):")
    person_fast = df[df['dataset'] == 'market1501'].nlargest(3, 'fps')
    for i, (_, row) in enumerate(person_fast.iterrows(), 1):
        print(f"  {i}. {row['model_name']}: {row['fps']:.1f} FPS ({row['avg_query_time_ms']:.1f}ms per query)")
    
    print("\nFastest models (Vehicle ReID):")
    vehicle_fast = df[df['dataset'] == 'veri776'].nlargest(3, 'fps')
    for i, (_, row) in enumerate(vehicle_fast.iterrows(), 1):
        print(f"  {i}. {row['model_name']}: {row['fps']:.1f} FPS ({row['avg_query_time_ms']:.1f}ms per query)")
    
    print("\n" + "="*80)
    print("Report complete. All visualizations saved to:", os.path.join(RESULTS_DIR, 'visualizations'))
    print("="*80)