In [2]:
"""
Validation suite for TaxonomyExpander model with comprehensive test dataset
and metric calculations.
"""

import json
from typing import Dict, List, Tuple, Set
from collections import defaultdict
import numpy as np

# Comprehensive validation dataset with known taxonomic relationships
VALIDATION_DATASET = {
    # Domain -> Kingdom
    "domain_kingdom": [
        {"parent": "Eukarya", "rank": "kingdom", "expected": {"Animalia", "Plantae", "Fungi", "Protista"}},
        {"parent": "Bacteria", "rank": "kingdom", "expected": {"Bacteria"}},
        {"parent": "Archaea", "rank": "kingdom", "expected": {"Archaea"}},
    ],
    
    # Kingdom -> Phylum (Animalia)
    "kingdom_phylum": [
        {"parent": "Animalia", "rank": "phylum", "expected": {
            "Chordata", "Arthropoda", "Mollusca", "Annelida", "Echinodermata",
            "Cnidaria", "Platyhelminthes", "Nematoda", "Porifera"
        }},
        {"parent": "Plantae", "rank": "phylum", "expected": {
            "Tracheophyta", "Bryophyta", "Marchantiophyta"
        }},
    ],
    
    # Phylum -> Class
    "phylum_class": [
        {"parent": "Chordata", "rank": "class", "expected": {
            "Mammalia", "Aves", "Reptilia", "Amphibia", "Actinopterygii", "Chondrichthyes"
        }},
        {"parent": "Arthropoda", "rank": "class", "expected": {
            "Insecta", "Arachnida", "Crustacea", "Myriapoda"
        }},
    ],
    
    # Class -> Order
    "class_order": [
        {"parent": "Mammalia", "rank": "order", "expected": {
            "Primates", "Carnivora", "Rodentia", "Chiroptera", "Cetacea",
            "Artiodactyla", "Perissodactyla", "Proboscidea"
        }},
        {"parent": "Aves", "rank": "order", "expected": {
            "Passeriformes", "Falconiformes", "Strigiformes", "Psittaciformes", "Columbiformes"
        }},
        {"parent": "Insecta", "rank": "order", "expected": {
            "Coleoptera", "Lepidoptera", "Hymenoptera", "Diptera", "Hemiptera"
        }},
    ],
    
    # Order -> Family
    "order_family": [
        {"parent": "Primates", "rank": "family", "expected": {
            "Hominidae", "Cercopithecidae", "Hylobatidae", "Lemuridae", "Callitrichidae"
        }},
        {"parent": "Carnivora", "rank": "family", "expected": {
            "Felidae", "Canidae", "Ursidae", "Mustelidae", "Phocidae"
        }},
        {"parent": "Rodentia", "rank": "family", "expected": {
            "Muridae", "Sciuridae", "Cricetidae", "Castoridae"
        }},
    ],
    
    # Family -> Genus
    "family_genus": [
        {"parent": "Hominidae", "rank": "genus", "expected": {
            "Homo", "Pan", "Gorilla", "Pongo"
        }},
        {"parent": "Felidae", "rank": "genus", "expected": {
            "Panthera", "Felis", "Lynx", "Puma", "Acinonyx"
        }},
        {"parent": "Canidae", "rank": "genus", "expected": {
            "Canis", "Vulpes", "Lycaon"
        }},
    ],
    
    # Genus -> Species
    "genus_species": [
        {"parent": "Homo", "rank": "species", "expected": {
            "Homo sapiens"
        }},
        {"parent": "Canis", "rank": "species", "expected": {
            "Canis lupus", "Canis familiaris", "Canis latrans"
        }},
        {"parent": "Panthera", "rank": "species", "expected": {
            "Panthera leo", "Panthera tigris", "Panthera pardus", "Panthera onca"
        }},
    ],
}


class TaxonomyValidator:
    """Validates taxonomy model predictions against ground truth."""
    
    def __init__(self, model):
        self.model = model
        self.results = []
        
    def validate_single(self, parent: str, rank: str, expected: Set[str]) -> Dict:
        """Validate a single query and return metrics."""
        # Get the appropriate method based on rank
        method_name = f"get_{rank}"
        if not hasattr(self.model, method_name):
            return {
                "parent": parent,
                "rank": rank,
                "predicted": set(),
                "expected": expected,
                "tp": 0, "fp": 0, "fn": len(expected), "tn": 0
            }
        
        method = getattr(self.model, method_name)
        
        # Execute query
        try:
            if rank == "kingdom":
                predicted = set(method(parent))
            else:
                predicted = set(method(parent))
        except Exception as e:
            print(f"Error querying {parent} -> {rank}: {e}")
            predicted = set()
        
        # Calculate confusion matrix elements
        tp = len(predicted & expected)  # True Positives
        fp = len(predicted - expected)  # False Positives
        fn = len(expected - predicted)  # False Negatives
        
        return {
            "parent": parent,
            "rank": rank,
            "predicted": predicted,
            "expected": expected,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": 0  # TN not well-defined in this context
        }
    
    def run_validation(self) -> List[Dict]:
        """Run full validation on the dataset."""
        print("Starting validation...")
        print("=" * 80)
        
        for category, test_cases in VALIDATION_DATASET.items():
            print(f"\nTesting: {category}")
            print("-" * 80)
            
            for test_case in test_cases:
                result = self.validate_single(
                    test_case["parent"],
                    test_case["rank"],
                    test_case["expected"]
                )
                self.results.append(result)
                
                # Print result
                print(f"\nParent: {result['parent']} -> {result['rank']}")
                print(f"Expected: {len(result['expected'])} items")
                print(f"Predicted: {len(result['predicted'])} items")
                print(f"TP: {result['tp']}, FP: {result['fp']}, FN: {result['fn']}")
                
                if result['predicted']:
                    print(f"Sample predictions: {list(result['predicted'])[:3]}")
        
        return self.results
    
    def calculate_metrics(self) -> Dict:
        """Calculate all evaluation metrics."""
        if not self.results:
            return {}
        
        # Aggregate counts
        total_tp = sum(r['tp'] for r in self.results)
        total_fp = sum(r['fp'] for r in self.results)
        total_fn = sum(r['fn'] for r in self.results)
        
        # Micro metrics (aggregate all predictions)
        precision_micro = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall_micro = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        f1_micro = (2 * precision_micro * recall_micro / (precision_micro + recall_micro) 
                   if (precision_micro + recall_micro) > 0 else 0)
        
        # Macro metrics (average per query)
        precisions = []
        recalls = []
        f1s = []
        
        for r in self.results:
            # Precision for this query
            p = r['tp'] / (r['tp'] + r['fp']) if (r['tp'] + r['fp']) > 0 else 0
            precisions.append(p)
            
            # Recall for this query
            rec = r['tp'] / (r['tp'] + r['fn']) if (r['tp'] + r['fn']) > 0 else 0
            recalls.append(rec)
            
            # F1 for this query
            f1 = 2 * p * rec / (p + rec) if (p + rec) > 0 else 0
            f1s.append(f1)
        
        precision_macro = np.mean(precisions)
        recall_macro = np.mean(recalls)
        f1_macro = np.mean(f1s)
        
        # Accuracy (correct predictions / total predictions made)
        total_predictions = sum(len(r['predicted']) for r in self.results)
        total_expected = sum(len(r['expected']) for r in self.results)
        accuracy = total_tp / max(total_predictions, total_expected) if max(total_predictions, total_expected) > 0 else 0
        
        metrics = {
            "accuracy": accuracy,
            "precision_macro": precision_macro,
            "recall_macro": recall_macro,
            "f1_macro": f1_macro,
            "precision_micro": precision_micro,
            "recall_micro": recall_micro,
            "f1_micro": f1_micro,
            "total_queries": len(self.results),
            "total_tp": total_tp,
            "total_fp": total_fp,
            "total_fn": total_fn,
        }
        
        return metrics
    
    def print_metrics(self):
        """Print formatted metrics report."""
        metrics = self.calculate_metrics()
        
        print("\n" + "=" * 80)
        print("VALIDATION METRICS")
        print("=" * 80)
        print(f"\nTotal Queries: {metrics['total_queries']}")
        print(f"Total True Positives: {metrics['total_tp']}")
        print(f"Total False Positives: {metrics['total_fp']}")
        print(f"Total False Negatives: {metrics['total_fn']}")
        print(f"\n{'Metric':<20} {'Value':<10}")
        print("-" * 80)
        print(f"{'Accuracy':<20} {metrics['accuracy']:.4f}")
        print(f"\n{'Macro Metrics:':<20}")
        print(f"{'  Precision':<20} {metrics['precision_macro']:.4f}")
        print(f"{'  Recall':<20} {metrics['recall_macro']:.4f}")
        print(f"{'  F1-Score':<20} {metrics['f1_macro']:.4f}")
        print(f"\n{'Micro Metrics:':<20}")
        print(f"{'  Precision':<20} {metrics['precision_micro']:.4f}")
        print(f"{'  Recall':<20} {metrics['recall_micro']:.4f}")
        print(f"{'  F1-Score':<20} {metrics['f1_micro']:.4f}")
        print("=" * 80)
        
        return metrics


# Usage example
if __name__ == "__main__":
    # Import your model
    from taxonomy_expander import TaxonomyExpander
    
    # Initialize model and validator
    model = TaxonomyExpander()
    validator = TaxonomyValidator(model)
    
    # Run validation
    validator.run_validation()
    
    # Calculate and print metrics
    metrics = validator.print_metrics()
    
    # Save results to JSON
    with open('validation_results.json', 'w') as f:
        json.dump({
            'metrics': metrics,
            'detailed_results': [
                {
                    'parent': r['parent'],
                    'rank': r['rank'],
                    'tp': r['tp'],
                    'fp': r['fp'],
                    'fn': r['fn'],
                    'predicted_count': len(r['predicted']),
                    'expected_count': len(r['expected'])
                }
                for r in validator.results
            ]
        }, f, indent=2)
    
    print("\nResults saved to validation_results.json")

Starting validation...

Testing: domain_kingdom
--------------------------------------------------------------------------------

Parent: Eukarya -> kingdom
Expected: 4 items
Predicted: 4 items
TP: 4, FP: 0, FN: 0
Sample predictions: ['Fungi', 'Plantae', 'Animalia']

Parent: Bacteria -> kingdom
Expected: 1 items
Predicted: 1 items
TP: 1, FP: 0, FN: 0
Sample predictions: ['Bacteria']

Parent: Archaea -> kingdom
Expected: 1 items
Predicted: 1 items
TP: 1, FP: 0, FN: 0
Sample predictions: ['Archaea']

Testing: kingdom_phylum
--------------------------------------------------------------------------------

Parent: Animalia -> phylum
Expected: 9 items
Predicted: 10 items
TP: 1, FP: 9, FN: 8
Sample predictions: ['Monoblastozoa', 'Porifera', 'Vendozoa']

Parent: Plantae -> phylum
Expected: 3 items
Predicted: 5 items
TP: 0, FP: 5, FN: 3
Sample predictions: ['sporae dispersae', 'Euthallophyta', 'Hepatophyta']

Testing: phylum_class
---------------------------------------------------------------