# JUST AFTER IMPORTING LIBRARIES, THERE IS CONFIG. CHANGE CHECKPOINT PATH THERE (IF NEEDED). IN THE NEXT CODE CELL, CHANGE TEST FOLDER PATH

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
import os
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm # Use tqdm for notebooks/Colab
import collections # Although not directly used in the snippet, good to include if it was in the original context
import math
from sklearn.metrics import f1_score # Import F1 score utility

# --- 1. Setup and Configuration (Essential for Model Definition and Paths) ---
CONFIG = {
    "BASE_PATH": "/kaggle/input/comsys/Comys_Hackathon5/Task_B", # This might not be strictly needed for evaluation
    "OUTPUT_PATH": "/kaggle/working/data", # Not strictly needed for evaluation but good to have
    "BEST_MODEL_PATH": "/kaggle/input/ckpt-for-comsystaska/best_embedding_model_TaskB.pth", # Your provided best model path
    "EMBEDDING_DIM": 512, # Crucial: Must match the embedding_dim used during training
}

# Ensure the output directory exists (if you plan to save anything, otherwise remove)
os.makedirs(CONFIG["OUTPUT_PATH"], exist_ok=True) # Good practice even if just evaluating

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {DEVICE} ---")

# --- 2. Required Model Architecture (Copy these classes here) ---
class EmbeddingNet(nn.Module):
    """The core feature extractor network, using EfficientNet-B4."""
    def __init__(self, embedding_dim):
        super(EmbeddingNet, self).__init__()
        weights = EfficientNet_B4_Weights.DEFAULT
        self.backbone = efficientnet_b4(weights=weights)
        in_features = self.backbone.classifier[1].in_features
        # Ensure this matches how you replaced the classifier during training
        self.backbone.classifier = nn.Linear(in_features, embedding_dim)
        self.transforms = weights.transforms()

    def forward(self, x):
        embedding = self.backbone(x)
        return embedding

# NOTE: ArcMarginProduct is NOT needed for inference/evaluation, as it's part of the classification head.
# The EmbeddingNet itself is what generates the embeddings for similarity comparison.
# class ArcMarginProduct(nn.Module):
#     # ... (definition from your original code) ...
#     pass

# --- 3. Required Data Preparation Functions (Copy these functions here) ---
# This function is crucial for organizing your test data for evaluation.
def prepare_evaluation_sets(data_path):
    """
    Scans the data directory once to get the paths for reference and query images.
    This avoids re-scanning the disk on every epoch.
    Assumes a structure like:
    data_path/
    ├── ClassA/
    │   ├── img1.jpg (reference)
    │   └── distortion/
    │       └── distorted_img1.jpg (query)
    └── ClassB/
        ├── img1.jpg (reference)
        └── distortion/
            └── distorted_img1.jpg (query)
    """
    reference_gallery_paths = {}
    query_set = []
    
    # Ensure data_path exists and is a directory
    if not os.path.exists(data_path) or not os.path.isdir(data_path):
        print(f"Warning: Data path does not exist or is not a directory: {data_path}")
        return {}, [], []

    person_classes = sorted([d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))])
    
    for class_name in person_classes:
        class_path = os.path.join(data_path, class_name)
        distortion_path = os.path.join(class_path, 'distortion')
        
        # Collect clean images for reference gallery
        clean_images = [os.path.join(class_path, f) for f in os.listdir(class_path)
                        if f != 'distortion' and f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        if clean_images:
            reference_gallery_paths[class_name] = clean_images
        
        # Collect distorted images for query set
        if os.path.exists(distortion_path) and os.path.isdir(distortion_path):
            for img_name in os.listdir(distortion_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    query_set.append((os.path.join(distortion_path, img_name), class_name))
    
    return reference_gallery_paths, query_set, person_classes

# --- 4. Required Evaluation Function (Copy this function here) ---
def evaluate(model, ref_gallery_paths, query_set, person_classes, transform, device):
    """
    Evaluates the model using pre-calculated file paths.
    """
    print(f"\n--- Evaluating Model ---")
    model.to(device)
    model.eval() # Set model to evaluation mode

    y_true, y_pred = [], []
    class_to_idx = {name: i for i, name in enumerate(person_classes)}

    if not query_set:
        print("No query images found to evaluate. Check your data structure or prepare_evaluation_sets.")
        return 0.0, 0.0

    print("Creating reference embedding gallery...")
    avg_reference_embeddings = collections.OrderedDict() # Use OrderedDict for consistent ordering
    with torch.no_grad():
        for class_name, img_paths in tqdm(ref_gallery_paths.items(), desc="Processing reference images"):
            embeddings = []
            for p in img_paths:
                try:
                    img = Image.open(p).convert("RGB")
                    tensor_img = transform(img).unsqueeze(0).to(device)
                    embeddings.append(model(tensor_img))
                except Exception as e:
                    print(f"Error processing reference image {p}: {e}")
                    continue # Skip this image and continue
            if embeddings: # Only average if there were successful embeddings
                avg_reference_embeddings[class_name] = torch.mean(torch.cat(embeddings), dim=0)
            else:
                print(f"Warning: No valid reference images for class {class_name}. Skipping.")

    if not avg_reference_embeddings:
        print("No valid reference embeddings created. Cannot evaluate.")
        return 0.0, 0.0

    ref_labels = list(avg_reference_embeddings.keys())
    ref_embeds = torch.stack(list(avg_reference_embeddings.values()))
    
    print("Matching query images against gallery...")
    with torch.no_grad():
        for query_path, true_label in tqdm(query_set, desc="Processing query images"):
            try:
                img_tensor = transform(Image.open(query_path).convert("RGB")).unsqueeze(0).to(device)
                query_embedding = model(img_tensor)
                
                # Calculate cosine similarity (more common for face recognition)
                # Ensure embeddings are normalized if not already by the model
                normalized_query = F.normalize(query_embedding)
                normalized_refs = F.normalize(ref_embeds)
                
                # Cosine similarity: (Q . R) / (||Q|| * ||R||) -> simplified to Q . R if normalized
                similarities = torch.matmul(normalized_query, normalized_refs.T)
                
                # For similarity, higher is better, so argmax. For distance (cdist), argmin.
                best_match_idx = torch.argmax(similarities, dim=1).item()
                
                predicted_label = ref_labels[best_match_idx]
                y_true.append(class_to_idx[true_label])
                y_pred.append(class_to_idx[predicted_label])
            except Exception as e:
                print(f"Error processing query image {query_path}: {e}")
                # Decide how to handle this: skip, or add a 'failure' prediction
                # For now, we skip, which means y_true and y_pred might be shorter than query_set initially
                continue

    if not y_true: # If no successful queries were processed
        print("No successful query predictions made. Cannot calculate metrics.")
        return 0.0, 0.0

    accuracy = (np.sum(np.array(y_true) == np.array(y_pred)) / len(y_true)) * 100
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    
    print(f"Evaluation Complete:")
    print(f"  - Top-1 Accuracy: {accuracy:.2f}%")
    print(f"  - Macro Avg F1-Score: {macro_f1:.4f}")
    return accuracy, macro_f1

# ENTER TEST DATASET PATH IN THE NEXT CELL.

In [None]:

# --- Main Execution Block for Standalone Evaluation ---
if __name__ == '__main__':
    print("\n--- Starting Standalone Evaluation on a NEW, Separate Test Folder ---")

    # Define the path to your new test folder
    # !!! IMPORTANT: REPLACE THIS WITH THE ACTUAL PATH TO YOUR NEW TEST FOLDER !!!
    NEW_TEST_FOLDER_PATH = "/kaggle/input/comsys/Comys_Hackathon5/Task_B/val" # Example: /path/to/your/new_test_dataset

    if not os.path.exists(NEW_TEST_FOLDER_PATH):
        print(f"Error: The specified NEW_TEST_FOLDER_PATH does not exist: {NEW_TEST_FOLDER_PATH}")
        print("Please update NEW_TEST_FOLDER_PATH to your actual test data location and ensure it's accessible.")
        exit() # Exit if the test path is invalid

    # Prepare evaluation sets for the new test folder
    print(f"\n--- Pre-calculating file paths for the NEW Test Set: {NEW_TEST_FOLDER_PATH} ---")
    new_test_gallery_paths, new_test_query_set, new_test_person_classes = prepare_evaluation_sets(NEW_TEST_FOLDER_PATH)
    print(f"Found {len(new_test_gallery_paths)} reference classes and {len(new_test_query_set)} query images in the NEW test set.")

    if not new_test_query_set:
        print("No query images found in the NEW test folder to evaluate. Ensure 'distortion' subfolders exist if intended for queries.")
    elif not new_test_gallery_paths:
        print("No reference images found in the NEW test folder to create a gallery. Evaluation cannot proceed.")
    else:
        # Load the BEST saved model for evaluation on the new test set
        print("\n--- Loading best model for evaluation ---")
        
        # Instantiate a fresh model to load weights into.
        eval_feature_extractor = EmbeddingNet(embedding_dim=CONFIG["EMBEDDING_DIM"])
        
        # Load the state dictionary from the provided path
        try:
            # map_location ensures it loads correctly whether on CPU or GPU
            eval_feature_extractor.load_state_dict(torch.load(CONFIG["BEST_MODEL_PATH"], map_location=DEVICE))
            print(f"Successfully loaded model from {CONFIG['BEST_MODEL_PATH']}")
        except FileNotFoundError:
            print(f"Error: Best model file not found at {CONFIG['BEST_MODEL_PATH']}. Please ensure the path is correct and the file exists.")
            exit()
        except Exception as e:
            print(f"An error occurred while loading the model: {e}")
            exit()

        eval_feature_extractor.to(DEVICE) # Move the model to the chosen device

        # Get the transformation used by the EfficientNet model
        # This is crucial: EfficientNet expects specific input preprocessing.
        data_transform = EfficientNet_B4_Weights.DEFAULT.transforms()

        # Perform the evaluation using the loaded model
        print(f"\n--- Evaluating on NEW Test Folder: {NEW_TEST_FOLDER_PATH} ---")
        evaluate(eval_feature_extractor, new_test_gallery_paths, new_test_query_set, new_test_person_classes, data_transform, DEVICE)

print("\n--- Standalone Evaluation Process Complete ---")