In [34]:
import json
import re
import gc
from typing import List, Dict, Set, Tuple, Any, Optional, Union
from collections import defaultdict, Counter
import time
import os

import numpy as np
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from rapidfuzz import fuzz
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import precision_score, recall_score, f1_score, matthews_corrcoef
import matplotlib.pyplot as plt
from functools import lru_cache
from datetime import datetime
import math

In [35]:
# Initialize GPU/CPU
device = 0 if torch.cuda.is_available() else -1

# Initialize SpaCy
try:
    nlp = spacy.load("en_core_web_lg")
except OSError:
    print("Downloading SpaCy model...")
    spacy.cli.download("en_core_web_lg")
    nlp = spacy.load("en_core_web_lg")

# Initialize SentenceTransformer model
#st_model = SentenceTransformer("nasa-impact/nasa-ibm-st.38m")
st_model = SentenceTransformer("nasa-impact/nasa-smd-ibm-st-v2")
st_model.eval()
st_model.to("cuda" if torch.cuda.is_available() else "cpu")


# Memory management
def optimize_memory():
    # Free up memory resources
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

In [36]:
class ScienceClassifier:
    # Science document classifier with model loading and caching
    _instance = None

    @classmethod
    def get_instance(cls):
        # Get or create singleton instance
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    def __init__(self):
        # Model configurations
        self.models = {
            "research_area": {
                "name": "arminmehrabian/nasa-impact-nasa-smd-ibm-st-v2-classification-finetuned",
                "label_map": {
                    0: "Agriculture",
                    1: "Air Quality",
                    2: "Atmospheric/Ocean Indicators",
                    3: "Cryospheric Indicators",
                    4: "Droughts",
                    5: "Earthquakes",
                    6: "Ecosystems",
                    7: "Energy Production/Use",
                    8: "Environmental Impacts",
                    9: "Floods",
                    10: "Greenhouse Gases",
                    11: "Habitat Conversion/Fragmentation",
                    12: "Heat",
                    13: "Land Surface/Agriculture Indicators",
                    14: "Public Health",
                    15: "Severe Storms",
                    16: "Sun-Earth Interactions",
                    17: "Validation",
                    18: "Volcanic Eruptions",
                    19: "Water Quality",
                    20: "Wildfires",
                },
            },
            "science_keywords": {"name": "nasa-impact/science-keyword-classification"},
            "division": {
                "name": "nasa-impact/division-classifier",
                "label_map": {
                    0: "Astrophysics",
                    1: "Biological and Physical Sciences",
                    2: "Earth Science",
                    3: "Heliophysics",
                    4: "Planetary Science",
                },
            },
        }
        self.classification_cache = {}
        self._load_models()

    def _load_models(self):
        # Load classification models
        self.classifiers = {}
        for task, config in self.models.items():
            try:
                print(f"Loading {task} model from {config['name']}...")
                with torch.no_grad():
                    tokenizer = AutoTokenizer.from_pretrained(config["name"])
                    model = AutoModelForSequenceClassification.from_pretrained(
                        config["name"]
                    )
                    model.eval()

                self.classifiers[task] = {
                    "pipe": pipeline(
                        "text-classification",
                        model=model,
                        tokenizer=tokenizer,
                        device=device,
                        batch_size=32,
                    ),
                    "config": config,
                }
                print(f"Successfully loaded {task} model")
            except Exception as e:
                print(f"Failed to load {task} model: {str(e)}")

    def _prepare_text(self, publication: Dict) -> str:
        # Extract and combine text from publication for classification
        # Extract title
        title = ""
        if "title" in publication:
            if isinstance(publication["title"], list) and publication["title"]:
                title = publication["title"][0]
            elif isinstance(publication["title"], str):
                title = publication["title"]

        abstract = publication.get("abstract", "")
        keywords = " ".join(publication.get("keywords", []))
        return " ".join([title, abstract, keywords])

    def _get_cache_key(self, publication: Dict) -> str:
        # Generate caches for publication
        if "DOI" in publication and publication["DOI"]:
            return f"doi:{publication['DOI']}"

        title = ""
        if "title" in publication:
            if isinstance(publication["title"], list) and publication["title"]:
                title = publication["title"][0]
            elif isinstance(publication["title"], str):
                title = publication.get("title", "")

        return f"title:{title}"

    @torch.inference_mode()
    def classify(self, publication: Dict) -> Dict:
        # Run classification on publications with caching
        # Check cache
        cache_key = self._get_cache_key(publication)
        if cache_key in self.classification_cache:
            return self.classification_cache[cache_key]

        text = self._prepare_text(publication)
        results = {"research_areas": [], "science_keywords": [], "division": None}

        # Research Area Classification
        if self.classifiers.get("research_area"):
            try:
                res_area = self.classifiers["research_area"]["pipe"](
                    text, top_k=3, truncation=True, max_length=512
                )
                results["research_areas"] = [
                    {
                        "label": self.models["research_area"]["label_map"][
                            int(pred["label"].replace("LABEL_", ""))
                        ],
                        "score": float(pred["score"]),
                    }
                    for pred in res_area
                ]
            except Exception as e:
                print(f"Research area classification failed: {str(e)}")

        # Science Keywords Classification
        if self.classifiers.get("science_keywords"):
            try:
                science_keywords = self.classifiers["science_keywords"]["pipe"](
                    text, truncation=True, max_length=512, top_k=10
                )

                for pred in science_keywords:
                    if pred["score"] > 0.35 and len(pred["label"]) >= 4:
                        results["science_keywords"].append(
                            {"label": pred["label"], "score": float(pred["score"])}
                        )
            except Exception as e:
                print(f"Science keyword classification failed: {str(e)}")

        # Division Classification
        if self.classifiers.get("division"):
            try:
                division_result = self.classifiers["division"]["pipe"](
                    text, top_k=1, truncation=True, max_length=512
                )

                if division_result:
                    division = division_result[0]
                    if "score" in division:
                        results["division"] = {
                            "label": division["label"],
                            "score": float(division["score"]),
                        }
            except Exception as e:
                print(f"Division classification failed: {str(e)}")

        # Cache results
        self.classification_cache[cache_key] = results
        return results


class ModelContextManager:
    # Context validation with model profiles

    def __init__(self, curated_publications: List[Dict]):
        self.st_model = st_model
        self.model_profiles = self._build_model_profiles(curated_publications)
        self.profile_cache = {}
        self.corpus_term_frequencies = self._build_corpus_term_frequencies(
            curated_publications
        )
        self.model_tfidf_terms = self._calculate_model_tfidf_terms(curated_publications)

    def _build_model_profiles(self, publications: List[Dict]) -> Dict[str, Dict]:
        # Build model context profiles from curated publications
        model_texts = defaultdict(list)
        model_terms = defaultdict(set)

        # Collect texts and terms
        for pub in publications:
            model = pub.get("model")
            if model:
                prepared_text = ScienceClassifier.get_instance()._prepare_text(pub)
                model_texts[model].append(prepared_text)
                model_terms[model].update(self._extract_key_terms(prepared_text))

        # Create embeddings
        final_profiles = {}
        for model, texts in model_texts.items():
            aggregated_text = " ".join(texts[: min(100, len(texts))])

            with torch.inference_mode():
                embedding = self.st_model.encode(
                    aggregated_text, convert_to_tensor=True
                )

            final_profiles[model] = {
                "embedding": embedding.cpu().numpy(),
                "terms": set(sorted(model_terms[model], key=lambda x: -len(x))[:25]),
                "text_count": len(texts),
            }

        optimize_memory()
        return final_profiles

    def _build_corpus_term_frequencies(
        self, publications: List[Dict]
    ) -> Dict[str, int]:
        # Build term frequency dictionary for entire corpus
        corpus_terms = Counter()

        for pub in publications:
            prepared_text = ScienceClassifier.get_instance()._prepare_text(pub)
            words = re.findall(r"\b[a-z]{4,}\b", prepared_text.lower())
            corpus_terms.update(words)

        return corpus_terms

    def _calculate_model_tfidf_terms(
        self, publications: List[Dict]
    ) -> Dict[str, List[Tuple[str, float]]]:
        # Calculate TF-IDF terms for each model
        model_term_counts = defaultdict(Counter)
        model_doc_counts = defaultdict(int)

        # Count terms by model
        for pub in publications:
            model = pub.get("model")
            if model:
                model_doc_counts[model] += 1
                prepared_text = ScienceClassifier.get_instance()._prepare_text(pub)
                words = re.findall(r"\b[a-z]{4,}\b", prepared_text.lower())
                model_term_counts[model].update(words)

        # Calculate total document count
        total_docs = sum(model_doc_counts.values())

        # Calculate TF-IDF for each term in each model
        model_tfidf_terms = {}
        for model, term_counts in model_term_counts.items():
            tfidf_scores = {}
            model_doc_count = model_doc_counts[model]

            for term, count in term_counts.items():
                # Term frequency in this model
                tf = count / sum(term_counts.values())

                # Inverse document frequency (add 1 to avoid division by zero)
                term_doc_count = sum(
                    1 for m, tc in model_term_counts.items() if term in tc
                )
                idf = np.log((total_docs + 1) / (term_doc_count + 1))

                # TF-IDF score
                tfidf = tf * idf

                # Only keep terms with sufficient frequency and length
                if count >= 3 and len(term) >= 4:
                    tfidf_scores[term] = tfidf

            # Sort by TF-IDF score and take top terms
            sorted_terms = sorted(
                tfidf_scores.items(), key=lambda x: x[1], reverse=True
            )
            model_tfidf_terms[model] = sorted_terms[
                :30
            ]  # Keep top 30 distinctive terms

        return model_tfidf_terms

    def _extract_key_terms(self, text: str) -> Set[str]:
        # Extract key terms from text
        words = re.findall(r"\b[a-z]{4,}\b", text.lower())
        word_counts = Counter(words)
        return {word for word, count in word_counts.items() if count >= 2}

    @lru_cache(maxsize=5000)
    def _get_pub_profile(self, prepared_text: str) -> Dict:
        # Get publication profile with caching
        if prepared_text in self.profile_cache:
            return self.profile_cache[prepared_text]

        with torch.inference_mode():
            embedding = self.st_model.encode(prepared_text)

        profile = {
            "embedding": embedding,
            "terms": self._extract_key_terms(prepared_text),
        }

        self.profile_cache[prepared_text] = profile
        return profile

    def get_model_specific_terms(self) -> Dict[str, List[str]]:
        # Return model-specific terminology based on TF-IDF analysis
        model_terms = {}
        for model, tfidf_terms in self.model_tfidf_terms.items():
            model_terms[model] = [term for term, score in tfidf_terms]
        return model_terms

    def get_context_scores(self, publication: Dict) -> Dict[str, float]:
        # Get context validation scores
        prepared_text = ScienceClassifier.get_instance()._prepare_text(publication)
        pub_profile = self._get_pub_profile(prepared_text)

        scores = {}
        for model, model_profile in self.model_profiles.items():
            # Basic term overlap score
            pub_terms = pub_profile["terms"]
            model_terms = model_profile["terms"]

            if not pub_terms or not model_terms:
                term_overlap = 0.0
            else:
                intersection = len(pub_terms.intersection(model_terms))
                union = len(pub_terms.union(model_terms))
                term_overlap = intersection / union if union > 0 else 0.0

            # TF-IDF term match score
            model_tfidf_terms = self.model_tfidf_terms.get(model, [])
            tfidf_term_set = {term for term, score in model_tfidf_terms}
            tfidf_match_count = len(pub_terms.intersection(tfidf_term_set))
            tfidf_match_score = (
                min(1.0, tfidf_match_count / 5) if tfidf_term_set else 0.0
            )

            # Semantic similarity
            semantic_sim = cosine_similarity(
                [pub_profile["embedding"]], [model_profile["embedding"]]
            )[0][0]

            # Combined score (with TF-IDF term matches)
            scores[model] = (
                0.5 * semantic_sim + 0.3 * term_overlap + 0.2 * tfidf_match_score
            )

        return scores


class RelevanceRanker:
    # Rank publications by relevance to models

    def __init__(self, model_descriptions: Dict[str, str]):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_descriptions = model_descriptions
        self.description_cache = {}
        self.result_cache = {}

        # Load model
        self.tokenizer = AutoTokenizer.from_pretrained(
            "nasa-impact/nasa-smd-ibm-ranker"
        )

        with torch.inference_mode():
            self.model = AutoModelForSequenceClassification.from_pretrained(
                "nasa-impact/nasa-smd-ibm-ranker"
            ).to(self.device)

            if self.device.type != "cuda":
                self.model = self.model.float()

        self.model.eval()

    def _safe_prepare_model_text(self, model_id: str) -> str:
        # Prepare model text with caching
        if model_id not in self.description_cache:
            desc = self.model_descriptions.get(model_id, "")[:380]
            self.description_cache[model_id] = re.sub(r"\s+", " ", desc)
        return self.description_cache[model_id]

    @torch.inference_mode()
    def batch_rank(self, query: str, model_ids: List[str]) -> Dict[str, float]:
        # Rank models by relevance to query
        if not model_ids:
            return {}

        # Check cache
        cache_key = f"{hash(query)}_{hash(tuple(sorted(model_ids)))}"
        if cache_key in self.result_cache:
            return self.result_cache[cache_key]

        query = query[:400]  # Truncate query
        batch_texts = [self._safe_prepare_model_text(mid) for mid in model_ids]

        try:
            # Prepare inputs
            inputs = self.tokenizer(
                [query] * len(batch_texts),
                batch_texts,
                padding="longest" if self.device.type == "cuda" else True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            ).to(self.device)

            outputs = self.model(**inputs)
            scores = F.softmax(outputs.logits, dim=1)[:, 1].cpu().numpy()
            result = dict(zip(model_ids, scores.tolist()))

            # Cache results
            self.result_cache[cache_key] = result
            return result

        except RuntimeError as e:
            print(f"Ranker fallback to CPU: {str(e)}")
            self.device = torch.device("cpu")
            self.model = self.model.to(self.device).float()
            return self.batch_rank(query, model_ids)

        except Exception as e:
            print(f"Ranker failed: {str(e)}")
            return {mid: 0.0 for mid in model_ids}

In [37]:
# Compiled regex patterns for efficiency
HTML_TAG_PATTERN = re.compile(r"<[^>]+>")
HYPHEN_UNDERSCORE_PATTERN = re.compile(r"[-_]")
SPECIAL_CHAR_PATTERN = re.compile(r"[^a-zA-Z0-9]")


def preprocess_text(text: str) -> str:
    # lean and normalize text for matching
    if not text:
        return ""

    text = HTML_TAG_PATTERN.sub(" ", text)
    text = HYPHEN_UNDERSCORE_PATTERN.sub(" ", text)
    return SPECIAL_CHAR_PATTERN.sub(" ", text).lower()


# DOI normalization cache
_doi_cache = {}


def normalize_doi(doi: str) -> str:
    # Normalize DOI strings
    if not doi:
        return ""

    if doi in _doi_cache:
        return _doi_cache[doi]

    doi = doi.lower().replace("https://", "").replace("http://", "")
    if "doi.org/" in doi:
        doi = doi.replace("doi.org/", "")
    doi = doi.replace(",", ".")
    result = doi.strip("/ \n\r\t")

    _doi_cache[doi] = result
    return result


# Fuzzy matching cache
_fuzzy_match_cache = {}


def fuzzy_keyword_match(
    text: str, keyword: str, threshold: float = 90.0
) -> Tuple[bool, float]:
    # Fuzzy keyword matching
    cache_key = f"{hash(text)}_{keyword}_{threshold}"
    if cache_key in _fuzzy_match_cache:
        return _fuzzy_match_cache[cache_key]

    # Check for exact match first
    keyword_lower = keyword.lower()
    if keyword_lower in text.lower():
        result = (True, 100.0)
        _fuzzy_match_cache[cache_key] = result
        return result

    # Use direct fuzzy matching for shorter texts
    if len(text) < 1000:
        score = fuzz.token_sort_ratio(text.lower(), keyword_lower)
        result = (score >= threshold, score)
        _fuzzy_match_cache[cache_key] = result
        return result

    # For longer texts, use spaCy
    doc = nlp(text[:2000])

    # Process relevant parts
    chunks = [chunk.text.lower() for chunk in doc.noun_chunks]
    entities = [ent.text.lower() for ent in doc.ents]
    tokens = [
        token.text.lower()
        for token in doc
        if token.is_alpha and not token.is_stop and len(token.text) > 3
    ]

    all_spans = chunks[:50] + entities[:50] + tokens[:100]

    # Find best match
    best_score = 0.0
    for span in all_spans:
        if len(span) < 3:
            continue
        score = fuzz.token_sort_ratio(span, keyword_lower)
        if score > best_score:
            best_score = score

    result = (best_score >= threshold, best_score)
    _fuzzy_match_cache[cache_key] = result
    return result


# Publication date extraction function
def extract_publication_date(publication: Dict) -> Optional[str]:
    # Extract publication date in yyyy-mm format
    try:
        # Try to get from 'published' field first
        if "published" in publication and "date-parts" in publication["published"]:
            date_parts = publication["published"]["date-parts"][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"

        # Try published-online
        if (
            "published-online" in publication
            and "date-parts" in publication["published-online"]
        ):
            date_parts = publication["published-online"]["date-parts"][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"

        # Try published-print
        if (
            "published-print" in publication
            and "date-parts" in publication["published-print"]
        ):
            date_parts = publication["published-print"]["date-parts"][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"

        # Try indexed
        if "indexed" in publication and "date-parts" in publication["indexed"]:
            date_parts = publication["indexed"]["date-parts"][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"

        # Try created
        if "created" in publication and "date-parts" in publication["created"]:
            date_parts = publication["created"]["date-parts"][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"

        # Try to parse issue date if year is present
        if "issued" in publication and "date-parts" in publication["issued"]:
            date_parts = publication["issued"]["date-parts"][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
            elif len(date_parts) >= 1:
                return f"{date_parts[0]:04d}-01"  # Default to January if only year

        # Try publication year
        if "year" in publication:
            return f"{publication['year']:04d}-01"  # Default to January

        return None
    except Exception as e:
        print(f"Error extracting publication date: {str(e)}")
        return None


# Publication characteristics cache
_characteristics_cache = {}


def extract_publication_characteristics(publication: Dict) -> Dict[str, Any]:
    # Extract publication characteristics with caching
    cache_key = str(hash(json.dumps(publication, sort_keys=True)[:1000]))
    if cache_key in _characteristics_cache:
        return _characteristics_cache[cache_key]

    characteristics = {}

    # Extract title
    title = ""
    if "title" in publication:
        if isinstance(publication["title"], list) and publication["title"]:
            title = publication["title"][0]
        else:
            title = publication.get("title", "")

    # Extract abstract
    abstract = publication.get("abstract", "")

    # Basic metrics
    characteristics["title_length"] = len(title.split())
    characteristics["abstract_length"] = len(abstract.split())
    characteristics["total_length"] = (
        characteristics["title_length"] + characteristics["abstract_length"]
    )

    # Keywords metrics
    keywords = publication.get("keywords", [])
    characteristics["keyword_count"] = len(keywords)

    # Process with spaCy if abstract is present
    if abstract and len(abstract) > 10 and len(abstract) < 10000:
        doc = nlp(abstract[:2000])

        # Part-of-speech distributions
        pos_counts = Counter([token.pos_ for token in doc])
        doc_len = len(doc) or 1

        characteristics["noun_ratio"] = pos_counts.get("NOUN", 0) / doc_len
        characteristics["verb_ratio"] = pos_counts.get("VERB", 0) / doc_len
        characteristics["adj_ratio"] = pos_counts.get("ADJ", 0) / doc_len

        # Named entity analysis
        entity_types = [ent.label_ for ent in doc.ents]
        entity_counter = Counter(entity_types)

        characteristics["entity_count"] = len(doc.ents)
        characteristics["org_count"] = entity_counter.get("ORG", 0)
        characteristics["person_count"] = entity_counter.get("PERSON", 0)
        characteristics["date_count"] = entity_counter.get("DATE", 0)

        # Readability metrics
        sentences = list(doc.sents)
        if sentences:
            sent_lengths = [len(sent) for sent in sentences]
            characteristics["avg_sentence_length"] = np.mean(sent_lengths)
            characteristics["sentence_count"] = len(sentences)
        else:
            characteristics["avg_sentence_length"] = 0
            characteristics["sentence_count"] = 0

    # Publication year
    if "year" in publication:
        characteristics["year"] = publication.get("year")

    # Author metrics
    if "authors" in publication:
        authors = publication.get("authors", [])
        if isinstance(authors, list):
            characteristics["author_count"] = len(authors)
        else:
            characteristics["author_count"] = 1 if authors else 0

    _characteristics_cache[cache_key] = characteristics
    return characteristics

In [38]:
_model_embeddings_cache = {}
_curated_models_cache = {}
_model_keywords_cache = {}
_model_descriptions_cache = {}


def initialize_model_embeddings(
    model_descriptions: Dict[str, str],
) -> Dict[str, np.ndarray]:
    # Create embeddings for model descriptions with batch processing
    if _model_embeddings_cache and len(_model_embeddings_cache) == len(
        model_descriptions
    ):
        return _model_embeddings_cache

    embeddings = {}
    batch_size = 32

    model_list = list(model_descriptions.items())

    for i in range(0, len(model_list), batch_size):
        batch = model_list[i : i + batch_size]
        models, texts = zip(*batch)

        with torch.inference_mode():
            batch_embeddings = st_model.encode(list(texts), convert_to_tensor=False)

        for j, model in enumerate(models):
            embeddings[model] = batch_embeddings[j]

    _model_embeddings_cache.update(embeddings)
    return embeddings


def load_curated_models(curated_path: str) -> Dict[str, str]:
    # Load curated DOI-model mappings
    if curated_path in _curated_models_cache:
        return _curated_models_cache[curated_path]

    try:
        with open(curated_path) as f:
            curated = json.load(f)

        mapping = {}
        for entry in curated:
            if "doi" in entry and "model" in entry:
                normalized_doi = normalize_doi(entry["doi"])
                if normalized_doi:
                    mapping[normalized_doi] = entry["model"]

        _curated_models_cache[curated_path] = mapping
        return mapping
    except Exception as e:
        print(f"Error loading curated models from {curated_path}: {str(e)}")
        return {}


def load_model_keywords(keywords_path: str) -> Dict[str, List[str]]:
    # Load model keywords
    if keywords_path in _model_keywords_cache:
        return _model_keywords_cache[keywords_path]

    try:
        with open(keywords_path) as f:
            keywords = json.load(f)

        _model_keywords_cache[keywords_path] = keywords
        return keywords
    except Exception as e:
        print(f"Error loading model keywords from {keywords_path}: {str(e)}")
        return {}


def load_model_descriptions(descriptions_path: str) -> Dict[str, str]:
    # Load model descriptions
    if descriptions_path in _model_descriptions_cache:
        return _model_descriptions_cache[descriptions_path]

    try:
        with open(descriptions_path) as f:
            descriptions = json.load(f)

        _model_descriptions_cache[descriptions_path] = descriptions
        return descriptions
    except Exception as e:
        print(f"Error loading model descriptions from {descriptions_path}: {str(e)}")
        return {}

In [39]:
# Model thresholds - optimized based on F1 score
MODEL_THRESHOLDS = {
    "ECCO": 0.10,
    "RAPID": 0.50,
    "ISSM": 0.40,
    "CMS-Flux": 0.75,
    "CARDAMOM": 0.70,
    "MOMO-CHEM": 0.95,
}


# Affinity caches
_keyword_affinity_cache = None
_research_area_affinity_cache = None
_division_affinity_cache = None


def derive_data_driven_affinities(curated_publications_path: str) -> Dict:
    """
    Derives data-driven affinities from statistical analysis of co-occurrence patterns
    in the curated dataset. For each research area, science keyword, and division,
    this function calculates normalized frequency distributions across models
    and transforms these distributions into affinity multipliers.

    Args:
        curated_publications_path: Path to the curated publications JSON file

    Returns:
        Dictionary containing science keyword to model affinities
    """
    print("Deriving data-driven affinities from curated dataset...")

    try:
        # Load curated publications
        with open(curated_publications_path, "r") as f:
            curated_publications = json.load(f)

        # Initialize science classifier
        science_classifier = ScienceClassifier.get_instance()

        # Count model occurrences and classifications
        model_counts = Counter()
        model_keyword_counts = defaultdict(Counter)
        model_research_area_counts = defaultdict(Counter)
        model_division_counts = defaultdict(Counter)

        # Analyze each publication
        for publication in curated_publications:
            model = publication.get("model")
            if not model:
                continue

            # Count model occurrences
            model_counts[model] += 1

            # Get science classifications
            classifications = science_classifier.classify(publication)

            # Extract science keywords
            for keyword_entry in classifications.get("science_keywords", []):
                keyword = keyword_entry.get("label")
                score = keyword_entry.get("score", 0)

                if keyword and score >= 0.35:  # Apply same threshold as classifier
                    model_keyword_counts[keyword][model] += 1

            # Extract research areas
            for area_entry in classifications.get("research_areas", []):
                area = area_entry.get("label")
                score = area_entry.get("score", 0)

                if area and score >= 0.3:  # Apply reasonable threshold
                    model_research_area_counts[area][model] += 1

            # Extract division
            division_entry = classifications.get("division")
            if division_entry:
                division = division_entry.get("label")
                score = division_entry.get("score", 0)

                if division and score >= 0.5:  # Apply reasonable threshold
                    model_division_counts[division][model] += 1

        # Calculate total publications per model
        total_models = sum(model_counts.values())
        model_proportions = {
            model: count / total_models for model, count in model_counts.items()
        }

        # Calculate keyword affinities
        keyword_affinities = {}

        for keyword, model_counts in model_keyword_counts.items():
            # Only consider keywords with sufficient data
            if sum(model_counts.values()) >= 3:
                keyword_affinities[keyword] = {}

                for model, count in model_counts.items():
                    # Skip if total model count is too low
                    if model_counts[model] < 2:
                        continue

                    # Calculate the probability of this model given the keyword
                    prob_model_given_keyword = count / sum(model_counts.values())

                    # Calculate the base probability of the model in the dataset
                    base_prob_model = model_proportions.get(
                        model, 0.01
                    )  # Avoid division by zero

                    # Calculate affinity as the ratio of these probabilities
                    # Apply a transformation to make it a useful multiplier
                    affinity = prob_model_given_keyword / base_prob_model

                    # Apply thresholds and scaling
                    if affinity > 1.05:  # Only include meaningful affinities
                        # Cap very high affinities to avoid overconfidence
                        affinity = min(1.35, affinity)
                        # Round to 2 decimal places
                        affinity = round(affinity, 2)
                        keyword_affinities[keyword][model] = affinity

        # Calculate research area affinities (using the same approach)
        research_area_affinities = {}

        for area, model_counts in model_research_area_counts.items():
            if sum(model_counts.values()) >= 3:
                research_area_affinities[area] = {}

                for model, count in model_counts.items():
                    if model_counts[model] < 2:
                        continue

                    prob_model_given_area = count / sum(model_counts.values())
                    base_prob_model = model_proportions.get(model, 0.01)

                    affinity = prob_model_given_area / base_prob_model

                    if affinity > 1.05:
                        affinity = min(
                            1.7, affinity
                        )  # Allow slightly higher values for research areas
                        affinity = round(affinity, 2)
                        research_area_affinities[area][model] = affinity

        # Calculate division affinities (using the same approach)
        division_affinities = {}

        for division, model_counts in model_division_counts.items():
            if sum(model_counts.values()) >= 3:
                division_affinities[division] = {}

                for model, count in model_counts.items():
                    if model_counts[model] < 2:
                        continue

                    prob_model_given_division = count / sum(model_counts.values())
                    base_prob_model = model_proportions.get(model, 0.01)

                    affinity = prob_model_given_division / base_prob_model

                    if affinity > 1.05:
                        affinity = min(1.3, affinity)
                        affinity = round(affinity, 2)
                        division_affinities[division][model] = affinity

        # Save the science keyword affinities to file
        try:
            with open("./data_driven_affinities.json", "w") as f:
                json.dump(keyword_affinities, f, indent=2)
            print(
                f"Saved data-driven keyword affinities with {len(keyword_affinities)} entries"
            )
        except Exception as e:
            print(f"Error saving data-driven affinities: {str(e)}")

        # Update global research area and division affinities
        global _research_area_affinity_cache, _division_affinity_cache
        _research_area_affinity_cache = research_area_affinities
        _division_affinity_cache = division_affinities

        print(
            f"Derived affinities for {len(keyword_affinities)} keywords, {len(research_area_affinities)} research areas, and {len(division_affinities)} divisions"
        )

        return keyword_affinities

    except Exception as e:
        print(f"Error deriving data-driven affinities: {str(e)}")
        return {}


def get_science_keyword_model_affinities():
    # Get science keyword affinities
    global _keyword_affinity_cache

    if _keyword_affinity_cache is not None:
        return _keyword_affinity_cache

    try:
        with open("./data_driven_affinities.json", "r") as f:
            _keyword_affinity_cache = json.load(f)
            return _keyword_affinity_cache
    except (FileNotFoundError, json.JSONDecodeError):
        print("Error loading affinity data. Using default empty affinities.")
        _keyword_affinity_cache = {}
        return _keyword_affinity_cache


def get_research_area_model_affinities():
    # Get research area affinities
    global _research_area_affinity_cache

    if _research_area_affinity_cache is not None:
        return _research_area_affinity_cache

    _research_area_affinity_cache = {
        "Atmospheric/Ocean Indicators": {
            "ECCO": 1.7,
            "MOMO-CHEM": 1.4,
            "CMS-Flux": 1.1,
            "ISSM": 1.1,
        },
        "Greenhouse Gases": {
            "CARDAMOM": 1.6,
            "CMS-Flux": 1.8,
            "MOMO-CHEM": 1.55,
            "ECCO": 1.15,
        },
        "Ecosystems": {"CARDAMOM": 1.6, "CMS-Flux": 1.2, "ECCO": 1.25},
        "Land Surface/Agriculture Indicators": {
            "CARDAMOM": 1.4,
            "CMS-Flux": 1.3,
            "ECCO": 1.1,
            "ISSM": 1.4,
            "RAPID": 1.4,
        },
        "Validation": {
            "CMS-Flux": 1.2,
            "ECCO": 1.4,
            "ISSM": 1.2,
            "MOMO-CHEM": 1.15,
            "RAPID": 1.25,
        },
        "Cryospheric Indicators": {"ECCO": 1.35, "ISSM": 1.9},
        "Air Quality": {"CMS-Flux": 1.4, "MOMO-CHEM": 1.9},
        "Floods": {"ISSM": 1.15, "RAPID": 1.6},
        "Environmental Impacts": {"MOMO-CHEM": 1.25},
        "Severe Storms": {"ECCO": 1.2},
        "Earthquakes": {"ECCO": 1.05},
        "Droughts": {"CMS-Flux": 1.2, "RAPID": 1.4},
    }

    return _research_area_affinity_cache


def get_division_model_affinities():
    # Get division affinities
    global _division_affinity_cache

    if _division_affinity_cache is not None:
        return _division_affinity_cache

    _division_affinity_cache = {
        "Earth Science": {
            "ECCO": 1.3,
            "RAPID": 1.2,
            "CMS-Flux": 1.15,
            "MOMO-CHEM": 1.1,
            "ISSM": 1.15,
            "CARDAMOM": 1.15,
        },
        "Biological and Physical Sciences": {
            "CARDAMOM": 1.2,
            "CMS-Flux": 1.1,
        },
        "Heliophysics": {
            "MOMO-CHEM": 1.15,
        },
        "Planetary Science": {
            "ISSM": 1.1,
        },
        "Astrophysics": {},
    }

    return _division_affinity_cache


def get_model_specific_thresholds():
    # Get model-specific thresholds
    return MODEL_THRESHOLDS.copy()


def analyze_threshold_performance(
    results: List[Dict],
    model_thresholds: Dict[str, float] = None,
    overall_threshold: float = 0.4,
) -> Dict:
    """
    Analyze model performance with custom thresholds

    Args:
        results: List of publication results with ground truth
        model_thresholds: Custom threshold values by model (defaults to MODEL_THRESHOLDS)
        overall_threshold: Default threshold for models without specific threshold

    Returns:
        Dict containing performance metrics using the specified thresholds
    """
    if model_thresholds is None:
        model_thresholds = MODEL_THRESHOLDS.copy()

    # Extract publications with ground truth
    publications_with_truth = [
        result
        for result in results
        if "models" in result and result["models"] and "confidence_scores" in result
    ]

    if not publications_with_truth:
        return {"error": "No publications with ground truth for evaluation"}

    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get("models", []))
        for model in pub.get("confidence_scores", {}).keys():
            all_models.add(model)
    all_models = sorted(list(all_models))

    # Collect performance metrics
    model_metrics = {}
    overall_metrics = {"tp": 0, "fp": 0, "fn": 0, "tn": 0}  # Added tn for MCC

    for model in all_models:
        # For each model, collect predictions using custom thresholds
        y_true = []
        y_pred = []

        # Apply model-specific threshold or fall back to overall
        threshold = model_thresholds.get(model, overall_threshold)

        for pub in publications_with_truth:
            is_true_match = 1 if model in pub.get("models", []) else 0
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            is_predicted = 1 if confidence >= threshold else 0

            y_true.append(is_true_match)
            y_pred.append(is_predicted)

        # Calculate metrics
        tp = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
        fp = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1)
        fn = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0)
        tn = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 0)  # Added for MCC

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0
        )
        mcc = matthews_corrcoef(y_true, y_pred)  # Calculate MCC

        model_metrics[model] = {
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "mcc": mcc,  # Added MCC
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,  # Added for MCC
            "threshold": threshold,
        }

        # Accumulate for overall metrics
        overall_metrics["tp"] += tp
        overall_metrics["fp"] += fp
        overall_metrics["fn"] += fn
        overall_metrics["tn"] += tn  # Added for MCC

    # Calculate overall metrics
    overall_precision = (
        overall_metrics["tp"] / (overall_metrics["tp"] + overall_metrics["fp"])
        if (overall_metrics["tp"] + overall_metrics["fp"]) > 0
        else 0
    )
    overall_recall = (
        overall_metrics["tp"] / (overall_metrics["tp"] + overall_metrics["fn"])
        if (overall_metrics["tp"] + overall_metrics["fn"]) > 0
        else 0
    )
    overall_f1 = (
        2 * overall_precision * overall_recall / (overall_precision + overall_recall)
        if (overall_precision + overall_recall) > 0
        else 0
    )
    
    # Calculate overall MCC
    all_y_true = []
    all_y_pred = []
    for pub in publications_with_truth:
        for model in all_models:
            is_true_match = 1 if model in pub.get("models", []) else 0
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            threshold = model_thresholds.get(model, overall_threshold)
            is_predicted = 1 if confidence >= threshold else 0
            
            all_y_true.append(is_true_match)
            all_y_pred.append(is_predicted)
    
    overall_mcc = matthews_corrcoef(all_y_true, all_y_pred)

    return {
        "per_model": model_metrics,
        "overall": {
            "precision": overall_precision,
            "recall": overall_recall,
            "f1": overall_f1,
            "mcc": overall_mcc,  # Added MCC
            "tp": overall_metrics["tp"],
            "fp": overall_metrics["fp"],
            "fn": overall_metrics["fn"],
            "tn": overall_metrics["tn"],  # Added for MCC
        },
        "thresholds": {
            "model_specific": model_thresholds,
            "overall_default": overall_threshold,
        },
    }


def find_optimal_thresholds(
    results: List[Dict], threshold_range=None, step=0.05
) -> Dict:
    """
    Find optimal thresholds for each model based on F1 score and MCC

    Args:
        results: List of publication results with ground truth
        threshold_range: Optional range of thresholds to test (min, max)
        step: Step size for threshold values

    Returns:
        Dict containing optimal thresholds for each model and overall
    """
    if threshold_range is None:
        threshold_range = (0.1, 0.95)

    # Generate threshold values to test
    thresholds = np.arange(threshold_range[0], threshold_range[1] + step, step)

    # Extract publications with ground truth
    publications_with_truth = [
        result
        for result in results
        if "models" in result and result["models"] and "confidence_scores" in result
    ]

    if not publications_with_truth:
        return {"error": "No publications with ground truth for evaluation"}

    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get("models", []))
        for model in pub.get("confidence_scores", {}).keys():
            all_models.add(model)
    all_models = sorted(list(all_models))

    # Find optimal thresholds for each model
    optimal_thresholds = {}

    for model in all_models:
        best_f1 = -1
        best_f1_threshold = 0.4  # Default
        best_f1_metrics = {}
        
        best_mcc = -2  # MCC ranges from -1 to 1, so -2 is safe as initialization
        best_mcc_threshold = 0.4  # Default
        best_mcc_metrics = {}

        # Extract data for this model
        model_data = []
        for pub in publications_with_truth:
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            is_true_match = 1 if model in pub.get("models", []) else 0
            model_data.append((confidence, is_true_match))

        # Test each threshold
        for threshold in thresholds:
            # Calculate metrics at this threshold
            tp = sum(1 for conf, true in model_data if true == 1 and conf >= threshold)
            fp = sum(1 for conf, true in model_data if true == 0 and conf >= threshold)
            fn = sum(1 for conf, true in model_data if true == 1 and conf < threshold)
            tn = sum(1 for conf, true in model_data if true == 0 and conf < threshold)  # Added for MCC

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = (
                2 * precision * recall / (precision + recall)
                if (precision + recall) > 0
                else 0
            )
            
            # Calculate MCC
            y_true = [true for _, true in model_data]
            y_pred = [1 if conf >= threshold else 0 for conf, _ in model_data]
            mcc = matthews_corrcoef(y_true, y_pred)  # Calculate MCC

            # Update if this is the best F1 score
            if f1 > best_f1:
                best_f1 = f1
                best_f1_threshold = threshold
                best_f1_metrics = {
                    "precision": precision,
                    "recall": recall,
                    "f1": f1,
                    "mcc": mcc,  # Include MCC
                    "tp": tp,
                    "fp": fp,
                    "fn": fn,
                    "tn": tn,  # Added for MCC
                }
            
            # Update if this is the best MCC
            if mcc > best_mcc:
                best_mcc = mcc
                best_mcc_threshold = threshold
                best_mcc_metrics = {
                    "precision": precision,
                    "recall": recall,
                    "f1": f1,
                    "mcc": mcc,
                    "tp": tp,
                    "fp": fp,
                    "fn": fn,
                    "tn": tn,
                }

        optimal_thresholds[model] = {
            "threshold_f1": float(best_f1_threshold),
            "f1": best_f1,
            "metrics_f1": best_f1_metrics,
            "threshold_mcc": float(best_mcc_threshold),  # Added MCC threshold
            "mcc": best_mcc,  # Added best MCC value
            "metrics_mcc": best_mcc_metrics,  # Added MCC metrics
        }

    # Find optimal overall threshold
    all_data = []
    for pub in publications_with_truth:
        for model in all_models:
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            is_true_match = 1 if model in pub.get("models", []) else 0
            all_data.append((confidence, is_true_match))

    best_overall_f1 = -1
    best_overall_f1_threshold = 0.4  # Default
    best_overall_f1_metrics = {}
    
    best_overall_mcc = -2  # MCC initialization
    best_overall_mcc_threshold = 0.4  # Default
    best_overall_mcc_metrics = {}

    for threshold in thresholds:
        # Calculate metrics at this threshold
        tp = sum(1 for conf, true in all_data if true == 1 and conf >= threshold)
        fp = sum(1 for conf, true in all_data if true == 0 and conf >= threshold)
        fn = sum(1 for conf, true in all_data if true == 1 and conf < threshold)
        tn = sum(1 for conf, true in all_data if true == 0 and conf < threshold)  # Added for MCC

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0
        )
        
        # Calculate MCC
        y_true = [true for _, true in all_data]
        y_pred = [1 if conf >= threshold else 0 for conf, _ in all_data]
        mcc = matthews_corrcoef(y_true, y_pred)

        # Update if this is the best F1 score
        if f1 > best_overall_f1:
            best_overall_f1 = f1
            best_overall_f1_threshold = threshold
            best_overall_f1_metrics = {
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "mcc": mcc,  # Include MCC in F1 metrics
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "tn": tn,
            }
        
        # Update if this is the best MCC
        if mcc > best_overall_mcc:
            best_overall_mcc = mcc
            best_overall_mcc_threshold = threshold
            best_overall_mcc_metrics = {
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "mcc": mcc,
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "tn": tn,
            }

    # Create model threshold dictionaries (both F1 and MCC-based)
    model_threshold_dict_f1 = {
        model: data["threshold_f1"] for model, data in optimal_thresholds.items()
    }
    
    model_threshold_dict_mcc = {
        model: data["threshold_mcc"] for model, data in optimal_thresholds.items()
    }

    return {
        "per_model": optimal_thresholds,
        "overall": {
            "threshold_f1": float(best_overall_f1_threshold),
            "f1": best_overall_f1,
            "metrics_f1": best_overall_f1_metrics,
            "threshold_mcc": float(best_overall_mcc_threshold),  # Added MCC threshold
            "mcc": best_overall_mcc,  # Added best MCC value
            "metrics_mcc": best_overall_mcc_metrics,  # Added MCC metrics
        },
        "model_thresholds_f1": model_threshold_dict_f1,
        "model_thresholds_mcc": model_threshold_dict_mcc,  # Added MCC thresholds
    }


# Semantic matching cache
_semantic_match_cache = {}


@torch.inference_mode()
def semantic_match(
    publication_text: str,
    model_embeddings: Dict[str, np.ndarray],
    threshold: float = 0.5,
    model_thresholds: Dict[str, float] = None,
) -> Dict[str, float]:
    # Semantic matching
    cache_key = f"{hash(publication_text)}_{hash(str(threshold))}"
    if cache_key in _semantic_match_cache:
        return _semantic_match_cache[cache_key]

    if model_thresholds is None:
        model_thresholds = get_model_specific_thresholds()

    pub_embedding = st_model.encode(publication_text, convert_to_tensor=False)

    model_names = list(model_embeddings.keys())
    embeddings_array = np.array([model_embeddings[model] for model in model_names])

    similarities = cosine_similarity([pub_embedding], embeddings_array)[0]

    results = {}
    for i, model in enumerate(model_names):
        sim = similarities[i]
        model_threshold = model_thresholds.get(model, threshold)

        if sim >= model_threshold:
            results[model] = float(sim)

    _semantic_match_cache[cache_key] = results
    return results

In [40]:
_matched_publications_cache = {}
_improved_matching_cache = {}

# Define model origin years
MODEL_ORIGIN_YEARS = {
    "ECCO": 1998,
    "CARDAMOM": 2014,
    "CMS-Flux": 2014,
    "ISSM": 2003,
    "MOMO-CHEM": 2014,
    "RAPID": 2010,
}


def match_models_improved(
    publication: Dict,
    curated_mapping: Dict[str, str],
    model_keywords: Dict[str, List[str]],
    model_embeddings: Dict[str, np.ndarray],
    science_classifier,
    context_manager=None,
    ranker=None,
    threshold: float = 0.45,
    research_area_affinities=None,
    division_affinities=None,
    keyword_affinities=None,
    model_thresholds=None,
    include_classifications: bool = True,
) -> Dict:
    # Match publication to models
    # Try to use cache
    pub_key = None
    if "DOI" in publication and publication["DOI"]:
        pub_key = f"doi:{normalize_doi(publication['DOI'])}"
    elif "title" in publication:
        if isinstance(publication["title"], list) and publication["title"]:
            pub_key = f"title:{publication['title'][0]}"
        elif isinstance(publication["title"], str):
            pub_key = f"title:{publication['title']}"

    if pub_key and pub_key in _improved_matching_cache:
        return _improved_matching_cache[pub_key]

    # Load affinities if needed
    if research_area_affinities is None:
        research_area_affinities = get_research_area_model_affinities()

    if division_affinities is None:
        division_affinities = get_division_model_affinities()

    if keyword_affinities is None:
        keyword_affinities = get_science_keyword_model_affinities()

    if model_thresholds is None:
        model_thresholds = get_model_specific_thresholds()

    # Initialize scores
    all_models = list(model_embeddings.keys())
    confidence_scores = {model: 0.0 for model in all_models}
    confidence_sources = {model: [] for model in all_models}

    # Step 1: Check curated mapping (known ground truth)
    doi = normalize_doi(publication.get("DOI", ""))
    if doi and doi in curated_mapping:
        model = curated_mapping[doi]
        confidence_scores[model] = 1.0
        confidence_sources[model].append("curated_mapping")

    # Step 2: Extract text
    title = ""
    if "title" in publication:
        if isinstance(publication["title"], list) and publication["title"]:
            title = publication["title"][0]
        else:
            title = publication.get("title", "")

    abstract = publication.get("abstract", "")
    publication_text = f"{title} {abstract}"

    # Early exit if curated match and minimal text
    model = None
    if doi and doi in curated_mapping:
        model = curated_mapping[doi]

    if model and (len(publication_text) < 10):
        matched_models = [
            model
            for model, confidence in confidence_scores.items()
            if confidence >= threshold
        ]

        result = {
            "matched_models": matched_models,
            "confidence_scores": confidence_scores,
            "confidence_sources": confidence_sources,
        }

        if pub_key:
            _improved_matching_cache[pub_key] = result

        return result

    # Step 3: Check model-specific keywords (most direct signal)
    keywords = publication.get("keywords", [])
    text_for_keyword_matching = preprocess_text(publication_text)

    keyword_match_counts = {model: 0 for model in all_models}
    keyword_direct_matches = {model: [] for model in all_models}

    for model, model_kw_list in model_keywords.items():
        # Skip if already high confidence
        if model in confidence_scores and confidence_scores[model] >= 0.95:
            continue

        for kw in model_kw_list:
            # Check explicit keywords first (more reliable)
            keyword_found = False
            for pub_kw in keywords:
                matched, score = fuzzy_keyword_match(pub_kw, kw)
                if matched:
                    keyword_match_counts[model] += 1
                    # Store the exact keyword match for reference
                    keyword_direct_matches[model].append(kw)
                    keyword_found = True
                    break

            # Check text if not found in keywords
            if not keyword_found:
                match_found, score = fuzzy_keyword_match(text_for_keyword_matching, kw)
                if match_found:
                    keyword_match_counts[model] += 1
                    keyword_direct_matches[model].append(kw)

    # Apply direct keyword boost with improved confidence scale
    for model, count in keyword_match_counts.items():
        if count > 0:
            # Apply more nuanced confidence scaling
            if count >= 3:  # Multiple strong keyword matches
                kw_confidence = 0.92
            elif count == 2:  # Two keyword matches
                kw_confidence = 0.85
            else:  # Single keyword match
                kw_confidence = 0.70

            # Check for exact model name match which is a stronger signal
            if any(kw.lower() == model.lower() for kw in keyword_direct_matches[model]):
                kw_confidence = min(0.95, kw_confidence + 0.10)

            if kw_confidence > confidence_scores[model]:
                confidence_scores[model] = kw_confidence
                confidence_sources[model].append("keyword_match")

    # Step 4: Semantic matching (useful for content similarity)
    try:
        if not any(score >= 0.95 for score in confidence_scores.values()):
            # Improved default semantic threshold for better precision
            semantic_matches = semantic_match(
                publication_text,
                model_embeddings,
                threshold=0.45,  # Slightly higher default threshold
                model_thresholds=model_thresholds,
            )

            for model, similarity in semantic_matches.items():
                if similarity > 0.45:
                    # Scale similarity to align with confidence scoring
                    scaled_confidence = min(0.90, similarity * 0.95)
                    confidence_scores[model] = max(
                        confidence_scores[model], scaled_confidence
                    )
                    confidence_sources[model].append("semantic_match")
    except Exception as e:
        print(f"Semantic matching failed: {str(e)}")

    # Step 5: Science classification signals (important metadata matching)
    if not any(score >= 0.95 for score in confidence_scores.values()):
        try:
            science_results = science_classifier.classify(publication)

            # Step 5a: Apply science keyword affinities
            science_keywords = science_results.get("science_keywords", [])

            # Track high-scoring keywords to apply boosting later
            model_keyword_matches = defaultdict(list)

            for kw_entry in science_keywords:
                keyword = kw_entry["label"]
                keyword_score = kw_entry["score"]

                # Only process high confidence keywords
                if keyword_score < 0.40:
                    continue

                # Use data-driven affinities for keywords
                if keyword in keyword_affinities:
                    for model, affinity in keyword_affinities[keyword].items():
                        # Store keyword match for boosting
                        model_keyword_matches[model].append((keyword, keyword_score))

                        # Calculate confidence based on affinity and keyword score
                        confidence = keyword_score * (affinity - 1.0)

                        if confidence > 0.1:
                            confidence_scores[model] = max(
                                confidence_scores[model], confidence
                            )
                            confidence_sources[model].append(
                                f"science_keyword:{keyword}"
                            )

            # Step 5b: Apply research area affinities
            research_areas = science_results.get("research_areas", [])
            model_area_matches = defaultdict(list)

            for area_entry in research_areas:
                area = area_entry["label"]
                area_score = area_entry["score"]

                # Only consider strong research area matches
                if area_score < 0.40:
                    continue

                if area in research_area_affinities:
                    for model, affinity in research_area_affinities[area].items():
                        # Store area match for boosting
                        model_area_matches[model].append((area, area_score))

                        # Apply improved confidence calculation
                        confidence = area_score * (affinity - 1.0)
                        confidence = min(
                            confidence, 0.85
                        )  # Cap to avoid overconfidence

                        if confidence > 0.12:
                            confidence_scores[model] = max(
                                confidence_scores[model], confidence
                            )
                            confidence_sources[model].append(f"research_area:{area}")

            # Step 5c: Apply division affinities
            division_entry = science_results.get("division")
            if division_entry:
                division = division_entry["label"]
                division_score = division_entry["score"]

                if division_score >= 0.55 and division in division_affinities:
                    for model, affinity in division_affinities[division].items():
                        confidence = division_score * (affinity - 1.0)

                        if confidence > 0.05:
                            confidence_scores[model] = max(
                                confidence_scores[model], confidence
                            )
                            confidence_sources[model].append(f"division:{division}")

            # Apply boosts for multiple science keyword matches
            for model, matches in model_keyword_matches.items():
                if len(matches) >= 3 and model in confidence_scores:
                    # Calculate average score
                    avg_score = sum(score for _, score in matches) / len(matches)
                    # Apply a boost for multiple keyword matches
                    boost = min(0.15, 0.05 * len(matches))
                    current_score = confidence_scores[model]
                    confidence_scores[model] = min(0.90, current_score + boost)
                    if "multiple_science_keywords" not in confidence_sources[model]:
                        confidence_sources[model].append("multiple_science_keywords")

            # Apply boosts for multiple research area matches
            for model, matches in model_area_matches.items():
                if len(matches) >= 2 and model in confidence_scores:
                    # Calculate average score
                    avg_score = sum(score for _, score in matches) / len(matches)
                    # Apply a boost for multiple area matches
                    boost = min(0.12, 0.06 * len(matches))
                    current_score = confidence_scores[model]
                    confidence_scores[model] = min(0.90, current_score + boost)
                    if "multiple_research_areas" not in confidence_sources[model]:
                        confidence_sources[model].append("multiple_research_areas")

        except Exception as e:
            print(f"Science classification failed: {str(e)}")

    # Step 6: Context validation (compare with known examples)
    if context_manager and not any(
        score >= 0.95 for score in confidence_scores.values()
    ):
        try:
            context_scores = context_manager.get_context_scores(publication)
            for model, score in context_scores.items():
                if score > 0.45:  # Slightly higher threshold for quality
                    # Apply a stronger weight to context validation
                    context_confidence = score * 0.95
                    confidence_scores[model] = max(
                        confidence_scores[model], context_confidence
                    )
                    confidence_sources[model].append("context_validation")
        except Exception as e:
            print(f"Context validation failed: {str(e)}")

    # Step 7: Relevance ranking (specific query matching)
    if ranker and not any(score >= 0.95 for score in confidence_scores.values()):
        try:
            # Create a more targeted query for ranking by focusing on title and first part of abstract
            query = title
            if abstract and len(abstract) > 50:
                query = f"{title} {abstract[:400]}"

            # Only rank models that have some confidence already
            candidate_models = [
                model for model, score in confidence_scores.items() if score > 0.25
            ]

            if candidate_models:
                rank_scores = ranker.batch_rank(query, candidate_models)

                for model, score in rank_scores.items():
                    if score > 0.35:  # Apply a stronger filter
                        # Scale ranker scores to be more in line with other confidence measures
                        ranker_confidence = score * 0.95
                        confidence_scores[model] = max(
                            confidence_scores[model], ranker_confidence
                        )
                        confidence_sources[model].append("relevance_ranker")
        except Exception as e:
            print(f"Relevance ranking failed: {str(e)}")

    # Step 8: Apply hybrid boosts for evidence consensus
    candidate_models = [
        model for model, score in confidence_scores.items() if score >= 0.25
    ]
    for model in candidate_models:
        sources = confidence_sources[model]

        # Significant boost for keyword + semantic matches (two strong signals)
        if "keyword_match" in sources and "semantic_match" in sources:
            current_score = confidence_scores[model]
            # Higher boost for this strong combination
            confidence_scores[model] = min(0.95, current_score * 1.08)
            if "hybrid_kw_semantic" not in sources:
                sources.append("hybrid_kw_semantic")

        # Boost for science metadata consensus (strong metadata alignment)
        science_sources = [
            s
            for s in sources
            if s.startswith(("science_keyword:", "research_area:", "division:"))
        ]
        if len(science_sources) >= 3:  # More sources required for stronger consensus
            current_score = confidence_scores[model]
            confidence_scores[model] = min(0.92, current_score * 1.15)
            if "science_consensus" not in sources:
                sources.append("science_consensus")

        # Special boost for context validation + semantic match (high precision combo)
        if "context_validation" in sources and "semantic_match" in sources:
            current_score = confidence_scores[model]
            confidence_scores[model] = min(0.95, current_score * 1.10)
            if "hybrid_context_semantic" not in sources:
                sources.append("hybrid_context_semantic")

        # Strong consensus boost when multiple independent signals agree
        independent_signals = len(
            set(
                [
                    "keyword_match",
                    "semantic_match",
                    "context_validation",
                    "relevance_ranker",
                    "science_consensus",
                ]
            )
            & set(sources)
        )
        if independent_signals >= 3:  # At least 3 independent signals
            current_score = confidence_scores[model]
            confidence_scores[model] = min(0.97, current_score * 1.12)
            if "strong_signal_consensus" not in sources:
                sources.append("strong_signal_consensus")

    # Step 9: Apply second-order filtering (if a model scores extremely high, increase its threshold)
    # This helps avoid false positives from borderline matches when we're very confident about another model
    high_confidence_models = [
        model for model, score in confidence_scores.items() if score >= 0.92
    ]
    if high_confidence_models:
        # Increase threshold for other models
        for model in all_models:
            if model not in high_confidence_models and confidence_scores[model] < 0.75:
                # Suppress low to medium confidence matches when we have very high confidence elsewhere
                confidence_scores[model] = confidence_scores[model] * 0.85

    # Step 10: Apply date verification - exclude models that were published after the publication date
    # Extract publication date and year
    pub_date = extract_publication_date(publication)
    pub_year = None
    if pub_date:
        try:
            pub_year = int(pub_date.split("-")[0])
        except (ValueError, IndexError):
            pass

    # If we have a publication year, filter out models that didn't exist yet
    if pub_year:
        for model in all_models:
            if model in MODEL_ORIGIN_YEARS:
                model_origin_year = MODEL_ORIGIN_YEARS[model]
                if pub_year < model_origin_year:
                    # Publication predates the model, so it cannot be about this model
                    confidence_scores[model] = 0.0
                    confidence_sources[model] = ["excluded_by_date_verification"]

    # Filter by thresholds
    matched_models = []
    for model, confidence in confidence_scores.items():
        model_threshold = model_thresholds.get(model, threshold)
        if confidence >= model_threshold:
            matched_models.append(model)

    # Add classifications
    science_results = None
    if include_classifications:
        try:
            # Use existing science results if available
            if not science_results:
                science_results = science_classifier.classify(publication)

            result_with_classifications = {
                "matched_models": matched_models,
                "confidence_scores": confidence_scores,
                "confidence_sources": confidence_sources,
                # Include all classification results
                "classifications": {
                    "research_areas": science_results.get("research_areas", []),
                    "science_keywords": science_results.get("science_keywords", []),
                    "division": science_results.get("division"),
                },
            }

            # Add context terms if available
            if context_manager:
                try:
                    prepared_text = science_classifier.get_instance()._prepare_text(
                        publication
                    )
                    pub_profile = context_manager._get_pub_profile(prepared_text)
                    result_with_classifications["context_terms"] = list(
                        pub_profile["terms"]
                    )
                except Exception as e:
                    print(f"Error extracting context terms: {str(e)}")

            # Add relevance scores if available
            if ranker and publication_text:
                try:
                    query = publication_text[:500]
                    relevance_scores = ranker.batch_rank(query, all_models)
                    result_with_classifications["relevance_scores"] = relevance_scores
                except Exception as e:
                    print(f"Error getting relevance scores: {str(e)}")

            # Extract publication date
            pub_date = extract_publication_date(publication)
            if pub_date:
                result_with_classifications["pubdate"] = pub_date

            # Cache result
            if pub_key:
                _improved_matching_cache[pub_key] = result_with_classifications

            return result_with_classifications
        except Exception as e:
            print(f"Error including classifications: {str(e)}")

    # Basic result without classifications
    result = {
        "matched_models": matched_models,
        "confidence_scores": confidence_scores,
        "confidence_sources": confidence_sources,
    }

    # Extract publication date
    pub_date = extract_publication_date(publication)
    if pub_date:
        result["pubdate"] = pub_date

    # Cache result
    if pub_key:
        _improved_matching_cache[pub_key] = result

    return result


def process_publication_batch(
    publications: List[Dict],
    curated_mapping: Dict[str, str],
    model_keywords: Dict[str, List[str]],
    model_embeddings: Dict[str, np.ndarray],
    science_classifier,
    context_manager,
    ranker=None,
    batch_size: int = 100,
    include_classifications: bool = True,
) -> List[Dict]:
    # Process publications in batches
    results = []
    total_pubs = len(publications)
    num_batches = (total_pubs + batch_size - 1) // batch_size

    # Load affinities
    research_area_affinities = get_research_area_model_affinities()
    division_affinities = get_division_model_affinities()
    keyword_affinities = get_science_keyword_model_affinities()
    model_thresholds = get_model_specific_thresholds()

    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_pubs)

        print(
            f"Processing batch {batch_idx+1}/{num_batches} (publications {start_idx+1}-{end_idx}/{total_pubs})..."
        )

        batch_pubs = publications[start_idx:end_idx]
        batch_results = []

        # Process each publication
        for i, pub in enumerate(batch_pubs):
            pub_idx = start_idx + i + 1
            if pub_idx % 20 == 0:
                print(f"  Processing publication {pub_idx}/{total_pubs}...")

            # Create cache
            pub_key = None
            if "DOI" in pub and pub["DOI"]:
                pub_key = f"doi:{normalize_doi(pub['DOI'])}"
            elif "title" in pub:
                if isinstance(pub["title"], list) and pub["title"]:
                    pub_key = f"title:{pub['title'][0]}"
                elif isinstance(pub["title"], str):
                    pub_key = f"title:{pub['title']}"

            # Check cache
            if pub_key and pub_key in _matched_publications_cache:
                cached_result = _matched_publications_cache[pub_key]
                pub_copy = pub.copy()
                pub_copy.update(cached_result)
                batch_results.append(pub_copy)
                continue

            # Match models
            try:
                match_result = match_models_improved(
                    pub,
                    curated_mapping,
                    model_keywords,
                    model_embeddings,
                    science_classifier,
                    context_manager,
                    ranker,
                    research_area_affinities=research_area_affinities,
                    division_affinities=division_affinities,
                    keyword_affinities=keyword_affinities,
                    model_thresholds=model_thresholds,
                    include_classifications=include_classifications,
                )

                # Add results to publication
                pub_copy = pub.copy()
                # Add all fields from match_result
                for key, value in match_result.items():
                    pub_copy[key] = value

                # Extract characteristics if needed
                if ranker is not None:
                    pub_copy["pub_characteristics"] = (
                        extract_publication_characteristics(pub)
                    )

                # Add to results
                batch_results.append(pub_copy)

                # Cache results
                if pub_key:
                    cache_value = {
                        "matched_models": match_result["matched_models"],
                        "confidence_scores": match_result["confidence_scores"],
                        "confidence_sources": match_result["confidence_sources"],
                    }
                    if "pub_characteristics" in pub_copy:
                        cache_value["pub_characteristics"] = pub_copy[
                            "pub_characteristics"
                        ]
                    _matched_publications_cache[pub_key] = cache_value

            except Exception as e:
                print(f"Error processing publication {pub_idx}: {str(e)}")
                # Add publication without matches
                pub_copy = pub.copy()
                pub_copy["matched_models"] = []
                pub_copy["confidence_scores"] = {}
                pub_copy["confidence_sources"] = {}
                batch_results.append(pub_copy)

        # Add batch results to overall results
        results.extend(batch_results)

        # Periodically clean caches
        if batch_idx % 5 == 4:
            _fuzzy_match_cache.clear()
            _semantic_match_cache.clear()

            # Clean up memory
            optimize_memory()

        print(f"Completed batch {batch_idx+1}/{num_batches}")

    return results

In [41]:
def visualize_metrics(
    results: List[Dict], output_path_base: str = "./metrics_visualization"
):
    # Create comprehensive visualizations
    # results: List of prediction results with true models and predicted models
    # output_path_base: Base path for saving visualizations

    # Extract publications with ground truth
    publications_with_truth = [
        result
        for result in results
        if "models" in result and result["models"] and "matched_models" in result
    ]

    if not publications_with_truth:
        print("No publications with ground truth for evaluation")
        return

    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get("models", []))
        all_models.update(pub.get("matched_models", []))
    all_models = sorted(list(all_models))

    # Dictionary to store all metrics
    complete_metrics = {
        "model_performance": {},
        "source_analysis": {},
        "classification_accuracy": {},
        "confidence_analysis": {},
        "temporal_analysis": {},
        "threshold_analysis": {},  # New for threshold analysis
        "confusion_matrix": {},  # New for confusion matrix
        "mcc_analysis": {},      # New for Matthews Correlation Coefficient
    }

    # 1. Model Performance Metrics
    model_metrics = {}
    metrics = {"precision": [], "recall": [], "f1": [], "mcc": []}  # Added MCC

    for model in all_models:
        # For each model, collect binary predictions
        y_true = []
        y_pred = []

        for pub in publications_with_truth:
            y_true.append(1 if model in pub.get("models", []) else 0)
            y_pred.append(1 if model in pub.get("matched_models", []) else 0)

        # Calculate metrics
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        mcc = matthews_corrcoef(y_true, y_pred)  # Calculate MCC

        model_metrics[model] = {
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "mcc": mcc,  # Add MCC to metrics
            "true_positives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1
            ),
            "false_positives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1
            ),
            "false_negatives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0
            ),
            "true_negatives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 0 and p == 0
            ),  # Added for MCC
            "match_count": sum(y_pred),
        }

        metrics["precision"].append(precision)
        metrics["recall"].append(recall)
        metrics["f1"].append(f1)
        metrics["mcc"].append(mcc)  # Add MCC to metrics list

    # Calculate micro-average metrics
    all_y_true = []
    all_y_pred = []

    for pub in publications_with_truth:
        true_labels = [
            1 if model in pub.get("models", []) else 0 for model in all_models
        ]
        pred_labels = [
            1 if model in pub.get("matched_models", []) else 0 for model in all_models
        ]

        all_y_true.extend(true_labels)
        all_y_pred.extend(pred_labels)

    micro_precision = precision_score(all_y_true, all_y_pred, zero_division=0)
    micro_recall = recall_score(all_y_true, all_y_pred, zero_division=0)
    micro_f1 = f1_score(all_y_true, all_y_pred, zero_division=0)
    micro_mcc = matthews_corrcoef(all_y_true, all_y_pred)  # Calculate overall MCC

    complete_metrics["model_performance"] = {
        "per_model": model_metrics,
        "micro_average": {
            "precision": micro_precision,
            "recall": micro_recall,
            "f1": micro_f1,
            "mcc": micro_mcc,  # Add MCC to micro_average
        },
    }

    # Create model performance visualization
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

    # Plot per-model metrics
    x = np.arange(len(all_models))
    width = 0.2  # Adjusted to fit 4 bars (added MCC)

    # Skip if no models
    if len(all_models) > 0:
        bars1 = ax1.bar(x - 1.5*width, metrics["precision"], width, label="Precision")
        bars2 = ax1.bar(x - 0.5*width, metrics["recall"], width, label="Recall")
        bars3 = ax1.bar(x + 0.5*width, metrics["f1"], width, label="F1 Score")
        bars4 = ax1.bar(x + 1.5*width, metrics["mcc"], width, label="MCC")  # Added MCC bars

        ax1.set_xlabel("Models")
        ax1.set_ylabel("Score")
        ax1.set_title("Precision, Recall, F1 Score and MCC by Model")
        ax1.set_xticks(x)
        ax1.set_xticklabels(all_models, rotation=45, ha="right")
        ax1.legend()
        ax1.grid(axis="y", linestyle="--", alpha=0.7)

        # Add values on bars
        def add_labels(bars):
            for bar in bars:
                height = bar.get_height()
                ax1.text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height + 0.01,
                    f"{height:.2f}",
                    ha="center",
                    va="bottom",
                )

        add_labels(bars1)
        add_labels(bars2)
        add_labels(bars3)
        add_labels(bars4)  # Add labels to MCC bars

    # Plot micro-average metrics
    micro_metrics = {
        "Precision": micro_precision,
        "Recall": micro_recall,
        "F1 Score": micro_f1,
        "MCC": micro_mcc,  # Added MCC
    }

    x2 = np.arange(len(micro_metrics))
    bars = ax2.bar(x2, micro_metrics.values(), width=0.4)

    ax2.set_xlabel("Metrics")
    ax2.set_ylabel("Score")
    ax2.set_title("Overall Micro-Average Metrics")
    ax2.set_xticks(x2)
    ax2.set_xticklabels(micro_metrics.keys())
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    # Add values on top of the bars
    for bar in bars:
        height = bar.get_height()
        ax2.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
        )

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_model_performance.png", dpi=300)
    plt.close()

    # Create a dedicated visualization for Matthews Correlation Coefficient
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot MCC by model
    mcc_values = [model_metrics[model]["mcc"] for model in all_models]
    bars = ax1.bar(all_models, mcc_values)
    
    ax1.set_xlabel("Models")
    ax1.set_ylabel("Matthews Correlation Coefficient")
    ax1.set_title("Matthews Correlation Coefficient by Model")
    ax1.set_xticklabels(all_models, rotation=45, ha="right")
    ax1.grid(axis="y", linestyle="--", alpha=0.7)
    ax1.axhline(y=micro_mcc, color='r', linestyle='--', label=f'Overall MCC: {micro_mcc:.3f}')
    ax1.legend()
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax1.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 0.01 if height >= 0 else height - 0.08,
            f"{height:.3f}",
            ha="center",
            va="bottom" if height >= 0 else "top",
            color="black"
        )
    
    # Plot MCC vs F1 to show their relationship
    f1_values = [model_metrics[model]["f1"] for model in all_models]
    
    ax2.scatter(f1_values, mcc_values)
    
    # Add model labels to points
    for i, model in enumerate(all_models):
        ax2.annotate(
            model, 
            (f1_values[i], mcc_values[i]),
            xytext=(5, 5),
            textcoords="offset points",
            fontsize=8
        )
    
    # Add a diagonal line for reference
    min_val = min(min(f1_values), min(mcc_values))
    max_val = max(max(f1_values), max(mcc_values))
    ax2.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.3)
    
    ax2.set_xlabel("F1 Score")
    ax2.set_ylabel("Matthews Correlation Coefficient")
    ax2.set_title("MCC vs F1 Score")
    ax2.grid(True, linestyle="--", alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_mcc_analysis.png", dpi=300)
    plt.close()

    # 2. Confidence Analysis

    # Extract confidence scores
    confidence_data = {}

    for model in all_models:
        confidence_scores = []
        correct_predictions = []

        for pub in publications_with_truth:
            is_true_match = model in pub.get("models", [])
            is_predicted = model in pub.get("matched_models", [])
            confidence = pub.get("confidence_scores", {}).get(model, 0)

            confidence_scores.append(confidence)
            correct_predictions.append(1 if is_true_match == is_predicted else 0)

        confidence_data[model] = {
            "scores": confidence_scores,
            "correct_predictions": correct_predictions,
            "mean_confidence": np.mean(confidence_scores) if confidence_scores else 0,
            "median_confidence": np.median(confidence_scores)
            if confidence_scores
            else 0,
        }

    complete_metrics["confidence_analysis"] = confidence_data

    # Create confidence visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot mean confidence by model
    mean_confidence = [
        confidence_data[model]["mean_confidence"] for model in all_models
    ]

    if all_models:
        ax1.bar(all_models, mean_confidence)
        ax1.set_xlabel("Models")
        ax1.set_ylabel("Mean Confidence Score")
        ax1.set_title("Mean Confidence Score by Model")
        ax1.set_xticklabels(all_models, rotation=45, ha="right")
        ax1.grid(axis="y", linestyle="--", alpha=0.7)

    # Plot confidence distribution (boxplot)
    if all_models:
        confidence_values = [confidence_data[model]["scores"] for model in all_models]
        ax2.boxplot(confidence_values, labels=all_models)
        ax2.set_xlabel("Models")
        ax2.set_ylabel("Confidence Score Distribution")
        ax2.set_title("Confidence Score Distribution by Model")
        ax2.set_xticklabels(all_models, rotation=45, ha="right")
        ax2.grid(axis="y", linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_confidence_analysis.png", dpi=300)
    plt.close()

    # 3. Source Analysis

    # Analyze the sources used for matches
    source_counts = {}
    source_accuracy = {}

    for pub in publications_with_truth:
        for model in all_models:
            sources = pub.get("confidence_sources", {}).get(model, [])
            is_true_match = model in pub.get("models", [])
            is_predicted = model in pub.get("matched_models", [])

            for source in sources:
                if source not in source_counts:
                    source_counts[source] = 0
                    source_accuracy[source] = {"correct": 0, "total": 0}

                source_counts[source] += 1

                # Track accuracy
                source_accuracy[source]["total"] += 1
                if is_true_match == is_predicted:
                    source_accuracy[source]["correct"] += 1

    # Calculate accuracy rates
    for source in source_accuracy:
        if source_accuracy[source]["total"] > 0:
            source_accuracy[source]["accuracy"] = (
                source_accuracy[source]["correct"] / source_accuracy[source]["total"]
            )
        else:
            source_accuracy[source]["accuracy"] = 0

    complete_metrics["source_analysis"] = {
        "counts": source_counts,
        "accuracy": source_accuracy,
    }

    # Create source analysis visualization
    if source_counts:
        # Prepare data
        sources = list(source_counts.keys())
        counts = [source_counts[s] for s in sources]
        accuracies = [source_accuracy[s]["accuracy"] for s in sources]

        # Sort by count
        sorted_data = sorted(
            zip(sources, counts, accuracies), key=lambda x: x[1], reverse=True
        )
        sources = [x[0] for x in sorted_data]
        counts = [x[1] for x in sorted_data]
        accuracies = [x[2] for x in sorted_data]

        # Limit to top 15 sources for better visualization
        if len(sources) > 15:
            sources = sources[:15]
            counts = counts[:15]
            accuracies = accuracies[:15]

        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

        # Plot counts
        ax1.bar(sources, counts)
        ax1.set_xlabel("Confidence Sources")
        ax1.set_ylabel("Count")
        ax1.set_title("Usage Count by Confidence Source")
        ax1.set_xticklabels(sources, rotation=45, ha="right")
        ax1.grid(axis="y", linestyle="--", alpha=0.7)

        # Plot accuracy
        ax2.bar(sources, accuracies)
        ax2.set_xlabel("Confidence Sources")
        ax2.set_ylabel("Accuracy")
        ax2.set_title("Accuracy by Confidence Source")
        ax2.set_xticklabels(sources, rotation=45, ha="right")
        ax2.grid(axis="y", linestyle="--", alpha=0.7)

        plt.tight_layout()
        plt.savefig(f"{output_path_base}_source_analysis.png", dpi=300)
        plt.close()

    # 4. Temporal Analysis

    # Analyze performance over time by publication date
    date_metrics = {}

    for pub in publications_with_truth:
        pub_date = pub.get("pubdate")
        if not pub_date or len(pub_date) < 7:  # Ensure proper format
            continue

        year_month = pub_date  # Already in yyyy-mm format

        if year_month not in date_metrics:
            date_metrics[year_month] = {
                "total": 0,
                "correct": 0,
                "precision": [],
                "recall": [],
                "f1": [],
                "mcc": [],  # Added MCC
            }

        # Count publications
        date_metrics[year_month]["total"] += 1

        # Calculate accuracy for this publication
        true_models = set(pub.get("models", []))
        pred_models = set(pub.get("matched_models", []))

        # True positives
        tp = len(true_models.intersection(pred_models))
        # False positives
        fp = len(pred_models - true_models)
        # False negatives
        fn = len(true_models - pred_models)
        # True negatives (for MCC) - need to consider all models not in either set
        potential_models = set(all_models) - true_models - pred_models
        tn = len(potential_models)

        # Calculate metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0
        )
        
        # Calculate MCC for this publication
        # First create binary arrays for all models
        pub_y_true = [1 if m in true_models else 0 for m in all_models]
        pub_y_pred = [1 if m in pred_models else 0 for m in all_models]
        
        # Calculate MCC if we have valid predictions (at least one true or one pred)
        if sum(pub_y_true) > 0 or sum(pub_y_pred) > 0:
            mcc = matthews_corrcoef(pub_y_true, pub_y_pred)
        else:
            mcc = 0

        date_metrics[year_month]["precision"].append(precision)
        date_metrics[year_month]["recall"].append(recall)
        date_metrics[year_month]["f1"].append(f1)
        date_metrics[year_month]["mcc"].append(mcc)  # Add MCC

        # Mark as correct if all matches are correct
        if true_models == pred_models:
            date_metrics[year_month]["correct"] += 1

    # Calculate average metrics by date
    for date in date_metrics:
        if date_metrics[date]["total"] > 0:
            date_metrics[date]["accuracy"] = (
                date_metrics[date]["correct"] / date_metrics[date]["total"]
            )
            date_metrics[date]["avg_precision"] = np.mean(
                date_metrics[date]["precision"]
            )
            date_metrics[date]["avg_recall"] = np.mean(date_metrics[date]["recall"])
            date_metrics[date]["avg_f1"] = np.mean(date_metrics[date]["f1"])
            date_metrics[date]["avg_mcc"] = np.mean(date_metrics[date]["mcc"])  # Add avg MCC

    complete_metrics["temporal_analysis"] = date_metrics

    # Create temporal analysis visualization with MCC
    if date_metrics:
        # Sort dates chronologically
        sorted_dates = sorted(date_metrics.keys())

        if len(sorted_dates) > 1:  # Only plot if we have multiple dates
            accuracies = [date_metrics[d]["accuracy"] for d in sorted_dates]
            precisions = [date_metrics[d]["avg_precision"] for d in sorted_dates]
            recalls = [date_metrics[d]["avg_recall"] for d in sorted_dates]
            f1_scores = [date_metrics[d]["avg_f1"] for d in sorted_dates]
            mcc_scores = [date_metrics[d]["avg_mcc"] for d in sorted_dates]  # Added MCC

            fig, ax = plt.subplots(figsize=(12, 6))

            ax.plot(sorted_dates, accuracies, "o-", label="Accuracy")
            ax.plot(sorted_dates, precisions, "s-", label="Precision") 
            ax.plot(sorted_dates, recalls, "^-", label="Recall")
            ax.plot(sorted_dates, f1_scores, "D-", label="F1 Score")
            ax.plot(sorted_dates, mcc_scores, "X-", label="MCC")  # Added MCC plot

            ax.set_xlabel("Publication Date")
            ax.set_ylabel("Score")
            ax.set_title("Performance Metrics by Publication Date")
            ax.grid(True, linestyle="--", alpha=0.7)
            ax.legend()

            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(f"{output_path_base}_temporal_analysis.png", dpi=300)
            plt.close()

    # 5. Classification Accuracy Analysis

    # Analyze how classification results correlate with model matches
    research_area_stats = {}
    science_keyword_stats = {}
    division_stats = {}

    # Extract data
    for pub in publications_with_truth:
        # Check if we have classification data
        classifications = pub.get("classifications", {})
        if not classifications:
            continue

        # Get match accuracy for this publication
        true_models = set(pub.get("models", []))
        pred_models = set(pub.get("matched_models", []))
        is_correct = true_models == pred_models

        # Process research areas
        for area in classifications.get("research_areas", []):
            area_name = area.get("label", "")
            area_score = area.get("score", 0)

            if area_name and area_score > 0.3:  # Only consider significant areas
                if area_name not in research_area_stats:
                    research_area_stats[area_name] = {"correct": 0, "total": 0}

                research_area_stats[area_name]["total"] += 1
                if is_correct:
                    research_area_stats[area_name]["correct"] += 1

        # Process science keywords
        for keyword in classifications.get("science_keywords", []):
            kw_name = keyword.get("label", "")
            kw_score = keyword.get("score", 0)

            if kw_name and kw_score > 0.3:  # Only consider significant keywords
                if kw_name not in science_keyword_stats:
                    science_keyword_stats[kw_name] = {"correct": 0, "total": 0}

                science_keyword_stats[kw_name]["total"] += 1
                if is_correct:
                    science_keyword_stats[kw_name]["correct"] += 1

        # Process division
        division = classifications.get("division", {})
        if division:
            div_name = division.get("label", "")
            div_score = division.get("score", 0)

            if div_name and div_score > 0.5:  # Only consider significant divisions
                if div_name not in division_stats:
                    division_stats[div_name] = {"correct": 0, "total": 0}

                division_stats[div_name]["total"] += 1
                if is_correct:
                    division_stats[div_name]["correct"] += 1

    # Calculate accuracy rates
    for area in research_area_stats:
        if research_area_stats[area]["total"] > 0:
            research_area_stats[area]["accuracy"] = (
                research_area_stats[area]["correct"]
                / research_area_stats[area]["total"]
            )

    for kw in science_keyword_stats:
        if science_keyword_stats[kw]["total"] > 0:
            science_keyword_stats[kw]["accuracy"] = (
                science_keyword_stats[kw]["correct"]
                / science_keyword_stats[kw]["total"]
            )

    for div in division_stats:
        if division_stats[div]["total"] > 0:
            division_stats[div]["accuracy"] = (
                division_stats[div]["correct"] / division_stats[div]["total"]
            )

    complete_metrics["classification_accuracy"] = {
        "research_areas": research_area_stats,
        "science_keywords": science_keyword_stats,
        "divisions": division_stats,
    }

    # Create classification accuracy visualization
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))

    # Filter out classifications with too few samples
    def filter_and_sort(stats_dict, min_samples=5):
        filtered = {k: v for k, v in stats_dict.items() if v["total"] >= min_samples}
        return sorted(filtered.items(), key=lambda x: x[1]["accuracy"], reverse=True)

    # Plot research area accuracy
    sorted_areas = filter_and_sort(research_area_stats)
    if sorted_areas:
        area_names = [a[0] for a in sorted_areas]
        area_accuracies = [a[1]["accuracy"] for a in sorted_areas]
        area_counts = [a[1]["total"] for a in sorted_areas]

        # Limit to top 15 for better visualization
        if len(area_names) > 15:
            area_names = area_names[:15]
            area_accuracies = area_accuracies[:15]
            area_counts = area_counts[:15]

        bars = ax1.barh(area_names, area_accuracies)
        ax1.set_xlabel("Accuracy")
        ax1.set_ylabel("Research Area")
        ax1.set_title("Match Accuracy by Research Area")
        ax1.set_xlim(0, 1)

        # Add count labels to bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax1.text(
                width + 0.01,
                bar.get_y() + bar.get_height() / 2,
                f"n={area_counts[i]}",
                ha="left",
                va="center",
            )
    else:
        ax1.text(
            0.5,
            0.5,
            "Insufficient research area data",
            ha="center",
            va="center",
            transform=ax1.transAxes,
        )

    # Plot science keyword accuracy
    sorted_keywords = filter_and_sort(science_keyword_stats)
    if sorted_keywords:
        kw_names = [k[0] for k in sorted_keywords]
        kw_accuracies = [k[1]["accuracy"] for k in sorted_keywords]
        kw_counts = [k[1]["total"] for k in sorted_keywords]

        # Limit to top 15 for better visualization
        if len(kw_names) > 15:
            kw_names = kw_names[:15]
            kw_accuracies = kw_accuracies[:15]
            kw_counts = kw_counts[:15]

        bars = ax2.barh(kw_names, kw_accuracies)
        ax2.set_xlabel("Accuracy")
        ax2.set_ylabel("Science Keyword")
        ax2.set_title("Match Accuracy by Science Keyword")
        ax2.set_xlim(0, 1)

        # Add count labels to bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax2.text(
                width + 0.01,
                bar.get_y() + bar.get_height() / 2,
                f"n={kw_counts[i]}",
                ha="left",
                va="center",
            )
    else:
        ax2.text(
            0.5,
            0.5,
            "Insufficient science keyword data",
            ha="center",
            va="center",
            transform=ax2.transAxes,
        )

    # Plot division accuracy
    sorted_divisions = filter_and_sort(division_stats)
    if sorted_divisions:
        div_names = [d[0] for d in sorted_divisions]
        div_accuracies = [d[1]["accuracy"] for d in sorted_divisions]
        div_counts = [d[1]["total"] for d in sorted_divisions]

        bars = ax3.barh(div_names, div_accuracies)
        ax3.set_xlabel("Accuracy")
        ax3.set_ylabel("Division")
        ax3.set_title("Match Accuracy by Division")
        ax3.set_xlim(0, 1)

        # Add count labels to bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax3.text(
                width + 0.01,
                bar.get_y() + bar.get_height() / 2,
                f"n={div_counts[i]}",
                ha="left",
                va="center",
            )
    else:
        ax3.text(
            0.5,
            0.5,
            "Insufficient division data",
            ha="center",
            va="center",
            transform=ax3.transAxes,
        )

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_classification_accuracy.png", dpi=300)
    plt.close()

    # 6. Threshold Analysis (New)

    # For each model, analyze different confidence thresholds
    threshold_analysis = {}

    # Generate possible threshold values to test
    thresholds = np.arange(0.1, 1.0, 0.05)

    for model in all_models:
        threshold_metrics = {
            t: {"precision": 0, "recall": 0, "f1": 0, "mcc": 0, "tp": 0, "fp": 0, "fn": 0, "tn": 0}  # Added MCC
            for t in thresholds
        }

        # Extract confidence scores and true labels for this model
        confidence_scores = []
        true_labels = []

        for pub in publications_with_truth:
            score = pub.get("confidence_scores", {}).get(model, 0)
            is_true_match = 1 if model in pub.get("models", []) else 0

            confidence_scores.append(score)
            true_labels.append(is_true_match)

        # Calculate metrics at each threshold
        for threshold in thresholds:
            # Generate predictions using this threshold
            predicted_labels = [
                1 if score >= threshold else 0 for score in confidence_scores
            ]

            # Calculate performance metrics
            tp = sum(
                1 for t, p in zip(true_labels, predicted_labels) if t == 1 and p == 1
            )
            fp = sum(
                1 for t, p in zip(true_labels, predicted_labels) if t == 0 and p == 1
            )
            fn = sum(
                1 for t, p in zip(true_labels, predicted_labels) if t == 1 and p == 0
            )
            tn = sum(
                1 for t, p in zip(true_labels, predicted_labels) if t == 0 and p == 0
            )  # Added for MCC

            # Precision, recall, and F1
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = (
                2 * precision * recall / (precision + recall)
                if (precision + recall) > 0
                else 0
            )
            
            # Calculate MCC
            mcc = matthews_corrcoef(true_labels, predicted_labels)  # Added MCC calculation

            threshold_metrics[threshold] = {
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "mcc": mcc,  # Added MCC
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "tn": tn,  # Added for MCC
            }

        # Find optimal threshold based on F1 score
        f1_scores = [(t, metrics["f1"]) for t, metrics in threshold_metrics.items()]
        optimal_threshold_f1 = max(f1_scores, key=lambda x: x[1])[0] if f1_scores else 0.5
        
        # Find optimal threshold based on MCC
        mcc_scores = [(t, metrics["mcc"]) for t, metrics in threshold_metrics.items()]
        optimal_threshold_mcc = max(mcc_scores, key=lambda x: x[1])[0] if mcc_scores else 0.5

        threshold_analysis[model] = {
            "metrics": threshold_metrics,
            "optimal_threshold_f1": optimal_threshold_f1,
            "optimal_f1": threshold_metrics[optimal_threshold_f1]["f1"],
            "optimal_threshold_mcc": optimal_threshold_mcc,  # Added optimal MCC threshold
            "optimal_mcc": threshold_metrics[optimal_threshold_mcc]["mcc"],  # Added optimal MCC value
            "current_threshold": MODEL_THRESHOLDS.get(model, 0.4),
        }

    # Calculate overall threshold analysis
    overall_threshold_metrics = {
        t: {"precision": 0, "recall": 0, "f1": 0, "mcc": 0} for t in thresholds  # Added MCC
    }

    # Get all confidence scores and labels across all models
    all_confidence_scores = []
    all_true_labels = []

    for pub in publications_with_truth:
        for model in all_models:
            score = pub.get("confidence_scores", {}).get(model, 0)
            is_true_match = 1 if model in pub.get("models", []) else 0

            all_confidence_scores.append(score)
            all_true_labels.append(is_true_match)

    # Calculate metrics for each threshold
    for threshold in thresholds:
        predicted_labels = [
            1 if score >= threshold else 0 for score in all_confidence_scores
        ]

        precision = precision_score(all_true_labels, predicted_labels, zero_division=0)
        recall = recall_score(all_true_labels, predicted_labels, zero_division=0)
        f1 = f1_score(all_true_labels, predicted_labels, zero_division=0)
        mcc = matthews_corrcoef(all_true_labels, predicted_labels)  # Added MCC

        overall_threshold_metrics[threshold] = {
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "mcc": mcc,  # Added MCC
        }

    # Find optimal overall thresholds (F1 and MCC)
    overall_f1_scores = [
        (t, metrics["f1"]) for t, metrics in overall_threshold_metrics.items()
    ]
    overall_optimal_threshold_f1 = (
        max(overall_f1_scores, key=lambda x: x[1])[0] if overall_f1_scores else 0.5
    )
    
    overall_mcc_scores = [
        (t, metrics["mcc"]) for t, metrics in overall_threshold_metrics.items()
    ]
    overall_optimal_threshold_mcc = (
        max(overall_mcc_scores, key=lambda x: x[1])[0] if overall_mcc_scores else 0.5
    )

    threshold_analysis["overall"] = {
        "metrics": overall_threshold_metrics,
        "optimal_threshold_f1": overall_optimal_threshold_f1,
        "optimal_f1": overall_threshold_metrics[overall_optimal_threshold_f1]["f1"],
        "optimal_threshold_mcc": overall_optimal_threshold_mcc,  # Added optimal MCC threshold
        "optimal_mcc": overall_threshold_metrics[overall_optimal_threshold_mcc]["mcc"],  # Added optimal MCC value
        "current_threshold": 0.4,  # Default overall threshold
    }

    complete_metrics["threshold_analysis"] = threshold_analysis

    # Create threshold analysis visualizations
    # 1. Model-specific threshold analysis with MCC
    for model in all_models:
        fig, ax = plt.subplots(figsize=(10, 6))

        model_thresholds = sorted(list(threshold_analysis[model]["metrics"].keys()))
        precision_values = [
            threshold_analysis[model]["metrics"][t]["precision"]
            for t in model_thresholds
        ]
        recall_values = [
            threshold_analysis[model]["metrics"][t]["recall"] for t in model_thresholds
        ]
        f1_values = [
            threshold_analysis[model]["metrics"][t]["f1"] for t in model_thresholds
        ]
        mcc_values = [
            threshold_analysis[model]["metrics"][t]["mcc"] for t in model_thresholds
        ]  # Added MCC values

        ax.plot(model_thresholds, precision_values, "b-", label="Precision")
        ax.plot(model_thresholds, recall_values, "g-", label="Recall")
        ax.plot(model_thresholds, f1_values, "r-", label="F1 Score")
        ax.plot(model_thresholds, mcc_values, "m-", label="MCC")  # Added MCC plot

        # Mark current and optimal thresholds
        current_threshold = threshold_analysis[model]["current_threshold"]
        optimal_threshold_f1 = threshold_analysis[model]["optimal_threshold_f1"]
        optimal_threshold_mcc = threshold_analysis[model]["optimal_threshold_mcc"]  # Added MCC threshold

        ax.axvline(
            x=current_threshold,
            color="gray",
            linestyle="--",
            label=f"Current Threshold ({current_threshold:.2f})",
        )
        ax.axvline(
            x=optimal_threshold_f1,
            color="red",
            linestyle="-.",
            label=f"Optimal F1 Threshold ({optimal_threshold_f1:.2f})",
        )
        ax.axvline(
            x=optimal_threshold_mcc,
            color="magenta",
            linestyle="-.",
            label=f"Optimal MCC Threshold ({optimal_threshold_mcc:.2f})",
        )  # Added MCC threshold line

        ax.set_xlabel("Confidence Threshold")
        ax.set_ylabel("Score")
        ax.set_title(f"Threshold Analysis for {model}")
        ax.legend()
        ax.grid(True, linestyle="--", alpha=0.7)

        plt.tight_layout()
        plt.savefig(f"{output_path_base}_threshold_{model}.png", dpi=300)
        plt.close()

    # 2. Overall threshold analysis with MCC
    fig, ax = plt.subplots(figsize=(10, 6))

    overall_thresholds = sorted(list(threshold_analysis["overall"]["metrics"].keys()))
    precision_values = [
        threshold_analysis["overall"]["metrics"][t]["precision"]
        for t in overall_thresholds
    ]
    recall_values = [
        threshold_analysis["overall"]["metrics"][t]["recall"]
        for t in overall_thresholds
    ]
    f1_values = [
        threshold_analysis["overall"]["metrics"][t]["f1"] for t in overall_thresholds
    ]
    mcc_values = [
        threshold_analysis["overall"]["metrics"][t]["mcc"] for t in overall_thresholds
    ]  # Added MCC values

    ax.plot(overall_thresholds, precision_values, "b-", label="Precision")
    ax.plot(overall_thresholds, recall_values, "g-", label="Recall")
    ax.plot(overall_thresholds, f1_values, "r-", label="F1 Score")
    ax.plot(overall_thresholds, mcc_values, "m-", label="MCC")  # Added MCC plot

    # Mark current and optimal thresholds
    current_threshold = threshold_analysis["overall"]["current_threshold"]
    optimal_threshold_f1 = threshold_analysis["overall"]["optimal_threshold_f1"]
    optimal_threshold_mcc = threshold_analysis["overall"]["optimal_threshold_mcc"]  # Added MCC threshold

    ax.axvline(
        x=current_threshold,
        color="gray",
        linestyle="--",
        label=f"Current Threshold ({current_threshold:.2f})",
    )
    ax.axvline(
        x=optimal_threshold_f1,
        color="red",
        linestyle="-.",
        label=f"Optimal F1 Threshold ({optimal_threshold_f1:.2f})",
    )
    ax.axvline(
        x=optimal_threshold_mcc,
        color="magenta",
        linestyle="-.",
        label=f"Optimal MCC Threshold ({optimal_threshold_mcc:.2f})",
    )  # Added MCC threshold line

    ax.set_xlabel("Confidence Threshold")
    ax.set_ylabel("Score")
    ax.set_title("Overall Threshold Analysis")
    ax.legend()
    ax.grid(True, linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_threshold_overall.png", dpi=300)
    plt.close()

    # 3. Comparative threshold analysis with MCC
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))

    # Create a bar chart comparing current vs. optimal thresholds (F1)
    model_names = all_models
    current_thresholds = [
        threshold_analysis[model]["current_threshold"] for model in model_names
    ]
    optimal_thresholds_f1 = [
        threshold_analysis[model]["optimal_threshold_f1"] for model in model_names
    ]
    optimal_thresholds_mcc = [
        threshold_analysis[model]["optimal_threshold_mcc"] for model in model_names
    ]  # Added MCC thresholds

    x = np.arange(len(model_names))
    width = 0.25  # Adjusted for 3 sets of bars

    bars1 = ax1.bar(x - width, current_thresholds, width, label="Current Threshold")
    bars2 = ax1.bar(x, optimal_thresholds_f1, width, label="Optimal F1 Threshold")
    bars3 = ax1.bar(x + width, optimal_thresholds_mcc, width, label="Optimal MCC Threshold")  # Added MCC bars

    ax1.set_xlabel("Model")
    ax1.set_ylabel("Threshold Value")
    ax1.set_title("Current vs. Optimal Thresholds by Model")
    ax1.set_xticks(x)
    ax1.set_xticklabels(model_names, rotation=45, ha="right")
    ax1.legend()
    ax1.grid(axis="y", linestyle="--", alpha=0.7)

    # Create a bar chart comparing F1 vs MCC scores at their optimal thresholds
    optimal_f1_scores = [threshold_analysis[model]["optimal_f1"] for model in model_names]
    optimal_mcc_scores = [threshold_analysis[model]["optimal_mcc"] for model in model_names]

    bars4 = ax2.bar(x - width/2, optimal_f1_scores, width, label="Optimal F1 Score")
    bars5 = ax2.bar(x + width/2, optimal_mcc_scores, width, label="Optimal MCC Score")

    ax2.set_xlabel("Model")
    ax2.set_ylabel("Score")
    ax2.set_title("Optimal F1 vs MCC Scores by Model")
    ax2.set_xticks(x)
    ax2.set_xticklabels(model_names, rotation=45, ha="right")
    ax2.legend()
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_threshold_comparison.png", dpi=300)
    plt.close()

    # 7. Confusion Matrix Analysis

    # Create a confusion matrix for all models
    # Initialize the confusion matrix with zeros
    num_models = len(all_models)
    confusion_matrix = np.zeros((num_models, num_models), dtype=int)

    # Fill the confusion matrix
    for pub in publications_with_truth:
        true_models = pub.get("models", [])
        pred_models = pub.get("matched_models", [])

        for i, true_model in enumerate(all_models):
            for j, pred_model in enumerate(all_models):
                # Check if the true model is present in the ground truth
                # and the predicted model is in the predictions
                if true_model in true_models and pred_model in pred_models:
                    confusion_matrix[i, j] += 1

    # Store the confusion matrix in the metrics dictionary
    complete_metrics["confusion_matrix"] = {
        "matrix": confusion_matrix.tolist(),
        "model_names": all_models,
    }

    # Visualize the confusion matrix
    fig, ax = plt.subplots(figsize=(12, 10))
    im = ax.imshow(confusion_matrix, interpolation="nearest", cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)

    # Set ticks and labels
    ax.set_xticks(np.arange(num_models))
    ax.set_yticks(np.arange(num_models))
    ax.set_xticklabels(all_models, rotation=45, ha="right")
    ax.set_yticklabels(all_models)

    # Loop over data dimensions and create text annotations
    for i in range(num_models):
        for j in range(num_models):
            text_color = (
                "white"
                if confusion_matrix[i, j] > confusion_matrix.max() / 2
                else "black"
            )
            ax.text(
                j, i, confusion_matrix[i, j], ha="center", va="center", color=text_color
            )

    ax.set_xlabel("Predicted Model")
    ax.set_ylabel("True Model")
    ax.set_title("Model Confusion Matrix")

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_confusion_matrix.png", dpi=300)
    plt.close()

    # 8. TP/TN/FP/FN Analysis for each model

    # Calculate these metrics for each model
    confusion_stats = {}

    for i, model in enumerate(all_models):
        # True positives, false positives, etc.
        tp = model_metrics[model]["true_positives"]
        fp = model_metrics[model]["false_positives"]
        fn = model_metrics[model]["false_negatives"]
        tn = model_metrics[model]["true_negatives"]  # Added for MCC

        confusion_stats[model] = {"tp": tp, "fp": fp, "tn": tn, "fn": fn}

    complete_metrics["confusion_stats"] = confusion_stats

    # Visualize TP/TN/FP/FN for each model
    fig, ax = plt.subplots(figsize=(14, 8))

    # Prepare data for grouped bar chart
    bar_width = 0.2
    index = np.arange(len(all_models))

    # Extract metrics for each type
    tp_values = [confusion_stats[model]["tp"] for model in all_models]
    tn_values = [confusion_stats[model]["tn"] for model in all_models]
    fp_values = [confusion_stats[model]["fp"] for model in all_models]
    fn_values = [confusion_stats[model]["fn"] for model in all_models]

    # Create bars
    bars1 = ax.bar(
        index - 1.5 * bar_width,
        tp_values,
        bar_width,
        label="True Positives",
        color="green",
    )
    bars2 = ax.bar(
        index - 0.5 * bar_width,
        tn_values,
        bar_width,
        label="True Negatives",
        color="blue",
    )
    bars3 = ax.bar(
        index + 0.5 * bar_width,
        fp_values,
        bar_width,
        label="False Positives",
        color="red",
    )
    bars4 = ax.bar(
        index + 1.5 * bar_width,
        fn_values,
        bar_width,
        label="False Negatives",
        color="orange",
    )

    # Add some text for labels, title and custom axis
    ax.set_xlabel("Model")
    ax.set_ylabel("Count")
    ax.set_title("Confusion Matrix Statistics by Model")
    ax.set_xticks(index)
    ax.set_xticklabels(all_models, rotation=45, ha="right")
    ax.legend()

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_confusion_stats.png", dpi=300)
    plt.close()

    # Create a summary visualization including MCC
    fig, axs = plt.subplots(2, 2, figsize=(15, 12))

    # Top left: Model F1 scores
    if all_models:
        f1_values = [model_metrics[model]["f1"] for model in all_models]
        mcc_values = [model_metrics[model]["mcc"] for model in all_models]  # Added MCC values
        
        # Create double bar chart
        x = np.arange(len(all_models))
        width = 0.35
        
        bars1 = axs[0, 0].bar(x - width/2, f1_values, width, label="F1 Score")
        bars2 = axs[0, 0].bar(x + width/2, mcc_values, width, label="MCC")
        
        axs[0, 0].set_title("F1 Score and MCC by Model")
        axs[0, 0].set_xlabel("Model")
        axs[0, 0].set_ylabel("Score")
        axs[0, 0].set_xticks(x)
        axs[0, 0].set_xticklabels(all_models, rotation=45, ha="right")
        axs[0, 0].grid(axis="y", linestyle="--", alpha=0.7)
        axs[0, 0].legend()

        # Add labels
        for bar in bars1:
            height = bar.get_height()
            axs[0, 0].text(
                bar.get_x() + bar.get_width() / 2.0,
                height + 0.01,
                f"{height:.2f}",
                ha="center",
                va="bottom",
            )
            
        for bar in bars2:
            height = bar.get_height()
            axs[0, 0].text(
                bar.get_x() + bar.get_width() / 2.0,
                height + 0.01,
                f"{height:.2f}",
                ha="center",
                va="bottom",
            )
    else:
        axs[0, 0].text(
            0.5,
            0.5,
            "No model data available",
            ha="center",
            va="center",
            transform=axs[0, 0].transAxes,
        )

    # Top right: Sources accuracy (top 5)
    if source_accuracy:
        # Get top 5 most used sources
        top_sources = sorted(source_counts.items(), key=lambda x: x[1], reverse=True)[
            :5
        ]
        top_source_names = [s[0] for s in top_sources]

        source_accuracies = [source_accuracy[s]["accuracy"] for s in top_source_names]
        source_counts_plot = [source_counts[s] for s in top_source_names]

        bars = axs[0, 1].bar(top_source_names, source_accuracies)
        axs[0, 1].set_title("Accuracy by Top 5 Confidence Sources")
        axs[0, 1].set_xlabel("Source")
        axs[0, 1].set_ylabel("Accuracy")
        axs[0, 1].set_xticklabels(top_source_names, rotation=45, ha="right")
        axs[0, 1].grid(axis="y", linestyle="--", alpha=0.7)

        # Add count labels
        for i, bar in enumerate(bars):
            height = bar.get_height()
            axs[0, 1].text(
                bar.get_x() + bar.get_width() / 2.0,
                height + 0.01,
                f"n={source_counts_plot[i]}",
                ha="center",
                va="bottom",
            )
    else:
        axs[0, 1].text(
            0.5,
            0.5,
            "No source data available",
            ha="center",
            va="center",
            transform=axs[0, 1].transAxes,
        )

    # Bottom left: Classification accuracy
    if research_area_stats or science_keyword_stats or division_stats:
        # Create a summary of classification performance
        classification_summary = {}

        # Average accuracy by type
        if research_area_stats:
            values = [
                s["accuracy"] for s in research_area_stats.values() if s["total"] >= 5
            ]
            if values:
                classification_summary["Research Areas"] = np.mean(values)

        if science_keyword_stats:
            values = [
                s["accuracy"] for s in science_keyword_stats.values() if s["total"] >= 5
            ]
            if values:
                classification_summary["Science Keywords"] = np.mean(values)

        if division_stats:
            values = [s["accuracy"] for s in division_stats.values() if s["total"] >= 5]
            if values:
                classification_summary["Divisions"] = np.mean(values)

        if classification_summary:
            names = list(classification_summary.keys())
            values = list(classification_summary.values())

            bars = axs[1, 0].bar(names, values)
            axs[1, 0].set_title("Average Accuracy by Classification Type")
            axs[1, 0].set_xlabel("Classification Type")
            axs[1, 0].set_ylabel("Average Accuracy")
            axs[1, 0].grid(axis="y", linestyle="--", alpha=0.7)

            # Add labels
            for bar in bars:
                height = bar.get_height()
                axs[1, 0].text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height + 0.01,
                    f"{height:.2f}",
                    ha="center",
                    va="bottom",
                )
        else:
            axs[1, 0].text(
                0.5,
                0.5,
                "Insufficient classification data",
                ha="center",
                va="center",
                transform=axs[1, 0].transAxes,
            )
    else:
        axs[1, 0].text(
            0.5,
            0.5,
            "No classification data available",
            ha="center",
            va="center",
            transform=axs[1, 0].transAxes,
        )

    # Bottom right: Micro-average performance metrics including MCC
    micro_metrics = {
        "Precision": micro_precision,
        "Recall": micro_recall,
        "F1 Score": micro_f1,
        "MCC": micro_mcc,  # Added MCC
    }

    bars = axs[1, 1].bar(micro_metrics.keys(), micro_metrics.values())
    axs[1, 1].set_title("Overall Performance Metrics")
    axs[1, 1].set_xlabel("Metric")
    axs[1, 1].set_ylabel("Score")
    axs[1, 1].grid(axis="y", linestyle="--", alpha=0.7)

    # Add labels
    for bar in bars:
        height = bar.get_height()
        axs[1, 1].text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
        )

    plt.tight_layout()
    plt.savefig(f"{output_path_base}_summary.png", dpi=300)
    plt.close()

    # Save all metrics to JSON
    with open(f"{output_path_base}_complete.json", "w") as f:
        # Convert NumPy values to Python types for JSON serialization
        def convert_numpy(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {k: convert_numpy(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [convert_numpy(i) for i in obj]
            else:
                return obj

        json.dump(convert_numpy(complete_metrics), f, indent=2)

    print(
        f"Comprehensive metrics visualizations saved with base path: {output_path_base}"
    )

    # Return all metrics for further analysis
    return complete_metrics

In [42]:
def main():
    start_time = time.time()

    print("Starting optimized science publication classifier...")
    print(f"CUDA available: {torch.cuda.is_available()}")

    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

    # Initialize classifier
    print("Initializing science classifier...")
    with torch.inference_mode():
        science_classifier = ScienceClassifier.get_instance()
    print("Science classifier initialized successfully.")

    # Load configuration files
    print("Loading configuration files...")
    curated_mapping = load_curated_models('./curated_publications.json')
    model_keywords = load_model_keywords('./model_keywords.json')
    model_descriptions = load_model_descriptions('./model_descriptions.json')

    # Initialize ranker
    print("Initializing relevance ranker...")
    with torch.inference_mode():
        ranker = RelevanceRanker(model_descriptions)

    # Initialize model embeddings
    print("Initializing model embeddings...")
    with torch.inference_mode():
        model_embeddings = initialize_model_embeddings(model_descriptions)

    # Clean memory
    optimize_memory()

    # Load test data
    print("Loading test publications...")
    try:
        with open('./all_2025-2-19.json') as f:
            test_data = json.load(f)
        print(f"Loaded {len(test_data)} test publications.")
    except Exception as e:
        print(f"Error loading test data: {str(e)}")
        return

    # Initialize context manager
    print("Building context validation profiles...")
    with open('./curated_publications.json') as f:
        full_curated = json.load(f)

    with torch.inference_mode():
        context_manager = ModelContextManager(full_curated)

    # Derive data-driven affinities from the curated dataset
    print("Deriving data-driven affinities from curated dataset...")
    derive_data_driven_affinities('./curated_publications.json')

    # Process publications
    print("\nProcessing publications...")
    results = process_publication_batch(
        test_data,
        curated_mapping,
        model_keywords,
        model_embeddings,
        science_classifier,
        context_manager,
        ranker=ranker
    )

    # Save results
    print("Saving results to results.json...")
    with open('results.json', 'w') as f:
        json.dump(results, f, indent=2)

    # Generate visualizations with threshold analysis
    print("\nGenerating visualizations with threshold analysis...")
    metrics = visualize_metrics(results)

    # Show TF-IDF model-specific terms
    print("\nModel-specific terminology based on TF-IDF analysis:")
    model_specific_terms = context_manager.get_model_specific_terms()
    for model, terms in model_specific_terms.items():
        print(f"\n{model} distinctive terms:")
        print(", ".join(terms[:10]))  # Show top 10 terms

    # Find optimal thresholds
    print("\nFinding optimal thresholds...")
    try:
        optimal_thresholds = find_optimal_thresholds(results)

        print("\nOptimal model-specific thresholds:")
        for model, data in optimal_thresholds['per_model'].items():
            current = MODEL_THRESHOLDS.get(model, 0.4)
            # Use threshold_f1 instead of threshold
            print(f"{model}: {data['threshold_f1']:.2f} (current: {current:.2f}, F1: {data['f1']:.2f})")

        print(f"\nOptimal overall threshold: {optimal_thresholds['overall']['threshold_f1']:.2f}")
        print(f"Overall F1 score with optimal thresholds: {optimal_thresholds['overall']['f1']:.3f}")

        # Compare performance with current vs. optimal thresholds
        print("\nComparing performance with current vs. optimal thresholds:")
        current_performance = analyze_threshold_performance(results)
        optimal_performance = analyze_threshold_performance(
            results,
            model_thresholds=optimal_thresholds['model_thresholds_f1'],
            overall_threshold=optimal_thresholds['overall']['threshold_f1']
        )

        print(f"Current F1: {current_performance['overall']['f1']:.3f}, " +
            f"Optimal F1: {optimal_performance['overall']['f1']:.3f}, " +
            f"Improvement: {(optimal_performance['overall']['f1'] - current_performance['overall']['f1']) * 100:.1f}%")
    except Exception as e:
        print(f"Error during threshold optimization: {str(e)}")

    # Report completion
    total_time = time.time() - start_time
    print(f"\nProcessing completed in {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average time per publication: {total_time/len(test_data):.4f}s")

    if torch.cuda.is_available():
        print(f"Peak GPU memory usage: {torch.cuda.max_memory_allocated(0) / 1024**2:.2f} MB")

In [43]:
# if __name__ == "__main__":
#    main()

In [44]:
# Evaluating model generalization to unseen data

def evaluate_generalization(
    results: List[Dict], test_split_ratio: float = 0.3, random_seed: int = 42
) -> Dict:
    # Evaluate how well the classifier generalizes to unseen data by:
    # 1. Splitting the dataset into training and testing sets
    # 2. Comparing performance metrics between the sets
    # 3. Analyzing the model's confidence distribution on both sets

    # Set random seed for reproducibility
    np.random.seed(random_seed)

    # Extract publications with ground truth
    publications_with_truth = [
        result
        for result in results
        if "models" in result and result["models"] and "confidence_scores" in result
    ]

    if not publications_with_truth:
        return {"error": "No publications with ground truth for evaluation"}

    # Shuffle the data
    np.random.shuffle(publications_with_truth)

    # Split data into training and testing sets
    split_idx = int(len(publications_with_truth) * (1 - test_split_ratio))
    train_data = publications_with_truth[:split_idx]
    test_data = publications_with_truth[split_idx:]

    print(
        f"Split data into {len(train_data)} training and {len(test_data)} testing publications"
    )

    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get("models", []))
    all_models = sorted(list(all_models))

    # Calculate metrics for each set
    generalization_metrics = {
        "training": calculate_set_metrics(train_data, all_models),
        "testing": calculate_set_metrics(test_data, all_models),
        "difference": {},
    }

    # Calculate differences between training and testing metrics
    for model in all_models:
        if (
            model in generalization_metrics["training"]["per_model"]
            and model in generalization_metrics["testing"]["per_model"]
        ):
            train_metrics = generalization_metrics["training"]["per_model"][model]
            test_metrics = generalization_metrics["testing"]["per_model"][model]

            generalization_metrics["difference"][model] = {
                "precision_diff": test_metrics["precision"]
                - train_metrics["precision"],
                "recall_diff": test_metrics["recall"] - train_metrics["recall"],
                "f1_diff": test_metrics["f1"] - train_metrics["f1"],
            }

    # Calculate overall difference
    train_overall = generalization_metrics["training"]["overall"]
    test_overall = generalization_metrics["testing"]["overall"]

    generalization_metrics["difference"]["overall"] = {
        "precision_diff": test_overall["precision"] - train_overall["precision"],
        "recall_diff": test_overall["recall"] - train_overall["recall"],
        "f1_diff": test_overall["f1"] - train_overall["f1"],
    }

    # Analyze confidence score distribution for correct and incorrect predictions
    generalization_metrics["confidence_analysis"] = analyze_confidence_distribution(
        train_data, test_data, all_models
    )

    # Calculate generalization gap metrics
    generalization_gap = abs(train_overall["f1"] - test_overall["f1"])
    generalization_metrics["generalization_gap"] = {
        "f1_gap": generalization_gap,
        "relative_gap": generalization_gap / train_overall["f1"]
        if train_overall["f1"] > 0
        else float("inf"),
        "gap_assessment": assess_generalization_gap(generalization_gap),
    }

    # Generate visualizations
    visualize_generalization_metrics(generalization_metrics, all_models)

    return generalization_metrics


def calculate_set_metrics(data: List[Dict], all_models: List[str]) -> Dict:
    # Calculate performance metrics for a specific dataset
    model_metrics = {}

    # Calculate per-model metrics
    for model in all_models:
        y_true = []
        y_pred = []

        for pub in data:
            is_true_match = 1 if model in pub.get("models", []) else 0
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            model_threshold = MODEL_THRESHOLDS.get(model, 0.4)
            is_predicted = 1 if confidence >= model_threshold else 0

            y_true.append(is_true_match)
            y_pred.append(is_predicted)

        # Calculate metrics
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)

        model_metrics[model] = {
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "true_positives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1
            ),
            "false_positives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1
            ),
            "false_negatives": sum(
                1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0
            ),
        }

    # Calculate overall metrics
    all_y_true = []
    all_y_pred = []

    for pub in data:
        for model in all_models:
            is_true_match = 1 if model in pub.get("models", []) else 0
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            model_threshold = MODEL_THRESHOLDS.get(model, 0.4)
            is_predicted = 1 if confidence >= model_threshold else 0

            all_y_true.append(is_true_match)
            all_y_pred.append(is_predicted)

    micro_precision = precision_score(all_y_true, all_y_pred, zero_division=0)
    micro_recall = recall_score(all_y_true, all_y_pred, zero_division=0)
    micro_f1 = f1_score(all_y_true, all_y_pred, zero_division=0)

    return {
        "per_model": model_metrics,
        "overall": {
            "precision": micro_precision,
            "recall": micro_recall,
            "f1": micro_f1,
        },
    }


def analyze_confidence_distribution(
    train_data: List[Dict], test_data: List[Dict], all_models: List[str]
) -> Dict:
    # Analyze confidence score distributions for correct and incorrect predictions
    result = {"train": {}, "test": {}}

    for dataset_name, dataset in [("train", train_data), ("test", test_data)]:
        # Overall confidence distribution
        all_confidences = {"correct": [], "incorrect": []}

        # Per-model confidence distributions
        model_confidences = {
            model: {"correct": [], "incorrect": []} for model in all_models
        }

        for pub in dataset:
            true_models = set(pub.get("models", []))

            for model in all_models:
                confidence = pub.get("confidence_scores", {}).get(model, 0)
                model_threshold = MODEL_THRESHOLDS.get(model, 0.4)
                is_predicted = confidence >= model_threshold
                is_true_match = model in true_models

                # Determine if prediction is correct
                is_correct = (is_predicted and is_true_match) or (
                    not is_predicted and not is_true_match
                )

                # Store confidence score
                category = "correct" if is_correct else "incorrect"
                all_confidences[category].append(confidence)
                model_confidences[model][category].append(confidence)

        # Calculate statistics for all models
        result[dataset_name]["all_models"] = {
            "correct": {
                "count": len(all_confidences["correct"]),
                "mean": np.mean(all_confidences["correct"])
                if all_confidences["correct"]
                else 0,
                "median": np.median(all_confidences["correct"])
                if all_confidences["correct"]
                else 0,
                "std": np.std(all_confidences["correct"])
                if all_confidences["correct"]
                else 0,
            },
            "incorrect": {
                "count": len(all_confidences["incorrect"]),
                "mean": np.mean(all_confidences["incorrect"])
                if all_confidences["incorrect"]
                else 0,
                "median": np.median(all_confidences["incorrect"])
                if all_confidences["incorrect"]
                else 0,
                "std": np.std(all_confidences["incorrect"])
                if all_confidences["incorrect"]
                else 0,
            },
        }

        # Calculate statistics for each model
        result[dataset_name]["per_model"] = {}
        for model in all_models:
            result[dataset_name]["per_model"][model] = {
                "correct": {
                    "count": len(model_confidences[model]["correct"]),
                    "mean": np.mean(model_confidences[model]["correct"])
                    if model_confidences[model]["correct"]
                    else 0,
                    "median": np.median(model_confidences[model]["correct"])
                    if model_confidences[model]["correct"]
                    else 0,
                    "std": np.std(model_confidences[model]["correct"])
                    if model_confidences[model]["correct"]
                    else 0,
                },
                "incorrect": {
                    "count": len(model_confidences[model]["incorrect"]),
                    "mean": np.mean(model_confidences[model]["incorrect"])
                    if model_confidences[model]["incorrect"]
                    else 0,
                    "median": np.median(model_confidences[model]["incorrect"])
                    if model_confidences[model]["incorrect"]
                    else 0,
                    "std": np.std(model_confidences[model]["incorrect"])
                    if model_confidences[model]["incorrect"]
                    else 0,
                },
            }

    # Calculate differences between train and test sets
    result["diff"] = {
        "all_models": {
            "correct_mean_diff": result["test"]["all_models"]["correct"]["mean"]
            - result["train"]["all_models"]["correct"]["mean"],
            "incorrect_mean_diff": result["test"]["all_models"]["incorrect"]["mean"]
            - result["train"]["all_models"]["incorrect"]["mean"],
            "correct_median_diff": result["test"]["all_models"]["correct"]["median"]
            - result["train"]["all_models"]["correct"]["median"],
            "incorrect_median_diff": result["test"]["all_models"]["incorrect"]["median"]
            - result["train"]["all_models"]["incorrect"]["median"],
        },
        "per_model": {},
    }

    for model in all_models:
        result["diff"]["per_model"][model] = {
            "correct_mean_diff": result["test"]["per_model"][model]["correct"]["mean"]
            - result["train"]["per_model"][model]["correct"]["mean"],
            "incorrect_mean_diff": result["test"]["per_model"][model]["incorrect"][
                "mean"
            ]
            - result["train"]["per_model"][model]["incorrect"]["mean"],
        }

    return result


def assess_generalization_gap(gap: float) -> str:
    # Assess the generalization gap qualitatively
    if gap < 0.03:
        return "Excellent generalization - model performs consistently on unseen data"
    elif gap < 0.05:
        return "Good generalization - minimal performance degradation on unseen data"
    elif gap < 0.1:
        return "Fair generalization - noticeable performance drop on unseen data"
    elif gap < 0.2:
        return "Poor generalization - significant performance drop on unseen data"
    else:
        return "Very poor generalization - model is likely overfitting"


def visualize_generalization_metrics(metrics: Dict, all_models: List[str]):
    # Generate visualizations for generalization metrics
    # 1. Train vs Test overall performance
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot overall metrics comparison
    metrics_names = ["Precision", "Recall", "F1 Score"]
    train_values = [
        metrics["training"]["overall"]["precision"],
        metrics["training"]["overall"]["recall"],
        metrics["training"]["overall"]["f1"],
    ]
    test_values = [
        metrics["testing"]["overall"]["precision"],
        metrics["testing"]["overall"]["recall"],
        metrics["testing"]["overall"]["f1"],
    ]

    x = np.arange(len(metrics_names))
    width = 0.35

    ax1.bar(x - width / 2, train_values, width, label="Training")
    ax1.bar(x + width / 2, test_values, width, label="Testing")

    ax1.set_xlabel("Metrics")
    ax1.set_ylabel("Score")
    ax1.set_title("Overall Metrics: Training vs Testing")
    ax1.set_xticks(x)
    ax1.set_xticklabels(metrics_names)
    ax1.legend()
    ax1.grid(axis="y", linestyle="--", alpha=0.7)

    # Add value labels
    for i, v in enumerate(train_values):
        ax1.text(i - width / 2, v + 0.01, f"{v:.3f}", ha="center", va="bottom")
    for i, v in enumerate(test_values):
        ax1.text(i + width / 2, v + 0.01, f"{v:.3f}", ha="center", va="bottom")

    # Plot F1 score comparison by model
    train_f1 = [metrics["training"]["per_model"][model]["f1"] for model in all_models]
    test_f1 = [metrics["testing"]["per_model"][model]["f1"] for model in all_models]

    x = np.arange(len(all_models))

    ax2.bar(x - width / 2, train_f1, width, label="Training")
    ax2.bar(x + width / 2, test_f1, width, label="Testing")

    ax2.set_xlabel("Models")
    ax2.set_ylabel("F1 Score")
    ax2.set_title("F1 Score by Model: Training vs Testing")
    ax2.set_xticks(x)
    ax2.set_xticklabels(all_models, rotation=45, ha="right")
    ax2.legend()
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.savefig("./metrics_visualization/generalization_overall.png", dpi=300)
    plt.close()

    # 2. Generalization gap analysis
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot F1 difference by model
    f1_diffs = [metrics["difference"][model]["f1_diff"] for model in all_models]

    colors = ["green" if diff >= 0 else "red" for diff in f1_diffs]
    ax1.bar(all_models, f1_diffs, color=colors)

    ax1.set_xlabel("Models")
    ax1.set_ylabel("F1 Score Difference (Test - Train)")
    ax1.set_title("F1 Score Difference by Model")
    ax1.set_xticklabels(all_models, rotation=45, ha="right")
    ax1.axhline(y=0, color="black", linestyle="-", alpha=0.3)
    ax1.grid(axis="y", linestyle="--", alpha=0.7)

    # Add value labels
    for i, v in enumerate(f1_diffs):
        va = "bottom" if v >= 0 else "top"
        ax1.text(i, v + 0.01 if v >= 0 else v - 0.01, f"{v:.3f}", ha="center", va=va)

    # Plot confidence analysis
    train_correct = metrics["confidence_analysis"]["train"]["all_models"]["correct"][
        "mean"
    ]
    train_incorrect = metrics["confidence_analysis"]["train"]["all_models"][
        "incorrect"
    ]["mean"]
    test_correct = metrics["confidence_analysis"]["test"]["all_models"]["correct"][
        "mean"
    ]
    test_incorrect = metrics["confidence_analysis"]["test"]["all_models"]["incorrect"][
        "mean"
    ]

    confidence_data = [train_correct, train_incorrect, test_correct, test_incorrect]
    labels = ["Train Correct", "Train Incorrect", "Test Correct", "Test Incorrect"]

    ax2.bar(labels, confidence_data)

    ax2.set_xlabel("Prediction Category")
    ax2.set_ylabel("Mean Confidence Score")
    ax2.set_title("Confidence Score Analysis")
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    # Add value labels
    for i, v in enumerate(confidence_data):
        ax2.text(i, v + 0.01, f"{v:.3f}", ha="center", va="bottom")

    plt.tight_layout()
    plt.savefig("./metrics_visualization/generalization_gap.png", dpi=300)
    plt.close()

    # 3. Calibration analysis (confidence vs accuracy)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Create confidence bins
    bins = np.linspace(0, 1, 11)  # 10 bins: 0-0.1, 0.1-0.2, etc.

    # Process training data
    train_bin_accuracies, train_bin_counts = calculate_calibration(
        metrics["training"], all_models, bins
    )

    # Process testing data
    test_bin_accuracies, test_bin_counts = calculate_calibration(
        metrics["testing"], all_models, bins
    )

    # Plot calibration curves
    bin_centers = (bins[:-1] + bins[1:]) / 2

    ax1.plot([0, 1], [0, 1], "k--", label="Perfect Calibration")
    ax1.plot(bin_centers, train_bin_accuracies, "o-", label="Training")
    ax1.plot(bin_centers, test_bin_accuracies, "s-", label="Testing")

    ax1.set_xlabel("Mean Confidence")
    ax1.set_ylabel("Accuracy")
    ax1.set_title("Calibration Curve (Confidence vs Accuracy)")
    ax1.legend()
    ax1.grid(True, linestyle="--", alpha=0.7)

    # Plot counts per bin
    width = 0.35
    x = np.arange(len(bin_centers))

    ax2.bar(x - width / 2, train_bin_counts, width, label="Training")
    ax2.bar(x + width / 2, test_bin_counts, width, label="Testing")

    ax2.set_xlabel("Confidence Bin")
    ax2.set_ylabel("Count")
    ax2.set_title("Samples per Confidence Bin")
    ax2.set_xticks(x)
    ax2.set_xticklabels([f"{b:.1f}-{b+0.1:.1f}" for b in bin_centers], rotation=45)
    ax2.legend()
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    plt.tight_layout()
    plt.savefig("./metrics_visualization/generalization_calibration.png", dpi=300)
    plt.close()


def calculate_calibration(
    data: Dict, all_models: List[str], bins: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    # Calculate calibration data for confidence vs accuracy
    bin_accuracies = np.zeros(len(bins) - 1)
    bin_counts = np.zeros(len(bins) - 1)

    # Collect all confidence scores and corresponding correctness
    confidences = []
    correctness = []

    for model_data in data["per_model"].values():
        tp = model_data["true_positives"]
        fp = model_data["false_positives"]
        fn = model_data["false_negatives"]

        # Skip models with insufficient data
        if tp + fp + fn < 5:
            continue

        # Calculate accuracy
        accuracy = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0

        # Use model mean confidence as a proxy
        # In a real implementation, we'd use individual prediction confidences
        mean_confidence = 0.5  # Placeholder

        confidences.append(mean_confidence)
        correctness.append(accuracy)

    # Bin the data
    for i in range(len(bins) - 1):
        bin_mask = (confidences >= bins[i]) & (confidences < bins[i + 1])
        if np.sum(bin_mask) > 0:
            bin_accuracies[i] = np.mean(np.array(correctness)[bin_mask])
            bin_counts[i] = np.sum(bin_mask)
        else:
            bin_accuracies[i] = 0
            bin_counts[i] = 0

    return bin_accuracies, bin_counts

In [45]:
# Function to run cross-validation evaluation for more robust generalization assessment
def evaluate_with_cross_validation(
    results: List[Dict], n_folds: int = 15, random_seed: int = 42
) -> Dict:
    # Evaluate model generalization using k-fold cross-validation

    # results: List of prediction results with true models and confidence scores
    # n_folds: Number of folds for cross-validation
    # random_seed: Seed for reproducibility

    # Set random seed
    np.random.seed(random_seed)

    # Extract publications with ground truth
    publications_with_truth = [
        result
        for result in results
        if "models" in result and result["models"] and "confidence_scores" in result
    ]

    if not publications_with_truth:
        return {"error": "No publications with ground truth for evaluation"}

    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get("models", []))
    all_models = sorted(list(all_models))

    # Shuffle the data
    np.random.shuffle(publications_with_truth)

    # Prepare folds
    fold_size = len(publications_with_truth) // n_folds
    folds = []

    for i in range(n_folds):
        start_idx = i * fold_size
        end_idx = (
            start_idx + fold_size if i < n_folds - 1 else len(publications_with_truth)
        )
        folds.append(publications_with_truth[start_idx:end_idx])

    print(f"Created {n_folds} folds with ~{fold_size} publications each")

    # Cross-validation results
    fold_metrics = []

    for i in range(n_folds):
        print(f"Processing fold {i+1}/{n_folds}...")

        # Create test and train sets
        test_fold = folds[i]
        train_folds = [fold for j, fold in enumerate(folds) if j != i]
        train_data = [item for fold in train_folds for item in fold]  # Flatten

        # Calculate metrics
        train_metrics = calculate_set_metrics(train_data, all_models)
        test_metrics = calculate_set_metrics(test_fold, all_models)

        # Calculate differences
        differences = {}
        for model in all_models:
            if (
                model in train_metrics["per_model"]
                and model in test_metrics["per_model"]
            ):
                train_model = train_metrics["per_model"][model]
                test_model = test_metrics["per_model"][model]

                differences[model] = {
                    "precision_diff": test_model["precision"]
                    - train_model["precision"],
                    "recall_diff": test_model["recall"] - train_model["recall"],
                    "f1_diff": test_model["f1"] - train_model["f1"],
                }

        # Overall difference
        train_overall = train_metrics["overall"]
        test_overall = test_metrics["overall"]

        differences["overall"] = {
            "precision_diff": test_overall["precision"] - train_overall["precision"],
            "recall_diff": test_overall["recall"] - train_overall["recall"],
            "f1_diff": test_overall["f1"] - train_overall["f1"],
        }

        # Add to fold metrics
        fold_metrics.append(
            {
                "train": train_metrics,
                "test": test_metrics,
                "diff": differences,
                "fold_idx": i,
            }
        )

    # Aggregate results across folds
    cv_results = {
        "per_fold": fold_metrics,
        "aggregated": {
            "train": {
                "overall": {
                    "precision": np.mean(
                        [fold["train"]["overall"]["precision"] for fold in fold_metrics]
                    ),
                    "recall": np.mean(
                        [fold["train"]["overall"]["recall"] for fold in fold_metrics]
                    ),
                    "f1": np.mean(
                        [fold["train"]["overall"]["f1"] for fold in fold_metrics]
                    ),
                },
                "per_model": {},
            },
            "test": {
                "overall": {
                    "precision": np.mean(
                        [fold["test"]["overall"]["precision"] for fold in fold_metrics]
                    ),
                    "recall": np.mean(
                        [fold["test"]["overall"]["recall"] for fold in fold_metrics]
                    ),
                    "f1": np.mean(
                        [fold["test"]["overall"]["f1"] for fold in fold_metrics]
                    ),
                },
                "per_model": {},
            },
            "diff": {
                "overall": {
                    "precision_diff": np.mean(
                        [
                            fold["diff"]["overall"]["precision_diff"]
                            for fold in fold_metrics
                        ]
                    ),
                    "recall_diff": np.mean(
                        [
                            fold["diff"]["overall"]["recall_diff"]
                            for fold in fold_metrics
                        ]
                    ),
                    "f1_diff": np.mean(
                        [fold["diff"]["overall"]["f1_diff"] for fold in fold_metrics]
                    ),
                },
                "per_model": {},
            },
        },
    }

    # Aggregate per-model metrics
    for model in all_models:
        # Check if model exists in all folds
        if all(model in fold["train"]["per_model"] for fold in fold_metrics) and all(
            model in fold["test"]["per_model"] for fold in fold_metrics
        ):
            cv_results["aggregated"]["train"]["per_model"][model] = {
                "precision": np.mean(
                    [
                        fold["train"]["per_model"][model]["precision"]
                        for fold in fold_metrics
                    ]
                ),
                "recall": np.mean(
                    [
                        fold["train"]["per_model"][model]["recall"]
                        for fold in fold_metrics
                    ]
                ),
                "f1": np.mean(
                    [fold["train"]["per_model"][model]["f1"] for fold in fold_metrics]
                ),
            }

            cv_results["aggregated"]["test"]["per_model"][model] = {
                "precision": np.mean(
                    [
                        fold["test"]["per_model"][model]["precision"]
                        for fold in fold_metrics
                    ]
                ),
                "recall": np.mean(
                    [
                        fold["test"]["per_model"][model]["recall"]
                        for fold in fold_metrics
                    ]
                ),
                "f1": np.mean(
                    [fold["test"]["per_model"][model]["f1"] for fold in fold_metrics]
                ),
            }

            cv_results["aggregated"]["diff"]["per_model"][model] = {
                "precision_diff": np.mean(
                    [fold["diff"][model]["precision_diff"] for fold in fold_metrics]
                ),
                "recall_diff": np.mean(
                    [fold["diff"][model]["recall_diff"] for fold in fold_metrics]
                ),
                "f1_diff": np.mean(
                    [fold["diff"][model]["f1_diff"] for fold in fold_metrics]
                ),
            }

    # Calculate generalization gap
    generalization_gap = abs(
        cv_results["aggregated"]["train"]["overall"]["f1"]
        - cv_results["aggregated"]["test"]["overall"]["f1"]
    )

    cv_results["generalization_gap"] = {
        "f1_gap": generalization_gap,
        "relative_gap": generalization_gap
        / cv_results["aggregated"]["train"]["overall"]["f1"]
        if cv_results["aggregated"]["train"]["overall"]["f1"] > 0
        else float("inf"),
        "gap_assessment": assess_generalization_gap(generalization_gap),
    }

    # Visualize cross-validation results
    visualize_cross_validation_results(cv_results, all_models, n_folds)

    return cv_results


def visualize_cross_validation_results(
    cv_results: Dict, all_models: List[str], n_folds: int
):
    """Generate visualizations for cross-validation results"""
    # 1. Fold comparison for overall F1 scores
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Extract data
    fold_indices = list(range(1, n_folds + 1))
    train_f1_by_fold = [
        fold["train"]["overall"]["f1"] for fold in cv_results["per_fold"]
    ]
    test_f1_by_fold = [fold["test"]["overall"]["f1"] for fold in cv_results["per_fold"]]
    f1_diff_by_fold = [
        fold["diff"]["overall"]["f1_diff"] for fold in cv_results["per_fold"]
    ]

    # Plot F1 scores across folds
    width = 0.35
    ax1.bar(
        [i - width / 2 for i in fold_indices], train_f1_by_fold, width, label="Train F1"
    )
    ax1.bar(
        [i + width / 2 for i in fold_indices], test_f1_by_fold, width, label="Test F1"
    )

    ax1.set_xlabel("Fold")
    ax1.set_ylabel("F1 Score")
    ax1.set_title("F1 Score by Fold")
    ax1.set_xticks(fold_indices)
    ax1.legend()
    ax1.grid(axis="y", linestyle="--", alpha=0.7)

    # Add mean line
    ax1.axhline(
        y=np.mean(train_f1_by_fold),
        color="blue",
        linestyle="--",
        alpha=0.7,
        label=f"Mean Train F1: {np.mean(train_f1_by_fold):.3f}",
    )
    ax1.axhline(
        y=np.mean(test_f1_by_fold),
        color="orange",
        linestyle="--",
        alpha=0.7,
        label=f"Mean Test F1: {np.mean(test_f1_by_fold):.3f}",
    )
    ax1.legend()

    # Plot F1 differences across folds
    colors = ["green" if diff >= 0 else "red" for diff in f1_diff_by_fold]
    ax2.bar(fold_indices, f1_diff_by_fold, color=colors)

    ax2.set_xlabel("Fold")
    ax2.set_ylabel("F1 Difference (Test - Train)")
    ax2.set_title("F1 Score Difference by Fold")
    ax2.set_xticks(fold_indices)
    ax2.axhline(y=0, color="black", linestyle="-", alpha=0.3)
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    # Add mean line
    mean_diff = np.mean(f1_diff_by_fold)
    ax2.axhline(
        y=mean_diff,
        color="purple",
        linestyle="--",
        alpha=0.7,
        label=f"Mean Diff: {mean_diff:.3f}",
    )
    ax2.legend()

    plt.tight_layout()
    plt.savefig("./metrics_visualization/cv_folds.png", dpi=300)
    plt.close()

    # 2. Per-model generalization performance
    fig, ax = plt.subplots(figsize=(12, 6))

    # Extract data for models with results across all folds
    models_with_data = [
        model
        for model in all_models
        if model in cv_results["aggregated"]["train"]["per_model"]
        and model in cv_results["aggregated"]["test"]["per_model"]
    ]

    if models_with_data:
        train_f1 = [
            cv_results["aggregated"]["train"]["per_model"][model]["f1"]
            for model in models_with_data
        ]
        test_f1 = [
            cv_results["aggregated"]["test"]["per_model"][model]["f1"]
            for model in models_with_data
        ]
        f1_diffs = [test - train for train, test in zip(train_f1, test_f1)]

        # Sort by generalization gap (absolute difference)
        sorted_data = sorted(
            zip(models_with_data, train_f1, test_f1, f1_diffs),
            key=lambda x: abs(x[3]),
            reverse=True,
        )

        models_sorted = [x[0] for x in sorted_data]
        train_f1_sorted = [x[1] for x in sorted_data]
        test_f1_sorted = [x[2] for x in sorted_data]

        # Plot train vs test F1 for each model
        x = np.arange(len(models_sorted))
        width = 0.35

        ax.bar(x - width / 2, train_f1_sorted, width, label="Train F1")
        ax.bar(x + width / 2, test_f1_sorted, width, label="Test F1")

        # Add connecting lines to visualize gap
        for i, (train, test) in enumerate(zip(train_f1_sorted, test_f1_sorted)):
            ax.plot([i - width / 2, i + width / 2], [train, test], "k-", alpha=0.3)

        ax.set_xlabel("Model")
        ax.set_ylabel("F1 Score")
        ax.set_title("Cross-Validation: Train vs Test F1 by Model")
        ax.set_xticks(x)
        ax.set_xticklabels(models_sorted, rotation=45, ha="right")
        ax.legend()
        ax.grid(axis="y", linestyle="--", alpha=0.7)

        # Add overall average F1 lines
        ax.axhline(
            y=cv_results["aggregated"]["train"]["overall"]["f1"],
            color="blue",
            linestyle="--",
            alpha=0.5,
            label=f"Avg Train F1: {cv_results['aggregated']['train']['overall']['f1']:.3f}",
        )
        ax.axhline(
            y=cv_results["aggregated"]["test"]["overall"]["f1"],
            color="orange",
            linestyle="--",
            alpha=0.5,
            label=f"Avg Test F1: {cv_results['aggregated']['test']['overall']['f1']:.3f}",
        )
        ax.legend()
    else:
        ax.text(
            0.5,
            0.5,
            "Insufficient data for model comparison",
            ha="center",
            va="center",
            transform=ax.transAxes,
        )

    plt.tight_layout()
    plt.savefig("./metrics_visualization/cv_models.png", dpi=300)
    plt.close()

    # 3. Metrics variance across folds
    fig, ax = plt.subplots(figsize=(12, 6))

    # Calculate variances for overall metrics
    metrics_names = ["Precision", "Recall", "F1 Score"]
    train_variances = [
        np.var(
            [fold["train"]["overall"]["precision"] for fold in cv_results["per_fold"]]
        ),
        np.var([fold["train"]["overall"]["recall"] for fold in cv_results["per_fold"]]),
        np.var([fold["train"]["overall"]["f1"] for fold in cv_results["per_fold"]]),
    ]
    test_variances = [
        np.var(
            [fold["test"]["overall"]["precision"] for fold in cv_results["per_fold"]]
        ),
        np.var([fold["test"]["overall"]["recall"] for fold in cv_results["per_fold"]]),
        np.var([fold["test"]["overall"]["f1"] for fold in cv_results["per_fold"]]),
    ]

    x = np.arange(len(metrics_names))
    width = 0.35

    ax.bar(x - width / 2, train_variances, width, label="Train Variance")
    ax.bar(x + width / 2, test_variances, width, label="Test Variance")

    ax.set_xlabel("Metric")
    ax.set_ylabel("Variance")
    ax.set_title("Metric Variance Across Folds")
    ax.set_xticks(x)
    ax.set_xticklabels(metrics_names)
    ax.legend()
    ax.grid(axis="y", linestyle="--", alpha=0.7)

    # Add value labels
    for i, v in enumerate(train_variances):
        ax.text(i - width / 2, v + 0.001, f"{v:.4f}", ha="center", va="bottom")
    for i, v in enumerate(test_variances):
        ax.text(i + width / 2, v + 0.001, f"{v:.4f}", ha="center", va="bottom")

    plt.tight_layout()
    plt.savefig("./metrics_visualization/cv_variance.png", dpi=300)
    plt.close()

In [46]:
# Additional evaluation function to analyze classifier robustness
def analyze_classifier_robustness(results: List[Dict]) -> Dict:
    # Analyze the robustness of the classifier by studying:
    # 1. Performance on boundary cases (near threshold)
    # 2. Contribution of different evidence sources
    # 3. Consistency of classifications across models

    # Extract publications with ground truth
    publications_with_truth = [
        result
        for result in results
        if "models" in result and result["models"] and "confidence_scores" in result
    ]

    if not publications_with_truth:
        return {"error": "No publications with ground truth for evaluation"}

    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get("models", []))
        all_models.update(pub.get("matched_models", []))
    all_models = sorted(list(all_models))

    # Initialize robustness metrics
    robustness_metrics = {
        "boundary_cases": analyze_boundary_cases(publications_with_truth, all_models),
        "evidence_sources": analyze_evidence_sources(
            publications_with_truth, all_models
        ),
        "classification_consistency": analyze_classification_consistency(
            publications_with_truth, all_models
        ),
    }

    # Visualize robustness metrics
    visualize_robustness_metrics(robustness_metrics)

    return robustness_metrics


def analyze_boundary_cases(publications: List[Dict], all_models: List[str]) -> Dict:
    # Analyze performance on boundary cases (confidence scores near thresholds)
    # Define boundary ranges (near thresholds)
    boundary_width = 0.1  # Define boundary as +/- 0.1 from threshold

    boundary_metrics = {
        model: {
            "boundary_count": 0,
            "boundary_correct": 0,
            "non_boundary_count": 0,
            "non_boundary_correct": 0,
        }
        for model in all_models
    }

    # Collect boundary cases and their performance
    for pub in publications:
        true_models = set(pub.get("models", []))

        for model in all_models:
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            threshold = MODEL_THRESHOLDS.get(model, 0.4)

            # Is this a boundary case?
            is_boundary = abs(confidence - threshold) < boundary_width
            is_true_match = model in true_models
            is_predicted = confidence >= threshold
            is_correct = (is_predicted and is_true_match) or (
                not is_predicted and not is_true_match
            )

            if is_boundary:
                boundary_metrics[model]["boundary_count"] += 1
                if is_correct:
                    boundary_metrics[model]["boundary_correct"] += 1
            else:
                boundary_metrics[model]["non_boundary_count"] += 1
                if is_correct:
                    boundary_metrics[model]["non_boundary_correct"] += 1

    # Calculate boundary vs non-boundary accuracy
    for model in all_models:
        # Boundary accuracy
        boundary_count = boundary_metrics[model]["boundary_count"]
        boundary_correct = boundary_metrics[model]["boundary_correct"]
        boundary_metrics[model]["boundary_accuracy"] = (
            boundary_correct / boundary_count if boundary_count > 0 else 0
        )

        # Non-boundary accuracy
        non_boundary_count = boundary_metrics[model]["non_boundary_count"]
        non_boundary_correct = boundary_metrics[model]["non_boundary_correct"]
        boundary_metrics[model]["non_boundary_accuracy"] = (
            non_boundary_correct / non_boundary_count if non_boundary_count > 0 else 0
        )

        # Calculate robustness score (higher is better)
        boundary_metrics[model]["robustness_score"] = boundary_metrics[model][
            "boundary_accuracy"
        ] / max(0.001, 1 - boundary_metrics[model]["non_boundary_accuracy"])

    # Overall metrics
    total_boundary_correct = sum(
        metrics["boundary_correct"] for metrics in boundary_metrics.values()
    )
    total_boundary_count = sum(
        metrics["boundary_count"] for metrics in boundary_metrics.values()
    )
    total_non_boundary_correct = sum(
        metrics["non_boundary_correct"] for metrics in boundary_metrics.values()
    )
    total_non_boundary_count = sum(
        metrics["non_boundary_count"] for metrics in boundary_metrics.values()
    )

    overall_boundary_accuracy = (
        total_boundary_correct / total_boundary_count if total_boundary_count > 0 else 0
    )
    overall_non_boundary_accuracy = (
        total_non_boundary_correct / total_non_boundary_count
        if total_non_boundary_count > 0
        else 0
    )

    return {
        "per_model": boundary_metrics,
        "overall": {
            "boundary_accuracy": overall_boundary_accuracy,
            "non_boundary_accuracy": overall_non_boundary_accuracy,
            "total_boundary_count": total_boundary_count,
            "total_non_boundary_count": total_non_boundary_count,
            "robustness_score": overall_boundary_accuracy
            / max(0.001, 1 - overall_non_boundary_accuracy),
        },
    }


def analyze_evidence_sources(publications: List[Dict], all_models: List[str]) -> Dict:
    # Analyze contribution of different evidence sources to prediction accuracy
    # Initialize evidence source metrics
    source_metrics = {}

    # Collect all unique sources
    all_sources = set()
    for pub in publications:
        for model in all_models:
            sources = pub.get("confidence_sources", {}).get(model, [])
            all_sources.update(sources)

    # Initialize metrics for each source
    for source in all_sources:
        source_metrics[source] = {
            "count": 0,
            "correct": 0,
            "incorrect": 0,
            "true_positives": 0,
            "false_positives": 0,
            "per_model": {model: {"count": 0, "correct": 0} for model in all_models},
        }

    # Collect metrics by source
    for pub in publications:
        true_models = set(pub.get("models", []))

        for model in all_models:
            sources = pub.get("confidence_sources", {}).get(model, [])
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            threshold = MODEL_THRESHOLDS.get(model, 0.4)

            is_predicted = confidence >= threshold
            is_true_match = model in true_models
            is_correct = (is_predicted and is_true_match) or (
                not is_predicted and not is_true_match
            )

            for source in sources:
                # Update overall source metrics
                source_metrics[source]["count"] += 1
                if is_correct:
                    source_metrics[source]["correct"] += 1
                else:
                    source_metrics[source]["incorrect"] += 1

                # Track true/false positives
                if is_predicted and is_true_match:
                    source_metrics[source]["true_positives"] += 1
                elif is_predicted and not is_true_match:
                    source_metrics[source]["false_positives"] += 1

                # Update per-model metrics
                source_metrics[source]["per_model"][model]["count"] += 1
                if is_correct:
                    source_metrics[source]["per_model"][model]["correct"] += 1

    # Calculate accuracies and contribution scores
    for source in source_metrics:
        # Skip sources with too few samples
        if source_metrics[source]["count"] < 5:
            continue

        # Calculate accuracy
        accuracy = source_metrics[source]["correct"] / source_metrics[source]["count"]
        source_metrics[source]["accuracy"] = accuracy

        # Calculate precision
        tp = source_metrics[source]["true_positives"]
        fp = source_metrics[source]["false_positives"]
        source_metrics[source]["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0

        # Calculate per-model accuracies
        for model in all_models:
            model_count = source_metrics[source]["per_model"][model]["count"]
            if model_count > 0:
                model_correct = source_metrics[source]["per_model"][model]["correct"]
                source_metrics[source]["per_model"][model]["accuracy"] = (
                    model_correct / model_count
                )

    return {
        "per_source": source_metrics,
        "overall": {
            "unique_sources": len(source_metrics),
            "top_sources": sorted(
                [
                    (s, m["accuracy"], m["count"])
                    for s, m in source_metrics.items()
                    if m["count"] >= 5
                ],
                key=lambda x: x[1]
                * math.log(x[2]),  # Sort by accuracy weighted by log of count
                reverse=True,
            )[:5],
        },
    }


def analyze_classification_consistency(
    publications: List[Dict], all_models: List[str]
) -> Dict:
    # Analyze consistency of classifications across different models
    # Initialize consistency metrics
    consistency_metrics = {
        "publication_consistency": [],
        "model_consistency": {
            model: {"consistency_score": 0, "samples": 0} for model in all_models
        },
        "confidence_vs_accuracy": {"bins": [], "accuracy": []},
    }

    # Create confidence bins
    bins = np.linspace(0, 1, 11)  # 10 bins
    bin_counts = np.zeros(10)
    bin_correct = np.zeros(10)

    # Analyze consistency for each publication
    for pub in publications:
        true_models = set(pub.get("models", []))
        pred_models = set(pub.get("matched_models", []))

        # Calculate agreement ratio for this publication
        total_models = len(true_models.union(pred_models))
        if total_models > 0:
            agreement = len(true_models.intersection(pred_models)) / total_models
            consistency_metrics["publication_consistency"].append(agreement)

        # Analyze per-model consistency
        for model in all_models:
            confidence = pub.get("confidence_scores", {}).get(model, 0)
            is_true_match = model in true_models

            # Update confidence bins
            bin_idx = min(9, int(confidence * 10))
            bin_counts[bin_idx] += 1
            if (confidence >= MODEL_THRESHOLDS.get(model, 0.4) and is_true_match) or (
                confidence < MODEL_THRESHOLDS.get(model, 0.4) and not is_true_match
            ):
                bin_correct[bin_idx] += 1

            # Skip if not enough data for this model
            if not true_models:
                continue

            # Calculate how well this model's prediction aligns with others
            other_true_models = true_models - {model}
            if not other_true_models:
                continue

            # Is this model's prediction consistent with others?
            is_predicted = confidence >= MODEL_THRESHOLDS.get(model, 0.4)

            if is_true_match:
                # For true matches, check if other true models are also predicted
                other_pred_models = (
                    pred_models - {model} if is_predicted else pred_models
                )
                overlap = len(other_true_models.intersection(other_pred_models))
                consistency = (
                    overlap / len(other_true_models) if other_true_models else 0
                )

                consistency_metrics["model_consistency"][model][
                    "consistency_score"
                ] += consistency
                consistency_metrics["model_consistency"][model]["samples"] += 1

    # Calculate average consistency for each model
    for model in all_models:
        samples = consistency_metrics["model_consistency"][model]["samples"]
        if samples > 0:
            consistency_metrics["model_consistency"][model]["consistency_score"] /= (
                samples
            )

    # Calculate overall consistency
    if consistency_metrics["publication_consistency"]:
        consistency_metrics["overall_consistency"] = sum(
            consistency_metrics["publication_consistency"]
        ) / len(consistency_metrics["publication_consistency"])
    else:
        consistency_metrics["overall_consistency"] = 0

    # Prepare confidence vs accuracy data
    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_accuracy = np.zeros(10)

    for i in range(10):
        if bin_counts[i] > 0:
            bin_accuracy[i] = bin_correct[i] / bin_counts[i]

    consistency_metrics["confidence_vs_accuracy"] = {
        "bins": bin_centers,
        "accuracy": bin_accuracy,
        "counts": bin_counts,
    }

    return consistency_metrics


def visualize_robustness_metrics(metrics: Dict):
    # Generate visualizations for robustness metrics
    # 1. Boundary vs Non-Boundary Accuracy
    boundary_metrics = metrics["boundary_cases"]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Extract data
    models = sorted(boundary_metrics["per_model"].keys())
    boundary_accuracy = [
        boundary_metrics["per_model"][model]["boundary_accuracy"] for model in models
    ]
    non_boundary_accuracy = [
        boundary_metrics["per_model"][model]["non_boundary_accuracy"]
        for model in models
    ]
    boundary_counts = [
        boundary_metrics["per_model"][model]["boundary_count"] for model in models
    ]

    # Plot boundary vs non-boundary accuracy
    x = np.arange(len(models))
    width = 0.35

    bars1 = ax1.bar(x - width / 2, boundary_accuracy, width, label="Boundary")
    bars2 = ax1.bar(x + width / 2, non_boundary_accuracy, width, label="Non-Boundary")

    ax1.set_xlabel("Model")
    ax1.set_ylabel("Accuracy")
    ax1.set_title("Boundary vs Non-Boundary Accuracy by Model")
    ax1.set_xticks(x)
    ax1.set_xticklabels(models, rotation=45, ha="right")
    ax1.legend()
    ax1.grid(axis="y", linestyle="--", alpha=0.7)

    # Plot boundary case counts
    ax2.bar(models, boundary_counts)
    ax2.set_xlabel("Model")
    ax2.set_ylabel("Count")
    ax2.set_title("Boundary Case Count by Model")
    ax2.set_xticklabels(models, rotation=45, ha="right")
    ax2.grid(axis="y", linestyle="--", alpha=0.7)

    # Add overall accuracy line
    ax1.axhline(
        y=boundary_metrics["overall"]["boundary_accuracy"],
        color="blue",
        linestyle="--",
        alpha=0.7,
        label=f'Overall Boundary: {boundary_metrics["overall"]["boundary_accuracy"]:.3f}',
    )
    ax1.axhline(
        y=boundary_metrics["overall"]["non_boundary_accuracy"],
        color="orange",
        linestyle="--",
        alpha=0.7,
        label=f'Overall Non-Boundary: {boundary_metrics["overall"]["non_boundary_accuracy"]:.3f}',
    )
    ax1.legend()

    plt.tight_layout()
    plt.savefig("./metrics_visualization/robustness_boundary.png", dpi=300)
    plt.close()

    # 2. Evidence Source Analysis
    source_metrics = metrics["evidence_sources"]["per_source"]

    # Filter sources with sufficient samples
    sources = [
        source
        for source, metrics in source_metrics.items()
        if metrics.get("count", 0) >= 20
    ]  # Min 20 samples

    if sources:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        # Extract data
        accuracies = [source_metrics[s]["accuracy"] for s in sources]
        counts = [source_metrics[s]["count"] for s in sources]

        # Sort by accuracy
        sorted_data = sorted(
            zip(sources, accuracies, counts), key=lambda x: x[1], reverse=True
        )
        sources_sorted = [x[0] for x in sorted_data]
        accuracies_sorted = [x[1] for x in sorted_data]
        counts_sorted = [x[2] for x in sorted_data]

        # Truncate to top 10 for better visualization
        if len(sources_sorted) > 10:
            sources_sorted = sources_sorted[:10]
            accuracies_sorted = accuracies_sorted[:10]
            counts_sorted = counts_sorted[:10]

        # Plot source accuracies
        bars = ax1.barh(sources_sorted, accuracies_sorted)

        ax1.set_xlabel("Accuracy")
        ax1.set_ylabel("Evidence Source")
        ax1.set_title("Accuracy by Evidence Source")
        ax1.set_xlim(0, 1)
        ax1.grid(axis="x", linestyle="--", alpha=0.7)

        # Add count labels
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax1.text(
                width + 0.01,
                bar.get_y() + bar.get_height() / 2,
                f"n={counts_sorted[i]}",
                va="center",
            )

        # Plot source counts
        ax2.barh(sources_sorted, counts_sorted)

        ax2.set_xlabel("Count")
        ax2.set_ylabel("Evidence Source")
        ax2.set_title("Usage Count by Evidence Source")
        ax2.grid(axis="x", linestyle="--", alpha=0.7)

        plt.tight_layout()
        plt.savefig("./metrics_visualization/robustness_sources.png", dpi=300)
        plt.close()

    # 3. Classification Consistency Analysis
    consistency_metrics = metrics["classification_consistency"]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot confidence vs accuracy relationship
    confidence_data = consistency_metrics["confidence_vs_accuracy"]

    ax1.plot(confidence_data["bins"], confidence_data["accuracy"], "o-")
    ax1.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect Calibration")

    ax1.set_xlabel("Confidence Score Bin")
    ax1.set_ylabel("Accuracy")
    ax1.set_title("Confidence vs Accuracy (Calibration)")
    ax1.set_xlim(0, 1)
    ax1.set_ylim(0, 1)
    ax1.grid(True, linestyle="--", alpha=0.7)
    ax1.legend()

    # Add sample counts
    for i, (bin_center, accuracy, count) in enumerate(
        zip(
            confidence_data["bins"],
            confidence_data["accuracy"],
            confidence_data["counts"],
        )
    ):
        if count > 0:
            ax1.annotate(
                f"n={int(count)}",
                (bin_center, accuracy),
                textcoords="offset points",
                xytext=(0, 10),
                ha="center",
            )

    # Plot model consistency
    models = []
    consistency_scores = []

    for model, data in consistency_metrics["model_consistency"].items():
        if data["samples"] >= 10:  # Min 10 samples
            models.append(model)
            consistency_scores.append(data["consistency_score"])

    if models:
        # Sort by consistency score
        sorted_data = sorted(
            zip(models, consistency_scores), key=lambda x: x[1], reverse=True
        )
        models_sorted = [x[0] for x in sorted_data]
        scores_sorted = [x[1] for x in sorted_data]

        ax2.bar(models_sorted, scores_sorted)

        ax2.set_xlabel("Model")
        ax2.set_ylabel("Consistency Score")
        ax2.set_title("Model Prediction Consistency")
        ax2.set_xticklabels(models_sorted, rotation=45, ha="right")
        ax2.grid(axis="y", linestyle="--", alpha=0.7)

        # Add overall consistency line
        ax2.axhline(
            y=consistency_metrics["overall_consistency"],
            color="red",
            linestyle="--",
            alpha=0.7,
            label=f'Overall: {consistency_metrics["overall_consistency"]:.3f}',
        )
        ax2.legend()
    else:
        ax2.text(
            0.5,
            0.5,
            "Insufficient data for consistency analysis",
            ha="center",
            va="center",
            transform=ax2.transAxes,
        )

    plt.tight_layout()
    plt.savefig("./metrics_visualization/robustness_consistency.png", dpi=300)
    plt.close()

In [47]:
# Run generalization evaluations after main processing is complete
def evaluate_generalization_metrics(results_path: str = "./results.json"):
    print("\nEvaluating model generalization to unseen data...")

    # Load results
    try:
        with open(results_path, "r") as f:
            results = json.load(f)
        print(f"Loaded {len(results)} publications from results")
    except Exception as e:
        print(f"Error loading results: {str(e)}")
        return

    # Run simple train/test split evaluation
    print("\n1. Running train/test split evaluation...")
    try:
        generalization_metrics = evaluate_generalization(results)
    except Exception as e:
        print(f"Error during train/test evaluation: {str(e)}")
        generalization_metrics = None

    # Run cross-validation evaluation
    print("\n2. Running cross-validation evaluation...")
    try:
        cv_metrics = evaluate_with_cross_validation(results, n_folds=5)
    except Exception as e:
        print(f"Error during cross-validation: {str(e)}")
        cv_metrics = None

    # Run robustness analysis
    print("\n3. Running robustness analysis...")
    try:
        robustness_metrics = analyze_classifier_robustness(results)
    except Exception as e:
        print(f"Error during robustness analysis: {str(e)}")
        robustness_metrics = None

    # Summarize the results
    print("\n=== Generalization Evaluation Summary ===")

    # Train/test summary
    if generalization_metrics:
        train_f1 = (
            generalization_metrics.get("training", {}).get("overall", {}).get("f1", 0)
        )
        test_f1 = (
            generalization_metrics.get("testing", {}).get("overall", {}).get("f1", 0)
        )
        gap = generalization_metrics.get("generalization_gap", {}).get("f1_gap", 0)
        assessment = generalization_metrics.get("generalization_gap", {}).get(
            "gap_assessment", "N/A"
        )

        print(f"Train/Test Split Results:")
        print(f"  - Training F1: {train_f1:.3f}")
        print(f"  - Testing F1: {test_f1:.3f}")
        print(f"  - Generalization Gap: {gap:.3f}")
        print(f"  - Assessment: {assessment}")
    else:
        print("Train/Test Split Results: Failed to generate metrics")

    # Cross-validation summary
    if cv_metrics:
        cv_train_f1 = (
            cv_metrics.get("aggregated", {})
            .get("train", {})
            .get("overall", {})
            .get("f1", 0)
        )
        cv_test_f1 = (
            cv_metrics.get("aggregated", {})
            .get("test", {})
            .get("overall", {})
            .get("f1", 0)
        )
        cv_gap = cv_metrics.get("generalization_gap", {}).get("f1_gap", 0)
        cv_assessment = cv_metrics.get("generalization_gap", {}).get(
            "gap_assessment", "N/A"
        )

        print(f"\nCross-Validation Results:")
        print(f"  - Average Training F1: {cv_train_f1:.3f}")
        print(f"  - Average Testing F1: {cv_test_f1:.3f}")
        print(f"  - Generalization Gap: {cv_gap:.3f}")
        print(f"  - Assessment: {cv_assessment}")
    else:
        print("\nCross-Validation Results: Failed to generate metrics")

    # Robustness summary
    if robustness_metrics:
        boundary_acc = (
            robustness_metrics.get("boundary_cases", {})
            .get("overall", {})
            .get("boundary_accuracy", 0)
        )
        non_boundary_acc = (
            robustness_metrics.get("boundary_cases", {})
            .get("overall", {})
            .get("non_boundary_accuracy", 0)
        )
        robustness_score = (
            robustness_metrics.get("boundary_cases", {})
            .get("overall", {})
            .get("robustness_score", 0)
        )

        print(f"\nRobustness Analysis:")
        print(f"  - Boundary Case Accuracy: {boundary_acc:.3f}")
        print(f"  - Non-Boundary Accuracy: {non_boundary_acc:.3f}")
        print(f"  - Robustness Score: {robustness_score:.3f}")

        # Top evidence sources
        if "top_sources" in robustness_metrics.get("evidence_sources", {}).get(
            "overall", {}
        ):
            top_sources = robustness_metrics["evidence_sources"]["overall"][
                "top_sources"
            ]
            if top_sources:
                print(
                    f"  - Top Evidence Source: {top_sources[0][0]} (Accuracy: {top_sources[0][1]:.3f}, Count: {top_sources[0][2]})"
                )

        # Consistency
        if "overall_consistency" in robustness_metrics.get(
            "classification_consistency", {}
        ):
            consistency = robustness_metrics["classification_consistency"][
                "overall_consistency"
            ]
            print(f"  - Classification Consistency: {consistency:.3f}")
    else:
        print("\nRobustness Analysis: Failed to generate metrics")

    print("\nGeneralization evaluation complete. Visualizations saved to:")
    print("  - /metrics_visualization/generalization_*.png")

    # Save full generalization metrics to JSON (with safe handling of numpy values)
    try:
        generalization_data = {
            "simple_split": generalization_metrics if generalization_metrics else {},
            "cross_validation": cv_metrics if cv_metrics else {},
            "robustness": robustness_metrics if robustness_metrics else {},
        }

        # Convert numpy types to Python native types
        def numpy_to_python(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {k: numpy_to_python(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [numpy_to_python(i) for i in obj]
            return obj

        with open("./generalization_metrics.json", "w") as f:
            json.dump(numpy_to_python(generalization_data), f, indent=2)
        print("Full generalization metrics saved to generalization_metrics.json")
    except Exception as e:
        print(f"Error saving generalization metrics: {str(e)}")

    return generalization_data

In [48]:
# # Modify the main function to include generalization evaluation
# def main_with_generalization():
#     """Extended main function that includes generalization evaluation"""
#     # Run standard pipeline
#     main()

#     # Run generalization evaluation
#     print("\n======================================")
#     print("Running generalization evaluation...")
#     print("======================================")
#     evaluate_generalization_metrics()

# Modify the main function to include generalization evaluation
def main_with_generalization():
    # Extended main function that includes generalization evaluation
    try:
        # First let's load and preprocess the test data
        print("Loading test data...")
        try:
            with open("./labeled_test_data_plusmomo.json") as f:
                test_data = json.load(f)

            print(f"Loaded {len(test_data)} publications from test data")

            # Load curated publications to use as ground truth
            print("Loading curated publications for ground truth...")
            with open("./curated_publications.json") as f:
                curated_publications = json.load(f)
            # with open('./labeled_test_data.json') as f:
            #     curated_publications = json.load(f)

            print(f"Loaded {len(curated_publications)} curated publications.")

            # Create a mapping from DOIs to models for quick lookup
            doi_to_model_map = {}
            for pub in curated_publications:
                if "doi" in pub and pub["doi"] and "model" in pub and pub["model"]:
                    normalized_doi = normalize_doi(pub["doi"])
                    if normalized_doi:
                        doi_to_model_map[normalized_doi] = pub["model"]

            print(f"Created mapping with {len(doi_to_model_map)} DOI to model pairs.")

            # Add 'models' field to test data where possible using the DOI mapping
            models_added = 0
            for pub in test_data:
                if "DOI" in pub and pub["DOI"]:
                    normalized_doi = normalize_doi(pub["DOI"])
                    if normalized_doi in doi_to_model_map:
                        pub["models"] = [doi_to_model_map[normalized_doi]]
                        models_added += 1

            print(
                f"Added ground truth 'models' to {models_added} publications using curated mapping."
            )
            print(f"This will allow evaluation for these publications.")

        except Exception as e:
            print(f"Error preparing test data: {str(e)}")
            print("Falling back to default processing...")

            # Run standard pipeline if there's an issue with the test data
            main()
            return

        # Continue with processing using the enhanced test data
        start_time = time.time()

        print("Starting optimized science publication classifier...")
        print(f"CUDA available: {torch.cuda.is_available()}")

        if torch.cuda.is_available():
            print(f"CUDA device: {torch.cuda.get_device_name(0)}")
            print(
                f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB"
            )

        # Initialize classifier
        print("Initializing science classifier...")
        with torch.inference_mode():
            science_classifier = ScienceClassifier.get_instance()
        print("Science classifier initialized successfully.")

        # Load configuration files
        print("Loading configuration files...")
        curated_mapping = load_curated_models("./curated_publications.json")
        # curated_mapping = load_curated_models('./labeled_test_data.json')
        model_keywords = load_model_keywords("./model_keywords.json")
        model_descriptions = load_model_descriptions("./model_descriptions.json")

        # Initialize ranker
        print("Initializing relevance ranker...")
        with torch.inference_mode():
            ranker = RelevanceRanker(model_descriptions)

        # Initialize model embeddings
        print("Initializing model embeddings...")
        with torch.inference_mode():
            model_embeddings = initialize_model_embeddings(model_descriptions)

        # Clean memory
        optimize_memory()

        # Initialize context manager
        print("Building context validation profiles...")
        with open("./curated_publications.json") as f:
            full_curated = json.load(f)
        # with open('./labeled_test_data.json') as f:
        #     full_curated = json.load(f)

        with torch.inference_mode():
            context_manager = ModelContextManager(full_curated)

        # Derive data-driven affinities from the curated dataset
        print("Deriving data-driven affinities from curated dataset...")
        derive_data_driven_affinities("./curated_publications.json")
        # derive_data_driven_affinities('./labeled_test_data.json')

        # Process publications
        print("\nProcessing publications...")
        results = process_publication_batch(
            test_data,
            curated_mapping,
            model_keywords,
            model_embeddings,
            science_classifier,
            context_manager,
            ranker=ranker,
        )

        # Save results
        print("Saving results...")
        with open("./labeled_test_data_plusmomo_results.json", "w") as f:
            json.dump(results, f, indent=2)

        # Generate visualizations with threshold analysis
        print("\nGenerating visualizations with threshold analysis...")
        metrics = visualize_metrics(results)

        # Show TF-IDF model-specific terms
        print("\nModel-specific terminology based on TF-IDF analysis:")
        model_specific_terms = context_manager.get_model_specific_terms()
        for model, terms in model_specific_terms.items():
            print(f"\n{model} distinctive terms:")
            print(", ".join(terms[:10]))  # Show top 10 terms

        # Find optimal thresholds
        print("\nFinding optimal thresholds...")
        try:
            optimal_thresholds = find_optimal_thresholds(results)

            # Add error check for per_model key
            if "per_model" not in optimal_thresholds:
                print(
                    "Warning: 'per_model' key missing in optimal_thresholds, using empty dict"
                )
                optimal_thresholds["per_model"] = {}

            print("\nOptimal model-specific thresholds:")
            for model, data in optimal_thresholds["per_model"].items():
                current = MODEL_THRESHOLDS.get(model, 0.4)
                # Use threshold_f1 instead of threshold
                print(
                    f"{model}: {data['threshold_f1']:.2f} (current: {current:.2f}, F1: {data['f1']:.2f})"
                )

            print(
                f"\nOptimal overall threshold: {optimal_thresholds['overall']['threshold_f1']:.2f}"
            )
            print(
                f"Overall F1 score with optimal thresholds: {optimal_thresholds['overall']['f1']:.3f}"
            )

            # Compare performance with current vs. optimal thresholds
            print("\nComparing performance with current vs. optimal thresholds:")
            current_performance = analyze_threshold_performance(results)
            optimal_performance = analyze_threshold_performance(
                results,
                model_thresholds=optimal_thresholds.get("model_thresholds_f1", {}),
                overall_threshold=optimal_thresholds.get("overall", {}).get(
                    "threshold_f1", 0.4
                ),
            )

            current_f1 = current_performance.get("overall", {}).get("f1", 0)
            optimal_f1 = optimal_performance.get("overall", {}).get("f1", 0)
            improvement = (optimal_f1 - current_f1) * 100

            print(
                f"Current F1: {current_f1:.3f}, "
                + f"Optimal F1: {optimal_f1:.3f}, "
                + f"Improvement: {improvement:.1f}%"
            )
        except Exception as e:
            print(f"Error during threshold optimization: {str(e)}")
            import traceback

            traceback.print_exc()

        # Run generalization evaluation
        print("\n======================================")
        print("Running generalization evaluation...")
        print("======================================")
        try:
            # Use results from this run directly instead of loading from file
            generalization_metrics = evaluate_generalization_metrics_inline(results)
        except Exception as e:
            print(f"Error during generalization evaluation: {str(e)}")
            import traceback

            traceback.print_exc()

        # Report completion
        total_time = time.time() - start_time
        print(
            f"\nProcessing completed in {total_time:.2f}s ({total_time/60:.2f} minutes)"
        )
        print(f"Average time per publication: {total_time/len(test_data):.4f}s")

        if torch.cuda.is_available():
            print(
                f"Peak GPU memory usage: {torch.cuda.max_memory_allocated(0) / 1024**2:.2f} MB"
            )

        # Report additional information
        print(
            f"\nNote: Ground truth 'models' were added to {models_added}/{len(test_data)} publications."
        )
        print(
            f"Evaluation metrics are based on these {models_added} publications with ground truth."
        )

    except Exception as e:
        print(f"\nError in main_with_generalization: {str(e)}")
        import traceback

        traceback.print_exc()


# Special version of evaluate_generalization_metrics that works directly with results in memory
def evaluate_generalization_metrics_inline(results):
    print("\nEvaluating model generalization on in-memory results...")
    print(f"Using {len(results)} publications for evaluation")

    # Run simple train/test split evaluation
    print("\n1. Running train/test split evaluation...")
    try:
        generalization_metrics = evaluate_generalization(results)
    except Exception as e:
        print(f"Error during train/test evaluation: {str(e)}")
        generalization_metrics = None

    # Run cross-validation evaluation
    print("\n2. Running cross-validation evaluation...")
    try:
        cv_metrics = evaluate_with_cross_validation(results, n_folds=5)
    except Exception as e:
        print(f"Error during cross-validation: {str(e)}")
        cv_metrics = None

    # Run robustness analysis
    print("\n3. Running robustness analysis...")
    try:
        robustness_metrics = analyze_classifier_robustness(results)
    except Exception as e:
        print(f"Error during robustness analysis: {str(e)}")
        robustness_metrics = None

    # Summarize the results
    print("\n=== Generalization Evaluation Summary ===")

    # Train/test summary
    if generalization_metrics:
        train_f1 = (
            generalization_metrics.get("training", {}).get("overall", {}).get("f1", 0)
        )
        test_f1 = (
            generalization_metrics.get("testing", {}).get("overall", {}).get("f1", 0)
        )
        gap = generalization_metrics.get("generalization_gap", {}).get("f1_gap", 0)
        assessment = generalization_metrics.get("generalization_gap", {}).get(
            "gap_assessment", "N/A"
        )

        print(f"Train/Test Split Results:")
        print(f"  - Training F1: {train_f1:.3f}")
        print(f"  - Testing F1: {test_f1:.3f}")
        print(f"  - Generalization Gap: {gap:.3f}")
        print(f"  - Assessment: {assessment}")
    else:
        print("Train/Test Split Results: Failed to generate metrics")

    # Cross-validation summary
    if cv_metrics:
        cv_train_f1 = (
            cv_metrics.get("aggregated", {})
            .get("train", {})
            .get("overall", {})
            .get("f1", 0)
        )
        cv_test_f1 = (
            cv_metrics.get("aggregated", {})
            .get("test", {})
            .get("overall", {})
            .get("f1", 0)
        )
        cv_gap = cv_metrics.get("generalization_gap", {}).get("f1_gap", 0)
        cv_assessment = cv_metrics.get("generalization_gap", {}).get(
            "gap_assessment", "N/A"
        )

        print(f"\nCross-Validation Results:")
        print(f"  - Average Training F1: {cv_train_f1:.3f}")
        print(f"  - Average Testing F1: {cv_test_f1:.3f}")
        print(f"  - Generalization Gap: {cv_gap:.3f}")
        print(f"  - Assessment: {cv_assessment}")
    else:
        print("\nCross-Validation Results: Failed to generate metrics")

    # Robustness summary
    if robustness_metrics:
        boundary_acc = (
            robustness_metrics.get("boundary_cases", {})
            .get("overall", {})
            .get("boundary_accuracy", 0)
        )
        non_boundary_acc = (
            robustness_metrics.get("boundary_cases", {})
            .get("overall", {})
            .get("non_boundary_accuracy", 0)
        )
        robustness_score = (
            robustness_metrics.get("boundary_cases", {})
            .get("overall", {})
            .get("robustness_score", 0)
        )

        print(f"\nRobustness Analysis:")
        print(f"  - Boundary Case Accuracy: {boundary_acc:.3f}")
        print(f"  - Non-Boundary Accuracy: {non_boundary_acc:.3f}")
        print(f"  - Robustness Score: {robustness_score:.3f}")

        # Top evidence sources
        if "top_sources" in robustness_metrics.get("evidence_sources", {}).get(
            "overall", {}
        ):
            top_sources = robustness_metrics["evidence_sources"]["overall"][
                "top_sources"
            ]
            if top_sources:
                print(
                    f"  - Top Evidence Source: {top_sources[0][0]} (Accuracy: {top_sources[0][1]:.3f}, Count: {top_sources[0][2]})"
                )

        # Consistency
        if "overall_consistency" in robustness_metrics.get(
            "classification_consistency", {}
        ):
            consistency = robustness_metrics["classification_consistency"][
                "overall_consistency"
            ]
            print(f"  - Classification Consistency: {consistency:.3f}")
    else:
        print("\nRobustness Analysis: Failed to generate metrics")

    print("\nGeneralization evaluation complete. Visualizations saved to:")
    print("  - metrics_visualization_generalization_*.png")

    # Save full generalization metrics to JSON (with safe handling of numpy values)
    try:
        generalization_data = {
            "simple_split": generalization_metrics if generalization_metrics else {},
            "cross_validation": cv_metrics if cv_metrics else {},
            "robustness": robustness_metrics if robustness_metrics else {},
        }

        # Convert numpy types to Python native types
        def numpy_to_python(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {k: numpy_to_python(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [numpy_to_python(i) for i in obj]
            return obj

        with open("./generalization_metrics_labeled_test_data_plusmomo.json", "w") as f:
            json.dump(numpy_to_python(generalization_data), f, indent=2)
        print("Full generalization metrics saved to file")
    except Exception as e:
        print(f"Error saving generalization metrics: {str(e)}")

    return generalization_data

In [None]:
# Run the extended main function when directly executed
if __name__ == "__main__":
    main_with_generalization()

Loading test data...
Loaded 3398 publications from test data
Loading curated publications for ground truth...
Loaded 2344 curated publications.
Created mapping with 2260 DOI to model pairs.
Added ground truth 'models' to 1548 publications using curated mapping.
This will allow evaluation for these publications.
Starting optimized science publication classifier...
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4070 Ti SUPER
CUDA memory allocated: 486.60 MB
Initializing science classifier...
Loading research_area model from arminmehrabian/nasa-impact-nasa-smd-ibm-st-v2-classification-finetuned...


Device set to use cuda:0


Successfully loaded research_area model
Loading science_keywords model from nasa-impact/science-keyword-classification...


Device set to use cuda:0


Successfully loaded science_keywords model
Loading division model from nasa-impact/division-classifier...


Device set to use cuda:0


Successfully loaded division model
Science classifier initialized successfully.
Loading configuration files...
Initializing relevance ranker...
Initializing model embeddings...
Building context validation profiles...
Deriving data-driven affinities from curated dataset...
Deriving data-driven affinities from curated dataset...
Saved data-driven keyword affinities with 186 entries
Derived affinities for 186 keywords, 15 research areas, and 4 divisions

Processing publications...
Processing batch 1/34 (publications 1-100/3398)...
  Processing publication 20/3398...
  Processing publication 40/3398...
  Processing publication 60/3398...
  Processing publication 80/3398...
  Processing publication 100/3398...
Completed batch 1/34
Processing batch 2/34 (publications 101-200/3398)...
  Processing publication 120/3398...
  Processing publication 140/3398...
  Processing publication 160/3398...
  Processing publication 180/3398...
  Processing publication 200/3398...
Completed batch 2/34
Proce