<h1>SSL training loop</h1>

In [None]:
import google.colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd master-thesis-reid

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.cluster import DBSCAN
from tqdm import tqdm
import json
from pathlib import Path
from PIL import Image

import torchreid
from torchreid import metrics

project_root = Path('/content/master-thesis-reid')
sys.path.insert(0, str(project_root))

from models.person import transreid_base
from utils.data_loader import get_dataloaders_from_config

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
!pip install -q torchreid scikit-learn

In [None]:
CONFIG = {
    'dataset': 'market1501_preprocessed',
    'k_shot': 16,
    'data_root': '/content/drive/MyDrive/reid_data',
    'config_dir': str(project_root / 'config'),

    'num_iterations': 1,
    'epochs_per_iteration': [40],
    'dbscan_eps': 1.5,
    'dbscan_eps_decay': 1.0,
    'dbscan_min_samples': 8,
    'lr_multiplier': [0.1],
    'warmup_epochs': 10,
}

BATCH_SIZE = 64
BASE_LR = 0.00035
WEIGHT_DECAY = 0.0005
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ROOT = '/content/drive/MyDrive/reid_models/ssl+reranking/'

print("SSL Configuration:")
print(json.dumps(CONFIG, indent=2))

<h2>Load Data</h2>

In [None]:
print("\n[1/3] Loading FULL dataset (all training data)...")
dataloaders = get_dataloaders_from_config(
    root=CONFIG['data_root'],
    dataset_name=CONFIG['dataset'],
    config_dir=str(project_root / 'config'),
    k_shot=99,  # Full dataset
    override_params={
        'height': 256,
        'width': 128,
        'batch_size_train': BATCH_SIZE,
        'batch_size_test': 128,
    }
)

train_loader_full = dataloaders['train']
query_loader = dataloaders['query']
gallery_loader = dataloaders['gallery']

print(f"  Training samples: {len(train_loader_full.dataset)}")
print(f"  Query: {len(query_loader.dataset)}, Gallery: {len(gallery_loader.dataset)}")


In [None]:
print("\n[2/3] Collecting all training data...")
all_data = []
for batch in tqdm(train_loader_full, desc='Loading data'):
    imgs, pids, camids, img_paths = batch
    for i in range(len(img_paths)):
        all_data.append({
            'img_path': img_paths[i],
            'pid': pids[i].item(),
            'camid': camids[i].item()
        })

print(f"  Total samples: {len(all_data)}")
print(f"  Ground-truth IDs: {len(set([d['pid'] for d in all_data]))} (for reference only)")


<h2>Helper functions</h2>

In [None]:
ssl_results_history = []
best_mAP = 0.0
best_iteration = 0
best_epoch = 0

for ssl_iter in range(1, CONFIG['num_iterations'] + 1):
    print(f"\n{'='*80}")
    print(f"SSL ITERATION {ssl_iter}/{CONFIG['num_iterations']}")
    print(f"{'='*80}")
    
    print(f"\n[Step 1/4] Extracting features from all training data...")

    model.eval()
    all_features = []
    feature_indices = [] 

    temp_dataset = train_loader_full.dataset
    temp_loader = torch.utils.data.DataLoader(
        temp_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    sample_idx = 0
    with torch.no_grad():
        for batch in tqdm(temp_loader, desc='Extracting features'):
            imgs, pids, camids, img_paths = batch
            imgs = imgs.to(device)

            outputs = model(imgs)

            if isinstance(outputs, tuple):
                features = outputs[1] if len(outputs) > 1 else outputs[0]
            else:
                features = outputs

            features = F.normalize(features, p=2, dim=1)
            all_features.append(features.cpu())

            batch_size = imgs.size(0)
            feature_indices.extend(range(sample_idx, sample_idx + batch_size))
            sample_idx += batch_size

    all_features = torch.cat(all_features, dim=0).numpy()
    print(f"  Extracted features shape: {all_features.shape}")
    print(f"  Number of samples in all_data: {len(all_data)}")  
    print(f"  Feature indices range: {min(feature_indices)} to {max(feature_indices)}") 


    print(f"\n[Step 2/4] Generating pseudo-labels with DBSCAN...")

    current_eps = CONFIG['dbscan_eps'] * (CONFIG['dbscan_eps_decay'] ** (ssl_iter - 1))
    print(f"  DBSCAN eps={current_eps:.3f}, min_samples={CONFIG['dbscan_min_samples']}")

    print("  [1/2] Computing distance matrix...")
    from sklearn.metrics import pairwise_distances
    distmat = pairwise_distances(all_features, metric='euclidean')

    print("  [2/2] Running DBSCAN clustering...")
    from sklearn.cluster import DBSCAN
    clusterer = DBSCAN(
        eps=current_eps,
        min_samples=CONFIG['dbscan_min_samples'],
        metric='precomputed',
        n_jobs=-1
    )
    cluster_labels = clusterer.fit_predict(distmat)

    n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
    n_noise = list(cluster_labels).count(-1)

    print(f"\n  Clustering results:")
    print(f"    Clusters found: {n_clusters}")
    print(f"    Noise samples: {n_noise}/{len(cluster_labels)} ({100*n_noise/len(cluster_labels):.1f}%)")

    if n_clusters < 50:
        print(f"    WARNING: Very few clusters! Consider increasing eps to {current_eps*1.2:.3f}")
        if n_clusters == 1:
            print(f"    ERROR: Only 1 cluster! Stopping SSL - features are too similar")
            print(f"    This usually means the model hasn't learned discriminative features yet")
            break
    elif n_clusters > 2000:
        print(f"    WARNING: Too many clusters! Consider decreasing eps to {current_eps*0.8:.3f}")

    cluster_sizes = []
    for cluster_id in range(n_clusters):
        cluster_size = (cluster_labels == cluster_id).sum()
        cluster_sizes.append(cluster_size)

    if cluster_sizes:
        print(f"    Largest cluster: {max(cluster_sizes)} samples")
        print(f"    Mean cluster size: {np.mean(cluster_sizes):.1f}")
        print(f"    Median cluster size: {np.median(cluster_sizes):.1f}")

    print(f"\n[Step 3/4] Creating SSL training dataset...")

    valid_mask = cluster_labels >= 0
    valid_feature_indices = np.where(valid_mask)[0]  

    print(f"  Valid samples (non-noise): {len(valid_feature_indices)}/{len(cluster_labels)}")

    unique_clusters = sorted(set(cluster_labels[valid_mask]))
    cluster_to_label = {cluster_id: new_label for new_label, cluster_id in enumerate(unique_clusters)}

    ssl_dataset_info = []
    for feat_idx in valid_feature_indices: 
        data_idx = feature_indices[feat_idx]

        if data_idx >= len(all_data):
            print(f"    WARNING: Skipping invalid index {data_idx} (max: {len(all_data)-1})")
            continue

        cluster_id = cluster_labels[feat_idx]
        pseudo_label = cluster_to_label[cluster_id]

        ssl_dataset_info.append({
            'img_path': all_data[data_idx]['img_path'],
            'pseudo_pid': pseudo_label,
            'camid': all_data[data_idx]['camid'],
            'gt_pid': all_data[data_idx]['pid'], 
        })

    num_pseudo_classes = len(unique_clusters)

    print(f"\n  SSL Dataset created:")
    print(f"    Total samples: {len(ssl_dataset_info)}")
    print(f"    Pseudo-label classes: {num_pseudo_classes}")
    print(f"    Filtered noise: {n_noise}")

    print(f"\n[Step 4/4] Setting up SSL training...")

    print(f"  Rebuilding model with {num_pseudo_classes} pseudo-classes...")

    new_model = transreid_base(num_classes=num_pseudo_classes, loss='softmax', pretrained=(ssl_iter == 1))

    if ssl_iter > 1:
        print(f"  Transferring backbone weights from iteration {ssl_iter-1}...")
        old_state = model.state_dict()
        new_state = new_model.state_dict()

        transferred = 0
        for name, param in old_state.items():
            if 'classifier' not in name and 'bottleneck' not in name and 'fc' not in name:
                if name in new_state and param.shape == new_state[name].shape:
                    new_state[name] = param
                    transferred += 1

        new_model.load_state_dict(new_state, strict=False)
        print(f"    Transferred {transferred} layers")

    model = new_model.to(device)

    print(f"  Creating SSL dataloader...")

    from torch.utils.data import Dataset

    class SSLDataset(torch.utils.data.Dataset):
        def __init__(self, data_info, transform=None):
            self.data_info = data_info
            self.transform = transform

            # Fallback transform jeÅ›li brak
            if self.transform is None:
                from torchvision import transforms
                self.transform = transforms.Compose([
                    transforms.Resize((256, 128)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])

        def __len__(self):
            return len(self.data_info)

        def __getitem__(self, idx):
            info = self.data_info[idx]
            img = Image.open(info['img_path']).convert('RGB')
            img = self.transform(img) 
            return img, info['pseudo_pid'], info['camid'], info['img_path']

    ssl_dataset = SSLDataset(
        ssl_dataset_info,
        transform=getattr(train_loader_full.dataset, 'transform', None)
    )
    ssl_loader = torch.utils.data.DataLoader(
        ssl_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,  
        pin_memory=True,
        drop_last=True
    )

    print(f"    Batches per epoch: {len(ssl_loader)}")

    lr = BASE_LR * CONFIG['lr_multiplier'][ssl_iter - 1]
    num_epochs = CONFIG['epochs_per_iteration'][ssl_iter - 1]
    warmup_epochs = CONFIG['warmup_epochs'] if ssl_iter == 1 else 5

    print(f"\n  Training configuration:")
    print(f"    Epochs: {num_epochs}")
    print(f"    Base LR: {lr:.6f}")
    print(f"    Warmup epochs: {warmup_epochs}")

    backbone_params = []
    classifier_params = []

    for name, param in model.named_parameters():
        if 'classifier' in name or 'bottleneck' in name or 'fc' in name:
            classifier_params.append(param)
        else:
            backbone_params.append(param)

    optimizer = torch.optim.Adam([
        {'params': backbone_params, 'lr': lr},
        {'params': classifier_params, 'lr': lr * 10}  # Higher LR for classifier
    ], weight_decay=WEIGHT_DECAY)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=lr * 0.01
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # Warmup phase (classifier only)
    if warmup_epochs > 0:
        print(f"\n  [Warm-up] Training classifier for {warmup_epochs} epochs...")
        print(f"    Backbone: frozen")
        print(f"    Classifier LR: {lr * 10:.6f}")

        # Freeze backbone
        for param in backbone_params:
            param.requires_grad = False

        warmup_optimizer = torch.optim.Adam(classifier_params, lr=lr*10, weight_decay=WEIGHT_DECAY)

        for warmup_epoch in range(1, warmup_epochs + 1):
            model.train()
            correct = 0
            total = 0
            running_loss = 0.0

            pbar = tqdm(ssl_loader, desc=f'Warmup {warmup_epoch}')
            for imgs, pids, camids, _ in pbar:
                imgs, pids = imgs.to(device), pids.to(device)

                warmup_optimizer.zero_grad()
                outputs = model(imgs)

                if isinstance(outputs, tuple):
                    logits = outputs[0]
                else:
                    logits = outputs

                loss = criterion(logits, pids)
                loss.backward()
                warmup_optimizer.step()

                running_loss += loss.item()
                _, predicted = logits.max(1)
                total += pids.size(0)
                correct += predicted.eq(pids).sum().item()

                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.*correct/total:.2f}%'
                })

            epoch_acc = 100. * correct / total
            print(f"    Warmup {warmup_epoch}: Acc={epoch_acc:.2f}%")

        # Unfreeze backbone
        for param in backbone_params:
            param.requires_grad = True

        print(f"  Warmup complete!\n")

    # Main SSL Training
    print(f"  [Main Training] {num_epochs} epochs...")

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(ssl_loader, desc=f'SSL {ssl_iter}.{epoch}')
        for imgs, pids, camids, _ in pbar:
            imgs, pids = imgs.to(device), pids.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)

            if isinstance(outputs, tuple):
                logits = outputs[0]
            else:
                logits = outputs

            loss = criterion(logits, pids)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = logits.max(1)
            total += pids.size(0)
            correct += predicted.eq(pids).sum().item()

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })

        scheduler.step()

        epoch_loss = running_loss / len(ssl_loader)
        epoch_acc = 100. * correct / total

        print(f"  SSL {ssl_iter}.{epoch}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%")

        if epoch % 10 == 0:
            print(f"\n  [Evaluation] SSL Iter {ssl_iter}, Epoch {epoch}:\n")

            eval_results = evaluate_model(
                model, query_loader, gallery_loader, device,
                use_reranking=True
            )

            print(f"    mAP: {eval_results['mAP']:.2%}")
            print(f"    Rank-1: {eval_results['rank1']:.2%}")
            print(f"    Rank-5: {eval_results['rank5']:.2%}")
            print(f"    Rank-10: {eval_results['rank10']:.2%}")

            # Save best model
            if eval_results['mAP'] > best_mAP:
                best_mAP = eval_results['mAP']
                best_iteration = ssl_iter
                best_epoch = epoch

                checkpoint_path = os.path.join(MODEL_ROOT, "transreid_pure_ssl_best.pth")
                torch.save({
                    'iteration': ssl_iter,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'num_classes': num_pseudo_classes,
                    'mAP': eval_results['mAP'],
                    'rank1': eval_results['rank1'],
                    'config': CONFIG,
                }, checkpoint_path)

                print(f"\n    NEW BEST MODEL SAVED!")
                print(f"    mAP: {best_mAP:.2%}, Iteration: {best_iteration}, Epoch: {best_epoch}\n")

    print(f"\n{'='*60}")
    print(f"SSL ITERATION {ssl_iter} COMPLETE")
    print(f"{'='*60}\n")

    final_results = evaluate_model(model, query_loader, gallery_loader, device, use_reranking=True)
    ssl_results_history.append({
        'iteration': ssl_iter,
        'mAP': final_results['mAP'],
        'rank1': final_results['rank1'],
        'rank5': final_results['rank5'],
        'rank10': final_results['rank10'],
    })

    print(f"  Final results:")
    print(f"    mAP: {final_results['mAP']:.2%}")
    print(f"    Rank-1: {final_results['rank1']:.2%}")

    iter_checkpoint_path = os.path.join(MODEL_ROOT, f"transreid_pure_ssl_iter{ssl_iter}.pth")
    torch.save({
        'iteration': ssl_iter,
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'num_classes': num_pseudo_classes,
        'results': final_results,
        'config': CONFIG,
    }, iter_checkpoint_path)
    print(f"  Saved: {iter_checkpoint_path}\n")