In [None]:
import os
import json
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import albumentations as A
from transformers import Dinov2PreTrainedModel, Dinov2Model
import csv
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data processing
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255
val_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

# Dinov2
class Dinov2FeatureExtractor(Dinov2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)
        self.projection_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, 256),
        )

        for param in self.dinov2.parameters():
            param.requires_grad = False

        self._unfreeze_dinov2_layers(2)
        for param in self.projection_head.parameters():
            param.requires_grad = False
    
    def _unfreeze_dinov2_layers(self, unfreeze_layers):
        try:
            total_blocks = len(self.dinov2.encoder.layer)
            layers_to_unfreeze = max(0, total_blocks - unfreeze_layers)
            
            print(f"Unfreeze the last {unfreeze_layers} Transformer blocks ({layers_to_unfreeze}-{total_blocks-1})")

            for i in range(layers_to_unfreeze, total_blocks):
                for param in self.dinov2.encoder.layer[i].parameters():
                    param.requires_grad = True
                print(f"Unfreeze block {i}")

            for param in self.dinov2.layernorm.parameters():
                param.requires_grad = True
            print("Unfreeze layernorm layer")
                
        except Exception as e:
            print(f"Error occurred during unfreezing: {e}")
            print("Only train the projection head")
    
    def forward(self, pixel_values, output_hidden_states=False, output_attentions=False,return_attentions=False):
        outputs = self.dinov2(
            pixel_values,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions
        )

        cls_token = outputs.last_hidden_state[:, 0, :]  # [batch_size, hidden_size]

        features = self.projection_head(cls_token)  # [batch_size, 256]
        if return_attentions:
            return features, outputs.last_hidden_state, outputs.hidden_states, outputs.attentions
        else:
            # return query_feat, target_feat, align_feat
            return {
                'features': features,
                'last_hidden_state': outputs.last_hidden_state,
                'hidden_states': outputs.hidden_states,
                'attentions': outputs.attentions
            }

# Query-Guided Attention Module
class QueryGuidedAttention(nn.Module):
    def __init__(self, hidden_size=768, num_heads=8, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
    
    def forward(self, query_global, target_spatial):
        """
        query_global: [batch_size, 1, hidden_size] sub-element features
        target_spatial: [batch_size, seq_len, hidden_size] image features 
        """
        context, attn_weights = self.multihead_attn(
            query=query_global,
            key=target_spatial,
            value=target_spatial,
            need_weights=True
        )
        return context, attn_weights


class L2Norm(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

class AttentionFeatureExtractor(Dinov2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)
        self.cross_attention = QueryGuidedAttention(hidden_size=config.hidden_size, num_heads=8, dropout=0.1)

        self.query_projection = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.GELU(),
            L2Norm()
        )
        
        self.target_projection = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.GELU(),
            L2Norm()
        )

        self.attention_align = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            L2Norm()
        )

        for param in self.dinov2.parameters():
            param.requires_grad = False

        self._unfreeze_dinov2_layers(2)

        for param in self.query_projection.parameters():
            param.requires_grad = True
        for param in self.target_projection.parameters():
            param.requires_grad = True
        for param in self.attention_align.parameters():
            param.requires_grad = True
        for param in self.cross_attention.parameters():
            param.requires_grad = True
    
    def _unfreeze_dinov2_layers(self, unfreeze_layers):
        try:
            total_blocks = len(self.dinov2.encoder.layer)
            layers_to_unfreeze = max(0, total_blocks - unfreeze_layers)
            
            print(f"Unfreeze the last {unfreeze_layers} Transformer blocks ({layers_to_unfreeze}-{total_blocks-1})")
 
            for i in range(layers_to_unfreeze, total_blocks):
                for param in self.dinov2.encoder.layer[i].parameters():
                    param.requires_grad = True
                print(f"Unfreeze block {i}")

            for param in self.dinov2.layernorm.parameters():
                param.requires_grad = True
            print("Unfreeze layernorm layer")
                
        except Exception as e:
            print(f"Error occurred during unfreezing: {e}")
            print("Only train the projection head")
    
    def forward(self, query_images, target_images, output_hidden_states=False, output_attentions=False, is_train=True, return_attentions=False):
        query_outputs = self.dinov2(
            query_images,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions
        )
        query_global = query_outputs.last_hidden_state[:, :1, :] # Sub-element takes the CLS token as feature

        if is_train and target_images is not None:
            target_outputs = self.dinov2(target_images)
            target_spatial = target_outputs.last_hidden_state[:, 1:, :] # Target image takes the patch tokens as features

            context, attn_weights = self.cross_attention(
                query_global, 
                target_spatial
            ) # query: query_global, key/value: target_spatial

            query_feat = self.query_projection(query_global.squeeze(1))
            target_feat = self.target_projection(context.squeeze(1))
            align_feat = self.attention_align(query_global.squeeze(1)) 

            if return_attentions:
                return query_feat, target_feat, align_feat, attn_weights
            else:
                return query_feat, target_feat, align_feat
        else:
            return self.query_projection(query_global.squeeze(1))
        
# Model loading
def load_model(model_config_path, weights_path, device):
    feature_extractor  = AttentionFeatureExtractor.from_pretrained(model_config_path)
    
    # Load weights
    try:
        checkpoint = torch.load(weights_path, map_location=device)
        feature_extractor.load_state_dict(checkpoint['feature_extractor_state_dict'])
        print(f"Successfully loaded feature extractor weights from {weights_path}")
        print(f"Model trained for {checkpoint.get('epoch', 'unknown')} epochs")
        
    except Exception as e:
        print(f"Error loading model weights: {e}")
        print("Using model with default weights")
    
    feature_extractor.to(device)
    return feature_extractor

# Fine reranking candidate images using cross-attention
def perform_cross_attention_reranking(feature_extractor, query_image, candidate_paths, top_k=5):
    # Query image preprocessing
    query_np = np.array(query_image)
    transformed_query = val_transform(image=query_np)
    query_tensor = torch.tensor(transformed_query["image"]).permute(2, 0, 1).float().unsqueeze(0).to(device)

    similarities = [] # Similarity scores for candidates
    
    # Process each candidate image
    for candidate_path in candidate_paths:
        try:
            # Load and preprocess candidate image
            candidate_img = Image.open(candidate_path).convert('RGB')
            candidate_np = np.array(candidate_img)
            transformed_candidate = val_transform(image=candidate_np)
            candidate_tensor = torch.tensor(transformed_candidate["image"]).permute(2, 0, 1).float().unsqueeze(0).to(device)
            
            # Feature extraction with cross-attention
            feature_extractor.eval()
            with torch.no_grad():
                query_feat, candidate_feat, _ = feature_extractor(query_tensor, candidate_tensor)
                sim = F.cosine_similarity(query_feat, candidate_feat).item()
                similarities.append((candidate_path, sim))
                
        except Exception as e:
            print(f"Error processing candidate {candidate_path}: {e}")
            similarities.append((candidate_path, -10.0))
    
    # Rank candidates by similarity
    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[:top_k]

# Perform reranking from first retrieval results
def perform_reranking_from_first_results(model_config_path, weights_path, query_file_path, 
                                        first_retrieval_path, top_k=5, output_json_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    feature_extractor = load_model(model_config_path, weights_path, device)
    
    # Load first retrieval results
    try:
        with open(first_retrieval_path, 'r') as f:
            first_results = json.load(f)
        print(f"Loaded first retrieval results from {first_retrieval_path}")
    except Exception as e:
        print(f"Error loading first retrieval results: {e}")
        return {}
    
    # Get query image paths
    try:
        with open(query_file_path, 'r') as f:
            query_image_paths = [line.strip().split('.png ')[0] + ".png" for line in f if line.strip()]
        print(f"Found {len(query_image_paths)} query images in {query_file_path}")
    
        # Sorting query images
        reranking_results = {}
        for query_path in query_image_paths:
            try:
                if not os.path.exists(query_path):
                    print(f"Warning: Query image not found: {query_path}")
                    continue
                
                # Load query image
                query_image = Image.open(query_path).convert('RGB')
                # print(f"\nPerforming retrieval with query: {query_path}")
                
                # Get candidate images from first retrieval
                candidate_paths = [result["image_path"] for result in first_results.get(query_path, [])]
                
                if not candidate_paths:
                    print(f"Warning: No candidates found for query {query_path}")
                    continue

                results = perform_cross_attention_reranking(
                    feature_extractor, query_image, candidate_paths, top_k=top_k
                )

                # print(f"Top {top_k} retrieval results:")
                # for i, (path, score) in enumerate(results):
                #     print(f"{i+1}. {path} (Similarity: {score:.4f})")
                
                reranking_results[query_path] = results
                print(f"Processed query: {query_path}, found {len(results)} results")
                
            except Exception as e:
                print(f"Error processing query {query_path}: {e}")

            if output_json_path and reranking_results:
                os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
                with open(output_json_path, 'w') as json_file:
                    json.dump(reranking_results, json_file)
                print(f"Saved retrieval results to {output_json_path}")
            
        return reranking_results
        
    except Exception as e:
        print(f"Error reading query file: {e}")
        return {}

def load_ground_truth(correct_result_file, class_path):
    """
    Load ground truth data from CSV files.
    
    Parameters:
    - correct_result_file: Ground truth including target categories
    - class_path: Category information of the query images
    
    Returns:
    - Ground truth dictionary: a dictionary where keys are categories and values are sets of correct results
    """
    # Get query image categories
    query_classes = {}
    with open(class_path, 'r') as file:
        reader = csv.reader(file)
        for row in reader:
            if len(row) >= 2:
                filename = row[0]
                class_name = row[1]
                query_classes[filename] = class_name
    print(f"Loaded classes for {len(query_classes)} queries")

    ground_truth = {}
    
    # Get ground truth results
    with open(correct_result_file, 'r') as correct_file:
        correct_reader = csv.reader(correct_file)
        # Get filenames as key
        try:
            filenames = next(correct_reader)
        except StopIteration:
            print("Error: Correct results file is empty!")
            return {}
        
        for row in correct_reader:
            for i, value in enumerate(row):
                if value.strip():
                    if filenames[i] not in ground_truth:
                        ground_truth[filenames[i]] = set()
                    ground_truth[filenames[i]].add(value.strip())
                    # print(class_name, os.path.splitext(value.strip())[0])
    # print(ground_truth)
    print(f"Loaded ground truth data for {len(ground_truth)} classes")
    return ground_truth

def calculate_ap(retrieved, relevant, top_k=None):
    if not relevant:
        return 0.0
    
    relevant_set = set(relevant)
    if top_k is not None:
        retrieved = retrieved[:top_k]

    precisions = []
    num_correct = 0
    
    for i, item in enumerate(retrieved):
        if item in relevant_set:
            num_correct += 1
            precision = num_correct / (i + 1)
            precisions.append(precision)

    if not precisions:
        return 0.0
    
    return sum(precisions) / min(len(relevant_set), top_k) if top_k else sum(precisions) / len(relevant_set)

def evaluate_retrieval_results(retrieval_results, ground_truth, top_k=5):
    all_precision = []
    all_recall = []
    all_f1 = []
    all_ap = []

    query_to_class = {}
    with open('/Dataset/Test_element.csv', 'r') as file:
        reader = csv.reader(file)
        for row in reader:
            if len(row) >= 2:
                query_id = os.path.splitext(row[0])[0]
                query_to_class[query_id] = row[1]
    
    print(f"Loaded class mappings for {len(query_to_class)} queries")

    # Evaluate each query and write to CSV
    with open('/Output/Test_map.csv', 'w', newline='') as file:
        writer = csv.writer(file)

        for query_path, results in retrieval_results.items():
            query_id = os.path.basename(query_path).split(',')[0]
            query_class = query_to_class.get(query_id)
            
            if query_class is None:
                print(f"Warning: No class found for query {query_id}")
                continue
            
            # Get relevant items from ground truth
            relevant_items = ground_truth.get(query_class, set())
            
            if not relevant_items:
                print(f"Warning: No ground truth items found for class {query_class}")
                continue

            retrieved_ids = [os.path.basename(path).split('.')[0] for path, _ in results[:top_k]]

            correct_predictions = sum(1 for item in retrieved_ids if item in relevant_items)

            precision = correct_predictions / top_k if top_k > 0 else 0
            recall = correct_predictions / len(relevant_items) if len(relevant_items) > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            ap = calculate_ap(retrieved_ids, list(relevant_items), top_k=top_k)

            all_precision.append(precision)
            all_recall.append(recall)
            all_f1.append(f1)
            all_ap.append(ap)
            
            # print(f"Query: {query_id} ({query_class})")
            print(f"  AP: {ap:.4f}")
            print(f"  Precision@{top_k}: {precision:.4f}")
            print(f"  Recall@{top_k}: {recall:.4f}")
            print(f"  F1-score@{top_k}: {f1:.4f}")
            writer.writerow([query_id, query_class, ap, precision, recall, f1])

    avg_precision = np.mean(all_precision) if all_precision else 0
    avg_recall = np.mean(all_recall) if all_recall else 0
    avg_f1 = np.mean(all_f1) if all_f1 else 0
    map_score = np.mean(all_ap) if all_ap else 0
    
    print(f"\nAverage Metrics (Top-{top_k}):")
    print(f"  MAP: {map_score:.4f}")
    print(f"  Precision: {avg_precision:.4f}")
    print(f"  Recall: {avg_recall:.4f}")
    print(f"  F1-score: {avg_f1:.4f}")
    
    return {
        'map': map_score,
        'precision': avg_precision,
        'recall': avg_recall,
        'f1': avg_f1
    }

if __name__ == "__main__":
    model_config_path = "/Weight_Path/dinov2-pytorch-base-v1" # Pre-trained model path
    weights_path = "Weight_Path/dinov2_query_epoch_100.pth" # Fine-tuned weights path
    first_retrieval_path = "/Initial_Retrieval/Output/first_retrieval_results.json" # First retrieval results file
    query_image_path = "/home/mayunjiao/MYJ/dataset/element_all/test.txt"  # Query image list file

    correct_result_file = "/Dataset/TestAll_index.csv"
    class_path = "/Dataset/Test_element.csv"
    rerank_output_path = "/Reranking/Output/reranking_results.json"

    reranking_results = perform_reranking_from_first_results(
        model_config_path, 
        weights_path, 
        query_image_path, 
        first_retrieval_path, 
        top_k=10,
        output_json_path=rerank_output_path
    )

    ground_truth = load_ground_truth(correct_result_file, class_path)

    if reranking_results and ground_truth:
        evaluate_retrieval_results(reranking_results, ground_truth, top_k=10)