In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from sklearn.metrics import silhouette_score, davies_bouldin_score
from collections import defaultdict
import random
import json

from models import (
    ViTEncoder, QuantumEnhancer, QuPIDModel,
    nt_xent_loss, quantum_contrastive_loss, combined_contrastive_loss
)


CONFIG = {
    'training_mode': 'two_stage',
    'checkpoint_dir': './checkpoints',
    'results_dir': './results',
    'image_size': 224,
    'vit_model': 'vit_large_patch16_224',
    'vit_dim': 1024,
    'projection_dim': 1024,
    'use_pretrained': True,
    'freeze_vit_backbone': True,
    'n_qubits': 10,
    'n_qlayers': 3,
    'epochs_stage1': 100,
    'lr_stage1': 1e-4,
    'epochs_stage2': 50,
    'lr_stage2': 1e-3,
    'epochs_e2e': 100,
    'lr_e2e_vit': 1e-4,
    'lr_e2e_quantum': 1e-3,
    'batch_size': 32,
    'weight_decay': 0.05,
    'grad_clip': 1.0,
    'temperature': 0.07,
    'eval_every_epoch': True,
    'save_every_n_epochs': 10,
    'k_values': [5, 10],
    'random_seed': 42,
    'num_workers': 4,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def setup_directories():
    for dir_path in [CONFIG['checkpoint_dir'], CONFIG['results_dir']]:
        os.makedirs(dir_path, exist_ok=True)


def get_device():
    return torch.device(CONFIG['device'])


def get_ssl_transforms(image_size=224):
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomRotation(5),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


def get_eval_transforms(image_size=224):
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


def compute_retrieval_metrics(embeddings, labels, k_values=[5, 10]):
    if torch.is_tensor(embeddings):
        embeddings = embeddings.cpu().numpy()
    if torch.is_tensor(labels):
        labels = labels.cpu().numpy()
    labels = np.array(labels)
    valid_mask = labels >= 0
    if valid_mask.sum() < 10:
        return {f'P@{k}': 0 for k in k_values}
    embeddings = embeddings[valid_mask]
    labels = labels[valid_mask]
    n_samples = len(labels)
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-8
    embeddings_norm = embeddings / norms
    sim_matrix = np.dot(embeddings_norm, embeddings_norm.T)
    np.fill_diagonal(sim_matrix, -np.inf)
    metrics = {}
    for k in k_values:
        precisions, recalls, aps, mrrs, ndcgs = [], [], [], [], []
        for i in range(n_samples):
            top_k_indices = np.argsort(sim_matrix[i])[::-1][:k]
            query_label = labels[i]
            relevant_mask = (labels == query_label)
            relevant_mask[i] = False
            n_relevant = relevant_mask.sum()
            if n_relevant == 0:
                continue
            retrieved_relevance = (labels[top_k_indices] == query_label)
            precision = retrieved_relevance.sum() / k
            precisions.append(precision)
            recall = retrieved_relevance.sum() / n_relevant
            recalls.append(recall)
            ap = 0
            n_rel_so_far = 0
            for j, rel in enumerate(retrieved_relevance):
                if rel:
                    n_rel_so_far += 1
                    ap += n_rel_so_far / (j + 1)
            if n_rel_so_far > 0:
                ap /= min(n_relevant, k)
            aps.append(ap)
            first_rel_positions = np.where(retrieved_relevance)[0]
            if len(first_rel_positions) > 0:
                mrrs.append(1.0 / (first_rel_positions[0] + 1))
            else:
                mrrs.append(0)
            dcg = sum([(1 if retrieved_relevance[j] else 0) / np.log2(j + 2) for j in range(k)])
            ideal_rel = min(n_relevant, k)
            idcg = sum([1 / np.log2(j + 2) for j in range(ideal_rel)])
            ndcg = dcg / idcg if idcg > 0 else 0
            ndcgs.append(ndcg)
        metrics[f'P@{k}'] = np.mean(precisions) if precisions else 0
        metrics[f'R@{k}'] = np.mean(recalls) if recalls else 0
        metrics[f'MAP@{k}'] = np.mean(aps) if aps else 0
        metrics[f'NDCG@{k}'] = np.mean(ndcgs) if ndcgs else 0
    metrics['MRR'] = np.mean(mrrs) if mrrs else 0
    return metrics


def compute_clustering_metrics(embeddings, labels):
    if torch.is_tensor(embeddings):
        embeddings = embeddings.cpu().numpy()
    if torch.is_tensor(labels):
        labels = labels.cpu().numpy()
    labels = np.array(labels)
    valid_mask = labels >= 0
    if valid_mask.sum() < 10:
        return {'silhouette_score': 0, 'davies_bouldin': float('inf')}
    embeddings = embeddings[valid_mask]
    labels = labels[valid_mask]
    unique_labels = np.unique(labels)
    centroids = {}
    for label in unique_labels:
        mask = labels == label
        if mask.sum() > 0:
            centroids[label] = embeddings[mask].mean(axis=0)
    intra_distances = []
    for label in unique_labels:
        mask = labels == label
        if mask.sum() > 1:
            cluster_points = embeddings[mask]
            centroid = centroids[label]
            distances = np.linalg.norm(cluster_points - centroid, axis=1)
            intra_distances.extend(distances.tolist())
    intra_cluster_dist = np.mean(intra_distances) if intra_distances else 0
    inter_distances = []
    centroid_list = list(centroids.values())
    for i in range(len(centroid_list)):
        for j in range(i + 1, len(centroid_list)):
            dist = np.linalg.norm(centroid_list[i] - centroid_list[j])
            inter_distances.append(dist)
    inter_cluster_dist = np.mean(inter_distances) if inter_distances else 0
    try:
        silhouette = silhouette_score(embeddings, labels)
    except:
        silhouette = 0
    try:
        db_index = davies_bouldin_score(embeddings, labels)
    except:
        db_index = float('inf')
    return {
        'intra_cluster_distance': intra_cluster_dist,
        'inter_cluster_distance': inter_cluster_dist,
        'silhouette_score': silhouette,
        'davies_bouldin': db_index
    }


class Trainer:
    def __init__(self, config, device):
        self.config = config
        self.device = device
        self.history = defaultdict(list)
    
    def train_stage1_vit(self, vit_encoder, train_loader, val_loader=None):
        vit_encoder = vit_encoder.to(self.device)
        vit_encoder.train()
        optimizer = torch.optim.AdamW(
            vit_encoder.projector.parameters(),
            lr=self.config['lr_stage1'],
            weight_decay=self.config['weight_decay']
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.config['epochs_stage1']
        )
        best_val_metric = 0
        for epoch in range(self.config['epochs_stage1']):
            vit_encoder.train()
            epoch_loss = 0
            pbar = tqdm(train_loader, desc=f"Stage1 Epoch {epoch+1}/{self.config['epochs_stage1']}")
            for view1, view2, _ in pbar:
                view1 = view1.to(self.device)
                view2 = view2.to(self.device)
                z1 = vit_encoder(view1)
                z2 = vit_encoder(view2)
                loss = nt_xent_loss(z1, z2, self.config['temperature'])
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(vit_encoder.parameters(), self.config['grad_clip'])
                optimizer.step()
                epoch_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            avg_loss = epoch_loss / len(train_loader)
            scheduler.step()
            self.history['stage1_loss'].append(avg_loss)
            if val_loader and self.config['eval_every_epoch']:
                val_metrics = self._evaluate_vit(vit_encoder, val_loader)
                self.history['stage1_val_map'].append(val_metrics.get('MAP@10', 0))
                if val_metrics.get('MAP@10', 0) > best_val_metric:
                    best_val_metric = val_metrics.get('MAP@10', 0)
                    self._save_checkpoint(vit_encoder, 'best_vit_encoder.pt')
            if (epoch + 1) % self.config['save_every_n_epochs'] == 0:
                self._save_checkpoint(vit_encoder, f'vit_encoder_epoch{epoch+1}.pt')
        self._save_checkpoint(vit_encoder, 'final_vit_encoder.pt')
        return vit_encoder
    
    def train_stage2_quantum(self, quantum_enhancer, vit_encoder, train_loader, val_loader=None):
        vit_encoder = vit_encoder.to(self.device)
        quantum_enhancer = quantum_enhancer.to(self.device)
        vit_encoder.eval()
        quantum_enhancer.train()
        optimizer = torch.optim.Adam(
            quantum_enhancer.parameters(),
            lr=self.config['lr_stage2']
        )
        best_val_metric = 0
        for epoch in range(self.config['epochs_stage2']):
            quantum_enhancer.train()
            epoch_loss = 0
            pbar = tqdm(train_loader, desc=f"Stage2 Epoch {epoch+1}/{self.config['epochs_stage2']}")
            for view1, view2, _ in pbar:
                view1 = view1.to(self.device)
                view2 = view2.to(self.device)
                with torch.no_grad():
                    z1 = vit_encoder(view1)
                    z2 = vit_encoder(view2)
                q1 = quantum_enhancer(z1)
                q2 = quantum_enhancer(z2)
                loss = quantum_contrastive_loss(q1, q2, self.config['temperature'])
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            avg_loss = epoch_loss / len(train_loader)
            self.history['stage2_loss'].append(avg_loss)
            if val_loader and self.config['eval_every_epoch']:
                val_metrics = self._evaluate_quantum(quantum_enhancer, vit_encoder, val_loader)
                self.history['stage2_val_map'].append(val_metrics.get('MAP@10', 0))
                if val_metrics.get('MAP@10', 0) > best_val_metric:
                    best_val_metric = val_metrics.get('MAP@10', 0)
                    self._save_checkpoint(quantum_enhancer, 'best_quantum_enhancer.pt')
            if (epoch + 1) % self.config['save_every_n_epochs'] == 0:
                self._save_checkpoint(quantum_enhancer, f'quantum_enhancer_epoch{epoch+1}.pt')
        self._save_checkpoint(quantum_enhancer, 'final_quantum_enhancer.pt')
        return quantum_enhancer
    
    def train_end_to_end(self, qupid_model, train_loader, val_loader=None):
        qupid_model = qupid_model.to(self.device)
        qupid_model.train()
        optimizer = torch.optim.AdamW([
            {'params': qupid_model.vit_encoder.projector.parameters(), 
             'lr': self.config['lr_e2e_vit']},
            {'params': qupid_model.quantum_enhancer.parameters(), 
             'lr': self.config['lr_e2e_quantum']}
        ], weight_decay=self.config['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.config['epochs_e2e']
        )
        best_val_metric = 0
        for epoch in range(self.config['epochs_e2e']):
            qupid_model.train()
            epoch_loss = 0
            pbar = tqdm(train_loader, desc=f"E2E Epoch {epoch+1}/{self.config['epochs_e2e']}")
            for view1, view2, _ in pbar:
                view1 = view1.to(self.device)
                view2 = view2.to(self.device)
                z1 = qupid_model.get_vit_embedding(view1)
                z2 = qupid_model.get_vit_embedding(view2)
                q1 = qupid_model.quantum_enhancer(z1)
                q2 = qupid_model.quantum_enhancer(z2)
                loss, _, _ = combined_contrastive_loss(
                    z1, z2, q1, q2, 
                    temperature=self.config['temperature'],
                    quantum_weight=0.5
                )
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(qupid_model.parameters(), self.config['grad_clip'])
                optimizer.step()
                epoch_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            avg_loss = epoch_loss / len(train_loader)
            scheduler.step()
            self.history['e2e_loss'].append(avg_loss)
            if val_loader and self.config['eval_every_epoch']:
                val_metrics = self._evaluate_e2e(qupid_model, val_loader)
                self.history['e2e_val_map'].append(val_metrics.get('MAP@10', 0))
                if val_metrics.get('MAP@10', 0) > best_val_metric:
                    best_val_metric = val_metrics.get('MAP@10', 0)
                    self._save_checkpoint(qupid_model, 'best_qupid_e2e.pt')
            if (epoch + 1) % self.config['save_every_n_epochs'] == 0:
                self._save_checkpoint(qupid_model, f'qupid_e2e_epoch{epoch+1}.pt')
        self._save_checkpoint(qupid_model, 'final_qupid_e2e.pt')
        return qupid_model
    
    def _evaluate_vit(self, vit_encoder, dataloader):
        vit_encoder.eval()
        embeddings = []
        labels = []
        with torch.no_grad():
            for batch in dataloader:
                images, batch_labels = batch[0], batch[1]
                images = images.to(self.device)
                emb = vit_encoder(images)
                embeddings.append(emb.cpu())
                if torch.is_tensor(batch_labels):
                    labels.extend(batch_labels.tolist())
                else:
                    labels.extend([batch_labels] if not isinstance(batch_labels, list) else batch_labels)
        embeddings = torch.cat(embeddings, dim=0).numpy()
        labels = np.array(labels)
        return compute_retrieval_metrics(embeddings, labels, self.config['k_values'])
    
    def _evaluate_quantum(self, quantum_enhancer, vit_encoder, dataloader):
        vit_encoder.eval()
        quantum_enhancer.eval()
        embeddings = []
        labels = []
        with torch.no_grad():
            for batch in dataloader:
                images, batch_labels = batch[0], batch[1]
                images = images.to(self.device)
                vit_emb = vit_encoder(images)
                q_emb = quantum_enhancer(vit_emb)
                embeddings.append(q_emb.cpu())
                if torch.is_tensor(batch_labels):
                    labels.extend(batch_labels.tolist())
                else:
                    labels.extend([batch_labels] if not isinstance(batch_labels, list) else batch_labels)
        embeddings = torch.cat(embeddings, dim=0).numpy()
        labels = np.array(labels)
        return compute_retrieval_metrics(embeddings, labels, self.config['k_values'])
    
    def _evaluate_e2e(self, qupid_model, dataloader):
        qupid_model.eval()
        embeddings = []
        labels = []
        with torch.no_grad():
            for batch in dataloader:
                images, batch_labels = batch[0], batch[1]
                images = images.to(self.device)
                q_emb = qupid_model(images)
                embeddings.append(q_emb.cpu())
                if torch.is_tensor(batch_labels):
                    labels.extend(batch_labels.tolist())
                else:
                    labels.extend([batch_labels] if not isinstance(batch_labels, list) else batch_labels)
        embeddings = torch.cat(embeddings, dim=0).numpy()
        labels = np.array(labels)
        return compute_retrieval_metrics(embeddings, labels, self.config['k_values'])
    
    def _save_checkpoint(self, model, filename):
        path = os.path.join(self.config['checkpoint_dir'], filename)
        torch.save(model.state_dict(), path)
    
    def save_history(self, filename='training_history.json'):
        path = os.path.join(self.config['results_dir'], filename)
        with open(path, 'w') as f:
            json.dump(dict(self.history), f, indent=2)


def main(train_loader, val_loader=None):
    set_seed(CONFIG['random_seed'])
    setup_directories()
    device = get_device()
    
    vit_encoder = ViTEncoder(
        model_name=CONFIG['vit_model'],
        embedding_dim=CONFIG['projection_dim'],
        pretrained=CONFIG['use_pretrained'],
        freeze_backbone=CONFIG['freeze_vit_backbone']
    )
    quantum_enhancer = QuantumEnhancer(
        input_dim=CONFIG['projection_dim'],
        n_qubits=CONFIG['n_qubits'],
        n_qlayers=CONFIG['n_qlayers']
    )
    
    trainer = Trainer(CONFIG, device)
    
    if CONFIG['training_mode'] == 'two_stage':
        vit_encoder = trainer.train_stage1_vit(vit_encoder, train_loader, val_loader)
        quantum_enhancer = trainer.train_stage2_quantum(
            quantum_enhancer, vit_encoder, train_loader, val_loader
        )
        qupid_model = QuPIDModel(vit_encoder, quantum_enhancer)
    elif CONFIG['training_mode'] == 'end_to_end':
        qupid_model = QuPIDModel(vit_encoder, quantum_enhancer)
        qupid_model = trainer.train_end_to_end(qupid_model, train_loader, val_loader)
    
    trainer.save_history()
    return qupid_model


if __name__ == "__main__":
    pass