In [None]:
import json
import re
import gc
from typing import List, Dict, Set, Tuple, Any, Optional, Union
from collections import defaultdict, Counter
import time
import os

import numpy as np
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from rapidfuzz import fuzz
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from functools import lru_cache
from datetime import datetime

In [None]:
# Initialize GPU/CPU
device = 0 if torch.cuda.is_available() else -1

# Initialize SpaCy
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("Downloading SpaCy model...")
    spacy.cli.download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# Initialize SentenceTransformer model
st_model = SentenceTransformer('nasa-impact/nasa-ibm-st.38m')
st_model.eval()
st_model.to('cuda' if torch.cuda.is_available() else 'cpu')

# Memory management
def optimize_memory():
    # Free up memory resources 
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

In [None]:
class ScienceClassifier:
    # Science document classifier with model loading and caching
    _instance = None
    
    @classmethod
    def get_instance(cls):
        # Get or create singleton instance
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance
    
    def __init__(self):
        # Model configurations
        self.models = {
            'research_area': {
                'name': 'arminmehrabian/nasa-impact-nasa-smd-ibm-st-v2-classification-finetuned',
                'label_map': {
                    0: "Agriculture",
                    1: "Air Quality",
                    2: "Atmospheric/Ocean Indicators",
                    3: "Cryospheric Indicators",
                    4: "Droughts",
                    5: "Earthquakes",
                    6: "Ecosystems",
                    7: "Energy Production/Use",
                    8: "Environmental Impacts",
                    9: "Floods",
                    10: "Greenhouse Gases",
                    11: "Habitat Conversion/Fragmentation",
                    12: "Heat",
                    13: "Land Surface/Agriculture Indicators",
                    14: "Public Health",
                    15: "Severe Storms",
                    16: "Sun-Earth Interactions",
                    17: "Validation",
                    18: "Volcanic Eruptions",
                    19: "Water Quality",
                    20: "Wildfires"
                }
            },
            'science_keywords': {
                'name': 'nasa-impact/science-keyword-classification'
            },
            'division': {
                'name': 'nasa-impact/division-classifier',
                'label_map': {
                    0: 'Astrophysics',
                    1: 'Biological and Physical Sciences',
                    2: 'Earth Science',
                    3: 'Heliophysics',
                    4: 'Planetary Science'
                }
            }
        }
        self.classification_cache = {}
        self._load_models()
    
    def _load_models(self):
        # Load classification models
        self.classifiers = {}
        for task, config in self.models.items():
            try:
                print(f"Loading {task} model from {config['name']}...")
                with torch.no_grad():
                    tokenizer = AutoTokenizer.from_pretrained(config['name'])
                    model = AutoModelForSequenceClassification.from_pretrained(config['name'])
                    model.eval()
                
                self.classifiers[task] = {
                    'pipe': pipeline(
                        'text-classification',
                        model=model,
                        tokenizer=tokenizer,
                        device=device,
                        batch_size=32
                    ),
                    'config': config
                }
                print(f"Successfully loaded {task} model")
            except Exception as e:
                print(f"Failed to load {task} model: {str(e)}")
    
    def _prepare_text(self, publication: Dict) -> str:
        # Extract and combine text from publication for classification
        # Extract title
        title = ""
        if 'title' in publication:
            if isinstance(publication['title'], list) and publication['title']:
                title = publication['title'][0]
            elif isinstance(publication['title'], str):
                title = publication['title']
                
        abstract = publication.get('abstract', '')
        keywords = ' '.join(publication.get('keywords', []))
        return ' '.join([title, abstract, keywords])
    
    def _get_cache_key(self, publication: Dict) -> str:
        # Generate caches for publication
        if 'DOI' in publication and publication['DOI']:
            return f"doi:{publication['DOI']}"
        
        title = ""
        if 'title' in publication:
            if isinstance(publication['title'], list) and publication['title']:
                title = publication['title'][0]
            elif isinstance(publication['title'], str):
                title = publication.get('title', '')
        
        return f"title:{title}"
    
    @torch.inference_mode()
    def classify(self, publication: Dict) -> Dict:
        # Run classification on publications with caching
        # Check cache
        cache_key = self._get_cache_key(publication)
        if cache_key in self.classification_cache:
            return self.classification_cache[cache_key]
        
        text = self._prepare_text(publication)
        results = {
            'research_areas': [],
            'science_keywords': [],
            'division': None
        }
        
        # Research Area Classification
        if self.classifiers.get('research_area'):
            try:
                res_area = self.classifiers['research_area']['pipe'](
                    text, top_k=3, truncation=True, max_length=512
                )
                results['research_areas'] = [{
                    'label': self.models['research_area']['label_map'][int(pred['label'].replace('LABEL_', ''))],
                    'score': float(pred['score'])
                } for pred in res_area]
            except Exception as e:
                print(f"Research area classification failed: {str(e)}")
        
        # Science Keywords Classification
        if self.classifiers.get('science_keywords'):
            try:
                science_keywords = self.classifiers['science_keywords']['pipe'](
                    text, truncation=True, max_length=512, top_k=10
                )
                
                for pred in science_keywords:
                    if pred['score'] > 0.35 and len(pred['label']) >= 4:
                        results['science_keywords'].append({
                            'label': pred['label'],
                            'score': float(pred['score'])
                        })
            except Exception as e:
                print(f"Science keyword classification failed: {str(e)}")
        
        # Division Classification
        if self.classifiers.get('division'):
            try:
                division_result = self.classifiers['division']['pipe'](
                    text, top_k=1, truncation=True, max_length=512
                )
                
                if division_result:
                    division = division_result[0]
                    if 'score' in division:
                        results['division'] = {
                            'label': division['label'],
                            'score': float(division['score'])
                        }
            except Exception as e:
                print(f"Division classification failed: {str(e)}")
        
        # Cache results
        self.classification_cache[cache_key] = results
        return results


class ModelContextManager:
    # Context validation with model profiles
    
    def __init__(self, curated_publications: List[Dict]):
        self.st_model = st_model
        self.model_profiles = self._build_model_profiles(curated_publications)
        self.profile_cache = {}
        self.corpus_term_frequencies = self._build_corpus_term_frequencies(curated_publications)
        self.model_tfidf_terms = self._calculate_model_tfidf_terms(curated_publications)
        
    def _build_model_profiles(self, publications: List[Dict]) -> Dict[str, Dict]:
        # Build model context profiles from curated publications
        model_texts = defaultdict(list)
        model_terms = defaultdict(set)
        
        # Collect texts and terms
        for pub in publications:
            model = pub.get('model')
            if model:
                prepared_text = ScienceClassifier.get_instance()._prepare_text(pub)
                model_texts[model].append(prepared_text)
                model_terms[model].update(self._extract_key_terms(prepared_text))
        
        # Create embeddings
        final_profiles = {}
        for model, texts in model_texts.items():
            aggregated_text = ' '.join(texts[:min(100, len(texts))])
            
            with torch.inference_mode():
                embedding = self.st_model.encode(aggregated_text, convert_to_tensor=True)
                
            final_profiles[model] = {
                'embedding': embedding.cpu().numpy(),
                'terms': set(sorted(model_terms[model], key=lambda x: -len(x))[:25]),
                'text_count': len(texts)
            }
            
        optimize_memory()
        return final_profiles
    
    def _build_corpus_term_frequencies(self, publications: List[Dict]) -> Dict[str, int]:
        # Build term frequency dictionary for entire corpus
        corpus_terms = Counter()
        
        for pub in publications:
            prepared_text = ScienceClassifier.get_instance()._prepare_text(pub)
            words = re.findall(r'\b[a-z]{4,}\b', prepared_text.lower())
            corpus_terms.update(words)
            
        return corpus_terms
    
    def _calculate_model_tfidf_terms(self, publications: List[Dict]) -> Dict[str, List[Tuple[str, float]]]:
        # Calculate TF-IDF terms for each model
        model_term_counts = defaultdict(Counter)
        model_doc_counts = defaultdict(int)
        
        # Count terms by model
        for pub in publications:
            model = pub.get('model')
            if model:
                model_doc_counts[model] += 1
                prepared_text = ScienceClassifier.get_instance()._prepare_text(pub)
                words = re.findall(r'\b[a-z]{4,}\b', prepared_text.lower())
                model_term_counts[model].update(words)
        
        # Calculate total document count
        total_docs = sum(model_doc_counts.values())
        
        # Calculate TF-IDF for each term in each model
        model_tfidf_terms = {}
        for model, term_counts in model_term_counts.items():
            tfidf_scores = {}
            model_doc_count = model_doc_counts[model]
            
            for term, count in term_counts.items():
                # Term frequency in this model
                tf = count / sum(term_counts.values())
                
                # Inverse document frequency (add 1 to avoid division by zero)
                term_doc_count = sum(1 for m, tc in model_term_counts.items() if term in tc)
                idf = np.log((total_docs + 1) / (term_doc_count + 1))
                
                # TF-IDF score
                tfidf = tf * idf
                
                # Only keep terms with sufficient frequency and length
                if count >= 3 and len(term) >= 4:
                    tfidf_scores[term] = tfidf
            
            # Sort by TF-IDF score and take top terms
            sorted_terms = sorted(tfidf_scores.items(), key=lambda x: x[1], reverse=True)
            model_tfidf_terms[model] = sorted_terms[:30]  # Keep top 30 distinctive terms
            
        return model_tfidf_terms
    
    def _extract_key_terms(self, text: str) -> Set[str]:
        # Extract key terms from text
        words = re.findall(r'\b[a-z]{4,}\b', text.lower())
        word_counts = Counter(words)
        return {word for word, count in word_counts.items() if count >= 2}
    
    @lru_cache(maxsize=5000)
    def _get_pub_profile(self, prepared_text: str) -> Dict:
        # Get publication profile with caching
        if prepared_text in self.profile_cache:
            return self.profile_cache[prepared_text]
        
        with torch.inference_mode():
            embedding = self.st_model.encode(prepared_text)
            
        profile = {
            'embedding': embedding,
            'terms': self._extract_key_terms(prepared_text)
        }
        
        self.profile_cache[prepared_text] = profile
        return profile
    
    def get_model_specific_terms(self) -> Dict[str, List[str]]:
        # Return model-specific terminology based on TF-IDF analysis
        model_terms = {}
        for model, tfidf_terms in self.model_tfidf_terms.items():
            model_terms[model] = [term for term, score in tfidf_terms]
        return model_terms
    
    def get_context_scores(self, publication: Dict) -> Dict[str, float]:
        # Get context validation scores
        prepared_text = ScienceClassifier.get_instance()._prepare_text(publication)
        pub_profile = self._get_pub_profile(prepared_text)
        
        scores = {}
        for model, model_profile in self.model_profiles.items():
            # Basic term overlap score
            pub_terms = pub_profile['terms']
            model_terms = model_profile['terms']
            
            if not pub_terms or not model_terms:
                term_overlap = 0.0
            else:
                intersection = len(pub_terms.intersection(model_terms))
                union = len(pub_terms.union(model_terms))
                term_overlap = intersection / union if union > 0 else 0.0
            
            # TF-IDF term match score
            model_tfidf_terms = self.model_tfidf_terms.get(model, [])
            tfidf_term_set = {term for term, score in model_tfidf_terms}
            tfidf_match_count = len(pub_terms.intersection(tfidf_term_set))
            tfidf_match_score = min(1.0, tfidf_match_count / 5) if tfidf_term_set else 0.0
            
            # Semantic similarity
            semantic_sim = cosine_similarity(
                [pub_profile['embedding']], 
                [model_profile['embedding']]
            )[0][0]
            
            # Combined score (with TF-IDF term matches)
            scores[model] = 0.5 * semantic_sim + 0.3 * term_overlap + 0.2 * tfidf_match_score

        return scores


class RelevanceRanker:
    # Rank publications by relevance to models
    
    def __init__(self, model_descriptions: Dict[str, str]):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_descriptions = model_descriptions
        self.description_cache = {}
        self.result_cache = {}
        
        # Load model
        self.tokenizer = AutoTokenizer.from_pretrained("nasa-impact/nasa-smd-ibm-ranker")
        
        with torch.inference_mode():
            self.model = AutoModelForSequenceClassification.from_pretrained(
                "nasa-impact/nasa-smd-ibm-ranker"
            ).to(self.device)
            
            if self.device.type != 'cuda':
                self.model = self.model.float()
        
        self.model.eval()

    def _safe_prepare_model_text(self, model_id: str) -> str:
        # Prepare model text with caching
        if model_id not in self.description_cache:
            desc = self.model_descriptions.get(model_id, '')[:380]
            self.description_cache[model_id] = re.sub(r'\s+', ' ', desc)
        return self.description_cache[model_id]

    @torch.inference_mode()
    def batch_rank(self, query: str, model_ids: List[str]) -> Dict[str, float]:
        # Rank models by relevance to query
        if not model_ids:
            return {}
            
        # Check cache
        cache_key = f"{hash(query)}_{hash(tuple(sorted(model_ids)))}"
        if cache_key in self.result_cache:
            return self.result_cache[cache_key]
            
        query = query[:400]  # Truncate query
        batch_texts = [self._safe_prepare_model_text(mid) for mid in model_ids]
        
        try:
            # Prepare inputs
            inputs = self.tokenizer(
                [query] * len(batch_texts),
                batch_texts,
                padding='longest' if self.device.type == 'cuda' else True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(self.device)

            outputs = self.model(**inputs)
            scores = F.softmax(outputs.logits, dim=1)[:, 1].cpu().numpy()
            result = dict(zip(model_ids, scores.tolist()))
            
            # Cache results
            self.result_cache[cache_key] = result
            return result
            
        except RuntimeError as e:
            print(f"Ranker fallback to CPU: {str(e)}")
            self.device = torch.device("cpu")
            self.model = self.model.to(self.device).float()
            return self.batch_rank(query, model_ids)
            
        except Exception as e:
            print(f"Ranker failed: {str(e)}")
            return {mid: 0.0 for mid in model_ids}

In [None]:
# Compiled regex patterns for efficiency
HTML_TAG_PATTERN = re.compile(r'<[^>]+>')
HYPHEN_UNDERSCORE_PATTERN = re.compile(r'[-_]')
SPECIAL_CHAR_PATTERN = re.compile(r'[^a-zA-Z0-9]')

def preprocess_text(text: str) -> str:
    # lean and normalize text for matching
    if not text:
        return ""
    
    text = HTML_TAG_PATTERN.sub(' ', text)
    text = HYPHEN_UNDERSCORE_PATTERN.sub(' ', text)
    return SPECIAL_CHAR_PATTERN.sub(' ', text).lower()

# DOI normalization cache
_doi_cache = {}
def normalize_doi(doi: str) -> str:
    # Normalize DOI strings
    if not doi:
        return ""
    
    if doi in _doi_cache:
        return _doi_cache[doi]
    
    doi = doi.lower().replace("https://", "").replace("http://", "")
    if "doi.org/" in doi:
        doi = doi.replace("doi.org/", "")
    doi = doi.replace(",", ".")
    result = doi.strip("/ \n\r\t")
    
    _doi_cache[doi] = result
    return result

# Fuzzy matching cache
_fuzzy_match_cache = {}
def fuzzy_keyword_match(text: str, keyword: str, threshold: float = 90.0) -> Tuple[bool, float]:
    # Fuzzy keyword matching
    cache_key = f"{hash(text)}_{keyword}_{threshold}"
    if cache_key in _fuzzy_match_cache:
        return _fuzzy_match_cache[cache_key]
    
    # Check for exact match first
    keyword_lower = keyword.lower()
    if keyword_lower in text.lower():
        result = (True, 100.0)
        _fuzzy_match_cache[cache_key] = result
        return result
    
    # Use direct fuzzy matching for shorter texts
    if len(text) < 1000:
        score = fuzz.token_sort_ratio(text.lower(), keyword_lower)
        result = (score >= threshold, score)
        _fuzzy_match_cache[cache_key] = result
        return result
    
    # For longer texts, use spaCy
    doc = nlp(text[:2000])
    
    # Process relevant parts
    chunks = [chunk.text.lower() for chunk in doc.noun_chunks]
    entities = [ent.text.lower() for ent in doc.ents]
    tokens = [token.text.lower() for token in doc 
              if token.is_alpha and not token.is_stop and len(token.text) > 3]
    
    all_spans = chunks[:50] + entities[:50] + tokens[:100]
    
    # Find best match
    best_score = 0.0
    for span in all_spans:
        if len(span) < 3:
            continue
        score = fuzz.token_sort_ratio(span, keyword_lower)
        if score > best_score:
            best_score = score
    
    result = (best_score >= threshold, best_score)
    _fuzzy_match_cache[cache_key] = result
    return result

# Publication date extraction function
def extract_publication_date(publication: Dict) -> Optional[str]:
    # Extract publication date in yyyy-mm format
    try:
        # Try to get from 'published' field first
        if 'published' in publication and 'date-parts' in publication['published']:
            date_parts = publication['published']['date-parts'][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
        
        # Try published-online
        if 'published-online' in publication and 'date-parts' in publication['published-online']:
            date_parts = publication['published-online']['date-parts'][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
        
        # Try published-print
        if 'published-print' in publication and 'date-parts' in publication['published-print']:
            date_parts = publication['published-print']['date-parts'][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
        
        # Try indexed
        if 'indexed' in publication and 'date-parts' in publication['indexed']:
            date_parts = publication['indexed']['date-parts'][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
        
        # Try created
        if 'created' in publication and 'date-parts' in publication['created']:
            date_parts = publication['created']['date-parts'][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
        
        # Try to parse issue date if year is present
        if 'issued' in publication and 'date-parts' in publication['issued']:
            date_parts = publication['issued']['date-parts'][0]
            if len(date_parts) >= 2:
                return f"{date_parts[0]:04d}-{date_parts[1]:02d}"
            elif len(date_parts) >= 1:
                return f"{date_parts[0]:04d}-01"  # Default to January if only year
        
        # Try publication year
        if 'year' in publication:
            return f"{publication['year']:04d}-01"  # Default to January
            
        return None
    except Exception as e:
        print(f"Error extracting publication date: {str(e)}")
        return None

# Publication characteristics cache
_characteristics_cache = {}
def extract_publication_characteristics(publication: Dict) -> Dict[str, Any]:
    # Extract publication characteristics with caching
    cache_key = str(hash(json.dumps(publication, sort_keys=True)[:1000]))
    if cache_key in _characteristics_cache:
        return _characteristics_cache[cache_key]
    
    characteristics = {}
    
    # Extract title
    title = ""
    if 'title' in publication:
        if isinstance(publication['title'], list) and publication['title']:
            title = publication['title'][0]
        else:
            title = publication.get('title', '')
    
    # Extract abstract
    abstract = publication.get('abstract', '')
    
    # Basic metrics
    characteristics['title_length'] = len(title.split())
    characteristics['abstract_length'] = len(abstract.split())
    characteristics['total_length'] = characteristics['title_length'] + characteristics['abstract_length']
    
    # Keywords metrics
    keywords = publication.get('keywords', [])
    characteristics['keyword_count'] = len(keywords)
    
    # Process with spaCy if abstract is present
    if abstract and len(abstract) > 10 and len(abstract) < 10000:
        doc = nlp(abstract[:2000])
        
        # Part-of-speech distributions
        pos_counts = Counter([token.pos_ for token in doc])
        doc_len = len(doc) or 1
        
        characteristics['noun_ratio'] = pos_counts.get('NOUN', 0) / doc_len
        characteristics['verb_ratio'] = pos_counts.get('VERB', 0) / doc_len
        characteristics['adj_ratio'] = pos_counts.get('ADJ', 0) / doc_len
        
        # Named entity analysis
        entity_types = [ent.label_ for ent in doc.ents]
        entity_counter = Counter(entity_types)
        
        characteristics['entity_count'] = len(doc.ents)
        characteristics['org_count'] = entity_counter.get('ORG', 0)
        characteristics['person_count'] = entity_counter.get('PERSON', 0)
        characteristics['date_count'] = entity_counter.get('DATE', 0)
        
        # Readability metrics
        sentences = list(doc.sents)
        if sentences:
            sent_lengths = [len(sent) for sent in sentences]
            characteristics['avg_sentence_length'] = np.mean(sent_lengths)
            characteristics['sentence_count'] = len(sentences)
        else:
            characteristics['avg_sentence_length'] = 0
            characteristics['sentence_count'] = 0
    
    # Publication year
    if 'year' in publication:
        characteristics['year'] = publication.get('year')
    
    # Author metrics
    if 'authors' in publication:
        authors = publication.get('authors', [])
        if isinstance(authors, list):
            characteristics['author_count'] = len(authors)
        else:
            characteristics['author_count'] = 1 if authors else 0
    
    _characteristics_cache[cache_key] = characteristics
    return characteristics

In [None]:
_model_embeddings_cache = {}
_curated_models_cache = {}
_model_keywords_cache = {}
_model_descriptions_cache = {}

def initialize_model_embeddings(model_descriptions: Dict[str, str]) -> Dict[str, np.ndarray]:
    # Create embeddings for model descriptions with batch processing
    if _model_embeddings_cache and len(_model_embeddings_cache) == len(model_descriptions):
        return _model_embeddings_cache
    
    embeddings = {}
    batch_size = 32
    
    model_list = list(model_descriptions.items())
    
    for i in range(0, len(model_list), batch_size):
        batch = model_list[i:i+batch_size]
        models, texts = zip(*batch)
        
        with torch.inference_mode():
            batch_embeddings = st_model.encode(list(texts), convert_to_tensor=False)
        
        for j, model in enumerate(models):
            embeddings[model] = batch_embeddings[j]
    
    _model_embeddings_cache.update(embeddings)
    return embeddings

def load_curated_models(curated_path: str) -> Dict[str, str]:
    # Load curated DOI-model mappings
    if curated_path in _curated_models_cache:
        return _curated_models_cache[curated_path]
    
    try:
        with open(curated_path) as f:
            curated = json.load(f)
        
        mapping = {}
        for entry in curated:
            if 'doi' in entry and 'model' in entry:
                normalized_doi = normalize_doi(entry['doi'])
                if normalized_doi:
                    mapping[normalized_doi] = entry['model']
        
        _curated_models_cache[curated_path] = mapping
        return mapping
    except Exception as e:
        print(f"Error loading curated models from {curated_path}: {str(e)}")
        return {}

def load_model_keywords(keywords_path: str) -> Dict[str, List[str]]:
    # Load model keywords
    if keywords_path in _model_keywords_cache:
        return _model_keywords_cache[keywords_path]
    
    try:
        with open(keywords_path) as f:
            keywords = json.load(f)
        
        _model_keywords_cache[keywords_path] = keywords
        return keywords
    except Exception as e:
        print(f"Error loading model keywords from {keywords_path}: {str(e)}")
        return {}

def load_model_descriptions(descriptions_path: str) -> Dict[str, str]:
    # Load model descriptions
    if descriptions_path in _model_descriptions_cache:
        return _model_descriptions_cache[descriptions_path]
    
    try:
        with open(descriptions_path) as f:
            descriptions = json.load(f)
        
        _model_descriptions_cache[descriptions_path] = descriptions
        return descriptions
    except Exception as e:
        print(f"Error loading model descriptions from {descriptions_path}: {str(e)}")
        return {}

In [None]:
# Model thresholds
MODEL_THRESHOLDS = {
    'ECCO': 0.60,
    'RAPID': 0.65,
    'ISSM': 0.70,
    'CMS-Flux': 0.70,
    'CARDAMOM': 0.55,
    'MOMO-CHEM': 0.95
}

# Affinity caches
_keyword_affinity_cache = None
_research_area_affinity_cache = None
_division_affinity_cache = None

def get_science_keyword_model_affinities():
    # Get science keyword affinities
    global _keyword_affinity_cache
    
    if _keyword_affinity_cache is not None:
        return _keyword_affinity_cache
    
    try:
        with open('./data_driven_affinities.json', 'r') as f:
            _keyword_affinity_cache = json.load(f)
            return _keyword_affinity_cache
    except (FileNotFoundError, json.JSONDecodeError):
        print("Error loading affinity data. Using default empty affinities.")
        _keyword_affinity_cache = {}
        return _keyword_affinity_cache

def get_research_area_model_affinities():
    # Get research area affinities 
    global _research_area_affinity_cache
    
    if _research_area_affinity_cache is not None:
        return _research_area_affinity_cache
    
    _research_area_affinity_cache = {
        "Atmospheric/Ocean Indicators": {
            'ECCO': 1.7,
            'MOMO-CHEM': 1.4,
            'CMS-Flux': 1.1,
            'ISSM': 1.1
        },
        "Greenhouse Gases": {
            'CARDAMOM': 1.6,
            'CMS-Flux': 1.8,
            'MOMO-CHEM': 1.55,
            'ECCO': 1.15
        },
        "Ecosystems": {
            'CARDAMOM': 1.6,
            'CMS-Flux': 1.2,
            'ECCO': 1.25
        },
        "Land Surface/Agriculture Indicators": {
            'CARDAMOM': 1.4,
            'CMS-Flux': 1.3,
            'ECCO': 1.1,
            'ISSM': 1.4,
            'RAPID': 1.4
        },
        "Validation": {
            'CMS-Flux': 1.2,
            'ECCO': 1.4,
            'ISSM': 1.2,
            'MOMO-CHEM': 1.15,
            'RAPID': 1.25
        },
        "Cryospheric Indicators": {
            'ECCO': 1.35,
            'ISSM': 1.9
        },
        "Air Quality": {
            'CMS-Flux': 1.4,
            'MOMO-CHEM': 1.9
        },
        "Floods": {
            'ISSM': 1.15,
            'RAPID': 1.6
        },
        "Environmental Impacts": {
            'MOMO-CHEM': 1.25
        },
        "Severe Storms": {
            'ECCO': 1.2
        },
        "Earthquakes": {
            'ECCO': 1.05
        },
        "Droughts": {
            'CMS-Flux': 1.2,
            'RAPID': 1.4
        }
    }
    
    return _research_area_affinity_cache

def get_division_model_affinities():
    # Get division affinities
    global _division_affinity_cache
    
    if _division_affinity_cache is not None:
        return _division_affinity_cache
    
    _division_affinity_cache = {
        'Earth Science': {
            'ECCO': 1.3,
            'RAPID': 1.2,
            'CMS-Flux': 1.15,
            'MOMO-CHEM': 1.1,
            'ISSM': 1.15,
            'CARDAMOM': 1.15
        },
        'Biological and Physical Sciences': {
            'CARDAMOM': 1.2,
            'CMS-Flux': 1.1,
        },
        'Heliophysics': {
            'MOMO-CHEM': 1.15,
        },
        'Planetary Science': {
            'ISSM': 1.1,
        },
        'Astrophysics': {}
    }
    
    return _division_affinity_cache

def get_model_specific_thresholds():
    # Get model-specific thresholds
    return MODEL_THRESHOLDS.copy()

def analyze_threshold_performance(results: List[Dict], model_thresholds: Dict[str, float] = None, 
                                 overall_threshold: float = 0.4) -> Dict:
    """
    Analyze model performance with custom thresholds
    
    Args:
        results: List of publication results with ground truth
        model_thresholds: Custom threshold values by model (defaults to MODEL_THRESHOLDS)
        overall_threshold: Default threshold for models without specific threshold
        
    Returns:
        Dict containing performance metrics using the specified thresholds
    """
    if model_thresholds is None:
        model_thresholds = MODEL_THRESHOLDS.copy()
    
    # Extract publications with ground truth
    publications_with_truth = [
        result for result in results 
        if 'models' in result and result['models'] and 'confidence_scores' in result
    ]
    
    if not publications_with_truth:
        return {'error': 'No publications with ground truth for evaluation'}
    
    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get('models', []))
        for model in pub.get('confidence_scores', {}).keys():
            all_models.add(model)
    all_models = sorted(list(all_models))
    
    # Collect performance metrics
    model_metrics = {}
    overall_metrics = {'tp': 0, 'fp': 0, 'fn': 0}
    
    for model in all_models:
        # For each model, collect predictions using custom thresholds
        y_true = []
        y_pred = []
        
        # Apply model-specific threshold or fall back to overall
        threshold = model_thresholds.get(model, overall_threshold)
        
        for pub in publications_with_truth:
            is_true_match = 1 if model in pub.get('models', []) else 0
            confidence = pub.get('confidence_scores', {}).get(model, 0)
            is_predicted = 1 if confidence >= threshold else 0
            
            y_true.append(is_true_match)
            y_pred.append(is_predicted)
        
        # Calculate metrics
        tp = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
        fp = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1)
        fn = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        model_metrics[model] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'tp': tp,
            'fp': fp,
            'fn': fn,
            'threshold': threshold
        }
        
        # Accumulate for overall metrics
        overall_metrics['tp'] += tp
        overall_metrics['fp'] += fp
        overall_metrics['fn'] += fn
    
    # Calculate overall metrics
    overall_precision = overall_metrics['tp'] / (overall_metrics['tp'] + overall_metrics['fp']) if (overall_metrics['tp'] + overall_metrics['fp']) > 0 else 0
    overall_recall = overall_metrics['tp'] / (overall_metrics['tp'] + overall_metrics['fn']) if (overall_metrics['tp'] + overall_metrics['fn']) > 0 else 0
    overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
    
    return {
        'per_model': model_metrics,
        'overall': {
            'precision': overall_precision,
            'recall': overall_recall,
            'f1': overall_f1,
            'tp': overall_metrics['tp'],
            'fp': overall_metrics['fp'],
            'fn': overall_metrics['fn']
        },
        'thresholds': {
            'model_specific': model_thresholds,
            'overall_default': overall_threshold
        }
    }

def find_optimal_thresholds(results: List[Dict], threshold_range=None, step=0.05) -> Dict:
    """
    Find optimal thresholds for each model based on F1 score
    
    Args:
        results: List of publication results with ground truth
        threshold_range: Optional range of thresholds to test (min, max)
        step: Step size for threshold values
        
    Returns:
        Dict containing optimal thresholds for each model and overall
    """
    if threshold_range is None:
        threshold_range = (0.1, 0.95)
    
    # Generate threshold values to test
    thresholds = np.arange(threshold_range[0], threshold_range[1] + step, step)
    
    # Extract publications with ground truth
    publications_with_truth = [
        result for result in results 
        if 'models' in result and result['models'] and 'confidence_scores' in result
    ]
    
    if not publications_with_truth:
        return {'error': 'No publications with ground truth for evaluation'}
    
    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get('models', []))
        for model in pub.get('confidence_scores', {}).keys():
            all_models.add(model)
    all_models = sorted(list(all_models))
    
    # Find optimal thresholds for each model
    optimal_thresholds = {}
    
    for model in all_models:
        best_f1 = -1
        best_threshold = 0.4  # Default
        best_metrics = {}
        
        # Extract data for this model
        model_data = []
        for pub in publications_with_truth:
            confidence = pub.get('confidence_scores', {}).get(model, 0)
            is_true_match = 1 if model in pub.get('models', []) else 0
            model_data.append((confidence, is_true_match))
        
        # Test each threshold
        for threshold in thresholds:
            # Calculate metrics at this threshold
            tp = sum(1 for conf, true in model_data if true == 1 and conf >= threshold)
            fp = sum(1 for conf, true in model_data if true == 0 and conf >= threshold)
            fn = sum(1 for conf, true in model_data if true == 1 and conf < threshold)
            
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            # Update if this is the best F1 score
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
                best_metrics = {
                    'precision': precision,
                    'recall': recall,
                    'f1': f1,
                    'tp': tp,
                    'fp': fp,
                    'fn': fn
                }
        
        optimal_thresholds[model] = {
            'threshold': float(best_threshold),
            'f1': best_f1,
            'metrics': best_metrics
        }
    
    # Find optimal overall threshold
    all_data = []
    for pub in publications_with_truth:
        for model in all_models:
            confidence = pub.get('confidence_scores', {}).get(model, 0)
            is_true_match = 1 if model in pub.get('models', []) else 0
            all_data.append((confidence, is_true_match))
    
    best_overall_f1 = -1
    best_overall_threshold = 0.4  # Default
    best_overall_metrics = {}
    
    for threshold in thresholds:
        # Calculate metrics at this threshold
        tp = sum(1 for conf, true in all_data if true == 1 and conf >= threshold)
        fp = sum(1 for conf, true in all_data if true == 0 and conf >= threshold)
        fn = sum(1 for conf, true in all_data if true == 1 and conf < threshold)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        # Update if this is the best F1 score
        if f1 > best_overall_f1:
            best_overall_f1 = f1
            best_overall_threshold = threshold
            best_overall_metrics = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'tp': tp,
                'fp': fp,
                'fn': fn
            }
    
    # Create model threshold dictionary
    model_threshold_dict = {model: data['threshold'] for model, data in optimal_thresholds.items()}
    
    return {
        'per_model': optimal_thresholds,
        'overall': {
            'threshold': float(best_overall_threshold),
            'f1': best_overall_f1,
            'metrics': best_overall_metrics
        },
        'model_thresholds': model_threshold_dict
    }

# Semantic matching cache
_semantic_match_cache = {}

@torch.inference_mode()
def semantic_match(publication_text: str, model_embeddings: Dict[str, np.ndarray], 
                  threshold: float = 0.5, model_thresholds: Dict[str, float] = None) -> Dict[str, float]:
    # Semantic matching
    cache_key = f"{hash(publication_text)}_{hash(str(threshold))}"
    if cache_key in _semantic_match_cache:
        return _semantic_match_cache[cache_key]
    
    if model_thresholds is None:
        model_thresholds = get_model_specific_thresholds()
    
    pub_embedding = st_model.encode(publication_text, convert_to_tensor=False)
    
    model_names = list(model_embeddings.keys())
    embeddings_array = np.array([model_embeddings[model] for model in model_names])
    
    similarities = cosine_similarity([pub_embedding], embeddings_array)[0]
    
    results = {}
    for i, model in enumerate(model_names):
        sim = similarities[i]
        model_threshold = model_thresholds.get(model, threshold)
        
        if sim >= model_threshold:
            results[model] = float(sim)
    
    _semantic_match_cache[cache_key] = results
    return results

In [None]:
_matched_publications_cache = {}
_improved_matching_cache = {}

def match_models_improved(
    publication: Dict, 
    curated_mapping: Dict[str, str],
    model_keywords: Dict[str, List[str]],
    model_embeddings: Dict[str, np.ndarray],
    science_classifier,
    context_manager=None,
    ranker=None,
    threshold: float = 0.65,
    research_area_affinities=None,
    division_affinities=None,
    model_thresholds=None,
    include_classifications: bool = True
) -> Dict:
    # Match publication to models
    # Try to use cache
    pub_key = None
    if 'DOI' in publication and publication['DOI']:
        pub_key = f"doi:{normalize_doi(publication['DOI'])}"
    elif 'title' in publication:
        if isinstance(publication['title'], list) and publication['title']:
            pub_key = f"title:{publication['title'][0]}"
        elif isinstance(publication['title'], str):
            pub_key = f"title:{publication['title']}"
    
    if pub_key and pub_key in _improved_matching_cache:
        return _improved_matching_cache[pub_key]
    
    # Load affinities if needed
    if research_area_affinities is None:
        research_area_affinities = get_research_area_model_affinities()
        
    if division_affinities is None:
        division_affinities = get_division_model_affinities()
    
    if model_thresholds is None:
        model_thresholds = get_model_specific_thresholds()
    
    # Initialize scores
    all_models = list(model_embeddings.keys())
    confidence_scores = {model: 0.0 for model in all_models}
    confidence_sources = {model: [] for model in all_models}
    
    # Step 1: Check curated mapping
    doi = normalize_doi(publication.get('DOI', ''))
    if doi and doi in curated_mapping:
        model = curated_mapping[doi]
        confidence_scores[model] = 1.0
        confidence_sources[model].append('curated_mapping')
    
    # Step 2: Extract text
    title = ""
    if 'title' in publication:
        if isinstance(publication['title'], list) and publication['title']:
            title = publication['title'][0]
        else:
            title = publication.get('title', '')
            
    abstract = publication.get('abstract', '')
    publication_text = f"{title} {abstract}"
    
    # Early exit if curated match and minimal text
    model = None
    if doi and doi in curated_mapping:
        model = curated_mapping[doi]
        
    if model and (len(publication_text) < 10):
        matched_models = [model for model, confidence in confidence_scores.items() 
                        if confidence >= threshold]
        
        result = {
            'matched_models': matched_models,
            'confidence_scores': confidence_scores,
            'confidence_sources': confidence_sources
        }
        
        if pub_key:
            _improved_matching_cache[pub_key] = result
            
        return result
    
    # Step 3: Check keywords
    keywords = publication.get('keywords', [])
    text_for_keyword_matching = preprocess_text(publication_text)
    
    keyword_match_counts = {model: 0 for model in all_models}
    
    for model, model_kw_list in model_keywords.items():
        # Skip if already high confidence
        if model in confidence_scores and confidence_scores[model] >= 0.9:
            continue
            
        for kw in model_kw_list:
            # Check explicit keywords first
            found_in_keywords = any(fuzzy_keyword_match(pub_kw, kw)[0] for pub_kw in keywords)
            
            if found_in_keywords:
                keyword_match_counts[model] += 1
            else:
                # Check text if not found in keywords
                match_found, _ = fuzzy_keyword_match(text_for_keyword_matching, kw)
                if match_found:
                    keyword_match_counts[model] += 1
    
    # Convert keyword matches to confidence
    for model, count in keyword_match_counts.items():
        if count > 0:
            kw_confidence = min(0.95, 1 / (1 + np.exp(-0.5 * (count - 1))))
            
            if kw_confidence > 0.2:
                confidence_scores[model] = max(confidence_scores[model], kw_confidence)
                confidence_sources[model].append('keyword_match')
    
    # Step 4: Semantic matching
    try:
        if not any(score >= 0.9 for score in confidence_scores.values()):
            semantic_matches = semantic_match(
                publication_text, 
                model_embeddings,
                threshold=0.4,
                model_thresholds=model_thresholds
            )
            
            for model, similarity in semantic_matches.items():
                if similarity > 0.4:
                    confidence_scores[model] = max(confidence_scores[model], similarity)
                    confidence_sources[model].append('semantic_match')
    except Exception as e:
        print(f"Semantic matching failed: {str(e)}")
    
    # Step 5: Science classification signals
    if not any(score >= 0.9 for score in confidence_scores.values()):
        try:
            science_results = science_classifier.classify(publication)
            
            # Apply science keyword affinities
            keyword_affinities = get_science_keyword_model_affinities()
            science_keywords = science_results.get('science_keywords', [])
            for kw_entry in science_keywords:
                keyword = kw_entry['label']
                keyword_score = kw_entry['score']
                
                if keyword in keyword_affinities and keyword_score >= 0.3:
                    for model, affinity in keyword_affinities[keyword].items():
                        confidence = keyword_score * (affinity - 1.0)
                        
                        if confidence > 0.1:
                            confidence_scores[model] = max(confidence_scores[model], confidence)
                            confidence_sources[model].append(f'science_keyword:{keyword}')
            
            # Apply research area affinities
            research_areas = science_results.get('research_areas', [])
            for area_entry in research_areas:
                area = area_entry['label']
                area_score = area_entry['score']
                
                if area in research_area_affinities and area_score >= 0.3:
                    for model, affinity in research_area_affinities[area].items():
                        confidence = area_score * (affinity - 1.0)
                        
                        if confidence > 0.1:
                            confidence_scores[model] = max(confidence_scores[model], confidence)
                            confidence_sources[model].append(f'research_area:{area}')
            
            # Apply division affinities
            division_entry = science_results.get('division')
            if division_entry:
                division = division_entry['label']
                division_score = division_entry['score']
                
                if division in division_affinities and division_score >= 0.5:
                    for model, affinity in division_affinities[division].items():
                        confidence = division_score * (affinity - 1.0)
                        
                        if confidence > 0.05:
                            confidence_scores[model] = max(confidence_scores[model], confidence)
                            confidence_sources[model].append(f'division:{division}')
        except Exception as e:
            print(f"Science classification failed: {str(e)}")
    
    # Step 6: Context validation
    if context_manager and not any(score >= 0.9 for score in confidence_scores.values()):
        try:
            context_scores = context_manager.get_context_scores(publication)
            for model, score in context_scores.items():
                if score > 0.4:
                    context_confidence = score * 0.9
                    confidence_scores[model] = max(confidence_scores[model], context_confidence)
                    confidence_sources[model].append('context_validation')
        except Exception as e:
            print(f"Context validation failed: {str(e)}")
    
    # Step 7: Relevance ranking
    if ranker and not any(score >= 0.9 for score in confidence_scores.values()):
        try:
            query = publication_text[:500]
            top_models = [model for model, score in confidence_scores.items() if score > 0.3]
            
            if top_models:
                rank_scores = ranker.batch_rank(query, top_models)
                
                for model, score in rank_scores.items():
                    if score > 0.3:
                        ranker_confidence = score * 0.98
                        confidence_scores[model] = max(confidence_scores[model], ranker_confidence)
                        confidence_sources[model].append('relevance_ranker')
        except Exception as e:
            print(f"Relevance ranking failed: {str(e)}")
    
    # Step 8: Apply hybrid boosts
    candidate_models = [model for model, score in confidence_scores.items() if score >= 0.3]
    for model in candidate_models:
        sources = confidence_sources[model]
        
        # Boost for keyword + semantic matches
        if 'keyword_match' in sources and 'semantic_match' in sources:
            current_score = confidence_scores[model]
            confidence_scores[model] = min(0.95, current_score * 1.05)
            if 'hybrid' not in sources:
                sources.append('hybrid')
        
        # Boost for science metadata consensus
        science_sources = [s for s in sources if s.startswith(('science_keyword:', 'research_area:', 'division:'))]
        if len(science_sources) >= 2:
            current_score = confidence_scores[model]
            confidence_scores[model] = min(0.95, current_score * 1.12)
            if 'science_consensus' not in sources:
                sources.append('science_consensus')
    
    # Filter by thresholds
    matched_models = []
    for model, confidence in confidence_scores.items():
        model_threshold = model_thresholds.get(model, threshold)
        if confidence >= model_threshold:
            matched_models.append(model)
    
    # Add classifications
    science_results = None
    if include_classifications:
        try:
            # Use existing science results if available
            if not science_results:
                science_results = science_classifier.classify(publication)
                
            result_with_classifications = {
                'matched_models': matched_models,
                'confidence_scores': confidence_scores,
                'confidence_sources': confidence_sources,
                # Include all classification results
                'classifications': {
                    'research_areas': science_results.get('research_areas', []),
                    'science_keywords': science_results.get('science_keywords', []),
                    'division': science_results.get('division')
                }
            }
            
            # Add context terms if available
            if context_manager:
                try:
                    prepared_text = science_classifier.get_instance()._prepare_text(publication)
                    pub_profile = context_manager._get_pub_profile(prepared_text)
                    result_with_classifications['context_terms'] = list(pub_profile['terms'])
                except Exception as e:
                    print(f"Error extracting context terms: {str(e)}")
            
            # Add relevance scores if available
            if ranker and publication_text:
                try:
                    query = publication_text[:500]
                    relevance_scores = ranker.batch_rank(query, all_models)
                    result_with_classifications['relevance_scores'] = relevance_scores
                except Exception as e:
                    print(f"Error getting relevance scores: {str(e)}")
            
            # Extract publication date 
            pub_date = extract_publication_date(publication)
            if pub_date:
                result_with_classifications['pubdate'] = pub_date
                
            # Cache result
            if pub_key:
                _improved_matching_cache[pub_key] = result_with_classifications
                
            return result_with_classifications
        except Exception as e:
            print(f"Error including classifications: {str(e)}")
    
    # Basic result without classifications
    result = {
        'matched_models': matched_models,
        'confidence_scores': confidence_scores,
        'confidence_sources': confidence_sources
    }
    
    # Extract publication date 
    pub_date = extract_publication_date(publication)
    if pub_date:
        result['pubdate'] = pub_date
    
    # Cache result
    if pub_key:
        _improved_matching_cache[pub_key] = result
        
    return result

def process_publication_batch(
    publications: List[Dict], 
    curated_mapping: Dict[str, str], 
    model_keywords: Dict[str, List[str]], 
    model_embeddings: Dict[str, np.ndarray], 
    science_classifier,
    context_manager,
    ranker=None,
    batch_size: int = 100,
    include_classifications: bool = True
) -> List[Dict]:
    # Process publications in batches
    results = []
    total_pubs = len(publications)
    num_batches = (total_pubs + batch_size - 1) // batch_size
    
    # Load affinity
    research_area_affinities = get_research_area_model_affinities()
    division_affinities = get_division_model_affinities()
    model_thresholds = get_model_specific_thresholds()
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_pubs)
        
        print(f"Processing batch {batch_idx+1}/{num_batches} (publications {start_idx+1}-{end_idx}/{total_pubs})...")
        
        batch_pubs = publications[start_idx:end_idx]
        batch_results = []
        
        # Process each publication
        for i, pub in enumerate(batch_pubs):
            pub_idx = start_idx + i + 1
            if pub_idx % 20 == 0:
                print(f"  Processing publication {pub_idx}/{total_pubs}...")
            
            # Create cache
            pub_key = None
            if 'DOI' in pub and pub['DOI']:
                pub_key = f"doi:{normalize_doi(pub['DOI'])}"
            elif 'title' in pub:
                if isinstance(pub['title'], list) and pub['title']:
                    pub_key = f"title:{pub['title'][0]}"
                elif isinstance(pub['title'], str):
                    pub_key = f"title:{pub['title']}"
            
            # Check cache
            if pub_key and pub_key in _matched_publications_cache:
                cached_result = _matched_publications_cache[pub_key]
                pub_copy = pub.copy()
                pub_copy.update(cached_result)
                batch_results.append(pub_copy)
                continue
            
            # Match models
            try:
                match_result = match_models_improved(
                    pub, 
                    curated_mapping, 
                    model_keywords, 
                    model_embeddings, 
                    science_classifier,
                    context_manager,
                    ranker,
                    research_area_affinities=research_area_affinities,
                    division_affinities=division_affinities,
                    model_thresholds=model_thresholds,
                    include_classifications=include_classifications
                )
                
                # Add results to publication
                pub_copy = pub.copy()
                # Add all fields from match_result
                for key, value in match_result.items():
                    pub_copy[key] = value
                
                # Extract characteristics if needed
                if ranker is not None:
                    pub_copy['pub_characteristics'] = extract_publication_characteristics(pub)
                
                # Add to results
                batch_results.append(pub_copy)
                
                # Cache results
                if pub_key:
                    cache_value = {
                        'matched_models': match_result['matched_models'],
                        'confidence_scores': match_result['confidence_scores'],
                        'confidence_sources': match_result['confidence_sources']
                    }
                    if 'pub_characteristics' in pub_copy:
                        cache_value['pub_characteristics'] = pub_copy['pub_characteristics']
                    _matched_publications_cache[pub_key] = cache_value
                    
            except Exception as e:
                print(f"Error processing publication {pub_idx}: {str(e)}")
                # Add publication without matches
                pub_copy = pub.copy()
                pub_copy['matched_models'] = []
                pub_copy['confidence_scores'] = {}
                pub_copy['confidence_sources'] = {}
                batch_results.append(pub_copy)
        
        # Add batch results to overall results
        results.extend(batch_results)
        
        # Periodically clean caches
        if batch_idx % 5 == 4:
            _fuzzy_match_cache.clear()
            _semantic_match_cache.clear()
            
            # Clean up memory
            optimize_memory()
        
        print(f"Completed batch {batch_idx+1}/{num_batches}")
    
    return results

In [None]:
def visualize_metrics(results: List[Dict], output_path_base: str = "./metrics_visualization"):
    # Create visualizations

    # Extract publications with ground truth
    publications_with_truth = [
        result for result in results 
        if 'models' in result and result['models'] and 'matched_models' in result
    ]
    
    if not publications_with_truth:
        print("No publications with ground truth for evaluation")
        return
        
    # Get all unique models
    all_models = set()
    for pub in publications_with_truth:
        all_models.update(pub.get('models', []))
        all_models.update(pub.get('matched_models', []))
    all_models = sorted(list(all_models))
    
    # Dictionary to store all metrics
    complete_metrics = {
        'model_performance': {},
        'source_analysis': {},
        'classification_accuracy': {},
        'confidence_analysis': {},
        'temporal_analysis': {},
        'threshold_analysis': {}  # New for threshold analysis
    }
    
    # 1. Model Performance Metrics
    model_metrics = {}
    metrics = {
        'precision': [],
        'recall': [],
        'f1': []
    }
    
    for model in all_models:
        # For each model, collect binary predictions
        y_true = []
        y_pred = []
        
        for pub in publications_with_truth:
            y_true.append(1 if model in pub.get('models', []) else 0)
            y_pred.append(1 if model in pub.get('matched_models', []) else 0)
        
        # Calculate metrics
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        
        model_metrics[model] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'true_positives': sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1),
            'false_positives': sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1),
            'false_negatives': sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0),
            'match_count': sum(y_pred)
        }
        
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        metrics['f1'].append(f1)
    
    # Calculate micro-average metrics
    all_y_true = []
    all_y_pred = []
    
    for pub in publications_with_truth:
        true_labels = [1 if model in pub.get('models', []) else 0 for model in all_models]
        pred_labels = [1 if model in pub.get('matched_models', []) else 0 for model in all_models]
        
        all_y_true.extend(true_labels)
        all_y_pred.extend(pred_labels)
    
    micro_precision = precision_score(all_y_true, all_y_pred, zero_division=0)
    micro_recall = recall_score(all_y_true, all_y_pred, zero_division=0)
    micro_f1 = f1_score(all_y_true, all_y_pred, zero_division=0)
    
    complete_metrics['model_performance'] = {
        'per_model': model_metrics,
        'micro_average': {
            'precision': micro_precision,
            'recall': micro_recall,
            'f1': micro_f1
        }
    }
    
    # Create model performance visualization
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    # Plot per-model metrics
    x = np.arange(len(all_models))
    width = 0.25
    
    # Skip if no models
    if len(all_models) > 0:
        bars1 = ax1.bar(x - width, metrics['precision'], width, label='Precision')
        bars2 = ax1.bar(x, metrics['recall'], width, label='Recall')
        bars3 = ax1.bar(x + width, metrics['f1'], width, label='F1 Score')
        
        ax1.set_xlabel('Models')
        ax1.set_ylabel('Score')
        ax1.set_title('Precision, Recall, and F1 Score by Model')
        ax1.set_xticks(x)
        ax1.set_xticklabels(all_models, rotation=45, ha='right')
        ax1.legend()
        ax1.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add values on bars
        def add_labels(bars):
            for bar in bars:
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{height:.2f}', ha='center', va='bottom')
        
        add_labels(bars1)
        add_labels(bars2)
        add_labels(bars3)
    
    # Plot micro-average metrics
    micro_metrics = {
        'Precision': micro_precision,
        'Recall': micro_recall,
        'F1 Score': micro_f1
    }
    
    x2 = np.arange(len(micro_metrics))
    bars = ax2.bar(x2, micro_metrics.values(), width=0.4)
    
    ax2.set_xlabel('Metrics')
    ax2.set_ylabel('Score')
    ax2.set_title('Overall Micro-Average Metrics')
    ax2.set_xticks(x2)
    ax2.set_xticklabels(micro_metrics.keys())
    ax2.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add values on top of the bars
    for bar in bars:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_model_performance.png", dpi=300)
    plt.close()

    
    # 2. Confidence Analysis

    # Extract confidence scores
    confidence_data = {}
    
    for model in all_models:
        confidence_scores = []
        correct_predictions = []
        
        for pub in publications_with_truth:
            is_true_match = model in pub.get('models', [])
            is_predicted = model in pub.get('matched_models', [])
            confidence = pub.get('confidence_scores', {}).get(model, 0)
            
            confidence_scores.append(confidence)
            correct_predictions.append(1 if is_true_match == is_predicted else 0)
        
        confidence_data[model] = {
            'scores': confidence_scores,
            'correct_predictions': correct_predictions,
            'mean_confidence': np.mean(confidence_scores) if confidence_scores else 0,
            'median_confidence': np.median(confidence_scores) if confidence_scores else 0
        }
    
    complete_metrics['confidence_analysis'] = confidence_data
    
    # Create confidence visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot mean confidence by model
    mean_confidence = [confidence_data[model]['mean_confidence'] for model in all_models]
    
    if all_models:
        ax1.bar(all_models, mean_confidence)
        ax1.set_xlabel('Models')
        ax1.set_ylabel('Mean Confidence Score')
        ax1.set_title('Mean Confidence Score by Model')
        ax1.set_xticklabels(all_models, rotation=45, ha='right')
        ax1.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Plot confidence distribution (boxplot)
    if all_models:
        confidence_values = [confidence_data[model]['scores'] for model in all_models]
        ax2.boxplot(confidence_values, labels=all_models)
        ax2.set_xlabel('Models')
        ax2.set_ylabel('Confidence Score Distribution')
        ax2.set_title('Confidence Score Distribution by Model')
        ax2.set_xticklabels(all_models, rotation=45, ha='right')
        ax2.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_confidence_analysis.png", dpi=300)
    plt.close()

    
    # 3. Source Analysis

    # Analyze the sources used for matches
    source_counts = {}
    source_accuracy = {}
    
    for pub in publications_with_truth:
        for model in all_models:
            sources = pub.get('confidence_sources', {}).get(model, [])
            is_true_match = model in pub.get('models', [])
            is_predicted = model in pub.get('matched_models', [])
            
            for source in sources:
                if source not in source_counts:
                    source_counts[source] = 0
                    source_accuracy[source] = {'correct': 0, 'total': 0}
                
                source_counts[source] += 1
                
                # Track accuracy
                source_accuracy[source]['total'] += 1
                if is_true_match == is_predicted:
                    source_accuracy[source]['correct'] += 1
    
    # Calculate accuracy rates
    for source in source_accuracy:
        if source_accuracy[source]['total'] > 0:
            source_accuracy[source]['accuracy'] = source_accuracy[source]['correct'] / source_accuracy[source]['total']
        else:
            source_accuracy[source]['accuracy'] = 0
    
    complete_metrics['source_analysis'] = {
        'counts': source_counts,
        'accuracy': source_accuracy
    }
    
    # Create source analysis visualization
    if source_counts:
        # Prepare data
        sources = list(source_counts.keys())
        counts = [source_counts[s] for s in sources]
        accuracies = [source_accuracy[s]['accuracy'] for s in sources]
        
        # Sort by count
        sorted_data = sorted(zip(sources, counts, accuracies), key=lambda x: x[1], reverse=True)
        sources = [x[0] for x in sorted_data]
        counts = [x[1] for x in sorted_data]
        accuracies = [x[2] for x in sorted_data]
        
        # Limit to top 15 sources for better visualization
        if len(sources) > 15:
            sources = sources[:15]
            counts = counts[:15]
            accuracies = accuracies[:15]
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # Plot counts
        ax1.bar(sources, counts)
        ax1.set_xlabel('Confidence Sources')
        ax1.set_ylabel('Count')
        ax1.set_title('Usage Count by Confidence Source')
        ax1.set_xticklabels(sources, rotation=45, ha='right')
        ax1.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Plot accuracy
        ax2.bar(sources, accuracies)
        ax2.set_xlabel('Confidence Sources')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Accuracy by Confidence Source')
        ax2.set_xticklabels(sources, rotation=45, ha='right')
        ax2.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.savefig(f"{output_path_base}_source_analysis.png", dpi=300)
        plt.close()

    
    # 4. Temporal Analysis

    # Analyze performance over time by publication date
    date_metrics = {}
    
    for pub in publications_with_truth:
        pub_date = pub.get('pubdate')
        if not pub_date or len(pub_date) < 7:  # Ensure proper format
            continue
            
        year_month = pub_date  # Already in yyyy-mm format
        
        if year_month not in date_metrics:
            date_metrics[year_month] = {
                'total': 0,
                'correct': 0,
                'precision': [],
                'recall': [],
                'f1': []
            }
        
        # Count publications
        date_metrics[year_month]['total'] += 1
        
        # Calculate accuracy for this publication
        true_models = set(pub.get('models', []))
        pred_models = set(pub.get('matched_models', []))
        
        # True positives
        tp = len(true_models.intersection(pred_models))
        # False positives
        fp = len(pred_models - true_models)
        # False negatives
        fn = len(true_models - pred_models)
        
        # Calculate metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        date_metrics[year_month]['precision'].append(precision)
        date_metrics[year_month]['recall'].append(recall)
        date_metrics[year_month]['f1'].append(f1)
        
        # Mark as correct if all matches are correct
        if true_models == pred_models:
            date_metrics[year_month]['correct'] += 1
    
    # Calculate average metrics by date
    for date in date_metrics:
        if date_metrics[date]['total'] > 0:
            date_metrics[date]['accuracy'] = date_metrics[date]['correct'] / date_metrics[date]['total']
            date_metrics[date]['avg_precision'] = np.mean(date_metrics[date]['precision'])
            date_metrics[date]['avg_recall'] = np.mean(date_metrics[date]['recall'])
            date_metrics[date]['avg_f1'] = np.mean(date_metrics[date]['f1'])
    
    complete_metrics['temporal_analysis'] = date_metrics
    
    # Create temporal analysis visualization
    if date_metrics:
        # Sort dates chronologically
        sorted_dates = sorted(date_metrics.keys())
        
        if len(sorted_dates) > 1:  # Only plot if we have multiple dates
            accuracies = [date_metrics[d]['accuracy'] for d in sorted_dates]
            precisions = [date_metrics[d]['avg_precision'] for d in sorted_dates]
            recalls = [date_metrics[d]['avg_recall'] for d in sorted_dates]
            f1_scores = [date_metrics[d]['avg_f1'] for d in sorted_dates]
            
            fig, ax = plt.subplots(figsize=(12, 6))
            
            ax.plot(sorted_dates, accuracies, 'o-', label='Accuracy')
            ax.plot(sorted_dates, precisions, 's-', label='Precision')
            ax.plot(sorted_dates, recalls, '^-', label='Recall')
            ax.plot(sorted_dates, f1_scores, 'D-', label='F1 Score')
            
            ax.set_xlabel('Publication Date')
            ax.set_ylabel('Score')
            ax.set_title('Performance Metrics by Publication Date')
            ax.grid(True, linestyle='--', alpha=0.7)
            ax.legend()
            
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(f"{output_path_base}_temporal_analysis.png", dpi=300)
            plt.close()
            
    
    # 5. Classification Accuracy Analysis

    # Analyze how classification results correlate with model matches
    research_area_stats = {}
    science_keyword_stats = {}
    division_stats = {}
    
    # Extract data
    for pub in publications_with_truth:
        # Check if we have classification data
        classifications = pub.get('classifications', {})
        if not classifications:
            continue
            
        # Get match accuracy for this publication
        true_models = set(pub.get('models', []))
        pred_models = set(pub.get('matched_models', []))
        is_correct = (true_models == pred_models)
        
        # Process research areas
        for area in classifications.get('research_areas', []):
            area_name = area.get('label', '')
            area_score = area.get('score', 0)
            
            if area_name and area_score > 0.3:  # Only consider significant areas
                if area_name not in research_area_stats:
                    research_area_stats[area_name] = {'correct': 0, 'total': 0}
                
                research_area_stats[area_name]['total'] += 1
                if is_correct:
                    research_area_stats[area_name]['correct'] += 1
        
        # Process science keywords
        for keyword in classifications.get('science_keywords', []):
            kw_name = keyword.get('label', '')
            kw_score = keyword.get('score', 0)
            
            if kw_name and kw_score > 0.3:  # Only consider significant keywords
                if kw_name not in science_keyword_stats:
                    science_keyword_stats[kw_name] = {'correct': 0, 'total': 0}
                
                science_keyword_stats[kw_name]['total'] += 1
                if is_correct:
                    science_keyword_stats[kw_name]['correct'] += 1
        
        # Process division
        division = classifications.get('division', {})
        if division:
            div_name = division.get('label', '')
            div_score = division.get('score', 0)
            
            if div_name and div_score > 0.5:  # Only consider significant divisions
                if div_name not in division_stats:
                    division_stats[div_name] = {'correct': 0, 'total': 0}
                
                division_stats[div_name]['total'] += 1
                if is_correct:
                    division_stats[div_name]['correct'] += 1
    
    # Calculate accuracy rates
    for area in research_area_stats:
        if research_area_stats[area]['total'] > 0:
            research_area_stats[area]['accuracy'] = research_area_stats[area]['correct'] / research_area_stats[area]['total']
    
    for kw in science_keyword_stats:
        if science_keyword_stats[kw]['total'] > 0:
            science_keyword_stats[kw]['accuracy'] = science_keyword_stats[kw]['correct'] / science_keyword_stats[kw]['total']
    
    for div in division_stats:
        if division_stats[div]['total'] > 0:
            division_stats[div]['accuracy'] = division_stats[div]['correct'] / division_stats[div]['total']
    
    complete_metrics['classification_accuracy'] = {
        'research_areas': research_area_stats,
        'science_keywords': science_keyword_stats,
        'divisions': division_stats
    }
    
    # Create classification accuracy visualization
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))
    
    # Filter out classifications with too few samples
    def filter_and_sort(stats_dict, min_samples=5):
        filtered = {k: v for k, v in stats_dict.items() if v['total'] >= min_samples}
        return sorted(filtered.items(), key=lambda x: x[1]['accuracy'], reverse=True)
    
    # Plot research area accuracy
    sorted_areas = filter_and_sort(research_area_stats)
    if sorted_areas:
        area_names = [a[0] for a in sorted_areas]
        area_accuracies = [a[1]['accuracy'] for a in sorted_areas]
        area_counts = [a[1]['total'] for a in sorted_areas]
        
        # Limit to top 15 for better visualization
        if len(area_names) > 15:
            area_names = area_names[:15]
            area_accuracies = area_accuracies[:15]
            area_counts = area_counts[:15]
            
        bars = ax1.barh(area_names, area_accuracies)
        ax1.set_xlabel('Accuracy')
        ax1.set_ylabel('Research Area')
        ax1.set_title('Match Accuracy by Research Area')
        ax1.set_xlim(0, 1)
        
        # Add count labels to bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax1.text(width + 0.01, bar.get_y() + bar.get_height()/2, 
                    f'n={area_counts[i]}', ha='left', va='center')
    else:
        ax1.text(0.5, 0.5, 'Insufficient research area data', 
                ha='center', va='center', transform=ax1.transAxes)
    
    # Plot science keyword accuracy
    sorted_keywords = filter_and_sort(science_keyword_stats)
    if sorted_keywords:
        kw_names = [k[0] for k in sorted_keywords]
        kw_accuracies = [k[1]['accuracy'] for k in sorted_keywords]
        kw_counts = [k[1]['total'] for k in sorted_keywords]
        
        # Limit to top 15 for better visualization
        if len(kw_names) > 15:
            kw_names = kw_names[:15]
            kw_accuracies = kw_accuracies[:15]
            kw_counts = kw_counts[:15]
            
        bars = ax2.barh(kw_names, kw_accuracies)
        ax2.set_xlabel('Accuracy')
        ax2.set_ylabel('Science Keyword')
        ax2.set_title('Match Accuracy by Science Keyword')
        ax2.set_xlim(0, 1)
        
        # Add count labels to bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax2.text(width + 0.01, bar.get_y() + bar.get_height()/2, 
                    f'n={kw_counts[i]}', ha='left', va='center')
    else:
        ax2.text(0.5, 0.5, 'Insufficient science keyword data', 
                ha='center', va='center', transform=ax2.transAxes)
    
    # Plot division accuracy
    sorted_divisions = filter_and_sort(division_stats)
    if sorted_divisions:
        div_names = [d[0] for d in sorted_divisions]
        div_accuracies = [d[1]['accuracy'] for d in sorted_divisions]
        div_counts = [d[1]['total'] for d in sorted_divisions]
            
        bars = ax3.barh(div_names, div_accuracies)
        ax3.set_xlabel('Accuracy')
        ax3.set_ylabel('Division')
        ax3.set_title('Match Accuracy by Division')
        ax3.set_xlim(0, 1)
        
        # Add count labels to bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            ax3.text(width + 0.01, bar.get_y() + bar.get_height()/2, 
                    f'n={div_counts[i]}', ha='left', va='center')
    else:
        ax3.text(0.5, 0.5, 'Insufficient division data', 
                ha='center', va='center', transform=ax3.transAxes)
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_classification_accuracy.png", dpi=300)
    plt.close()
    
    # 6. Threshold Analysis (New)
    
    # For each model, analyze different confidence thresholds
    threshold_analysis = {}
    
    # Generate possible threshold values to test
    thresholds = np.arange(0.1, 1.0, 0.05)
    
    for model in all_models:
        threshold_metrics = {t: {'precision': 0, 'recall': 0, 'f1': 0, 'tp': 0, 'fp': 0, 'fn': 0} 
                             for t in thresholds}
        
        # Extract confidence scores and true labels for this model
        confidence_scores = []
        true_labels = []
        
        for pub in publications_with_truth:
            score = pub.get('confidence_scores', {}).get(model, 0)
            is_true_match = 1 if model in pub.get('models', []) else 0
            
            confidence_scores.append(score)
            true_labels.append(is_true_match)
        
        # Calculate metrics at each threshold
        for threshold in thresholds:
            # Generate predictions using this threshold
            predicted_labels = [1 if score >= threshold else 0 for score in confidence_scores]
            
            # Calculate performance metrics
            tp = sum(1 for t, p in zip(true_labels, predicted_labels) if t == 1 and p == 1)
            fp = sum(1 for t, p in zip(true_labels, predicted_labels) if t == 0 and p == 1)
            fn = sum(1 for t, p in zip(true_labels, predicted_labels) if t == 1 and p == 0)
            
            # Precision, recall, and F1
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            threshold_metrics[threshold] = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'tp': tp,
                'fp': fp,
                'fn': fn
            }
        
        # Find optimal threshold based on F1 score
        f1_scores = [(t, metrics['f1']) for t, metrics in threshold_metrics.items()]
        optimal_threshold = max(f1_scores, key=lambda x: x[1])[0] if f1_scores else 0.5
        
        threshold_analysis[model] = {
            'metrics': threshold_metrics,
            'optimal_threshold': optimal_threshold,
            'optimal_f1': threshold_metrics[optimal_threshold]['f1'],
            'current_threshold': MODEL_THRESHOLDS.get(model, 0.4)
        }
    
    # Calculate overall threshold analysis
    overall_threshold_metrics = {t: {'precision': 0, 'recall': 0, 'f1': 0} for t in thresholds}
    
    # Get all confidence scores and labels across all models
    all_confidence_scores = []
    all_true_labels = []
    
    for pub in publications_with_truth:
        for model in all_models:
            score = pub.get('confidence_scores', {}).get(model, 0)
            is_true_match = 1 if model in pub.get('models', []) else 0
            
            all_confidence_scores.append(score)
            all_true_labels.append(is_true_match)
    
    # Calculate metrics for each threshold
    for threshold in thresholds:
        predicted_labels = [1 if score >= threshold else 0 for score in all_confidence_scores]
        
        precision = precision_score(all_true_labels, predicted_labels, zero_division=0)
        recall = recall_score(all_true_labels, predicted_labels, zero_division=0)
        f1 = f1_score(all_true_labels, predicted_labels, zero_division=0)
        
        overall_threshold_metrics[threshold] = {
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
    
    # Find optimal overall threshold
    overall_f1_scores = [(t, metrics['f1']) for t, metrics in overall_threshold_metrics.items()]
    overall_optimal_threshold = max(overall_f1_scores, key=lambda x: x[1])[0] if overall_f1_scores else 0.5
    
    threshold_analysis['overall'] = {
        'metrics': overall_threshold_metrics,
        'optimal_threshold': overall_optimal_threshold,
        'optimal_f1': overall_threshold_metrics[overall_optimal_threshold]['f1'],
        'current_threshold': 0.4  # Default overall threshold
    }
    
    complete_metrics['threshold_analysis'] = threshold_analysis
    
    # Create threshold analysis visualizations
    # 1. Model-specific threshold analysis
    for model in all_models:
        fig, ax = plt.subplots(figsize=(10, 6))
        
        model_thresholds = sorted(list(threshold_analysis[model]['metrics'].keys()))
        precision_values = [threshold_analysis[model]['metrics'][t]['precision'] for t in model_thresholds]
        recall_values = [threshold_analysis[model]['metrics'][t]['recall'] for t in model_thresholds]
        f1_values = [threshold_analysis[model]['metrics'][t]['f1'] for t in model_thresholds]
        
        ax.plot(model_thresholds, precision_values, 'b-', label='Precision')
        ax.plot(model_thresholds, recall_values, 'g-', label='Recall')
        ax.plot(model_thresholds, f1_values, 'r-', label='F1 Score')
        
        # Mark current and optimal thresholds
        current_threshold = threshold_analysis[model]['current_threshold']
        optimal_threshold = threshold_analysis[model]['optimal_threshold']
        
        ax.axvline(x=current_threshold, color='gray', linestyle='--', 
                   label=f'Current Threshold ({current_threshold:.2f})')
        ax.axvline(x=optimal_threshold, color='black', linestyle='-', 
                   label=f'Optimal Threshold ({optimal_threshold:.2f})')
        
        ax.set_xlabel('Confidence Threshold')
        ax.set_ylabel('Score')
        ax.set_title(f'Threshold Analysis for {model}')
        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.savefig(f"{output_path_base}_threshold_{model}.png", dpi=300)
        plt.close()
    
    # 2. Overall threshold analysis
    fig, ax = plt.subplots(figsize=(10, 6))
    
    overall_thresholds = sorted(list(threshold_analysis['overall']['metrics'].keys()))
    precision_values = [threshold_analysis['overall']['metrics'][t]['precision'] for t in overall_thresholds]
    recall_values = [threshold_analysis['overall']['metrics'][t]['recall'] for t in overall_thresholds]
    f1_values = [threshold_analysis['overall']['metrics'][t]['f1'] for t in overall_thresholds]
    
    ax.plot(overall_thresholds, precision_values, 'b-', label='Precision')
    ax.plot(overall_thresholds, recall_values, 'g-', label='Recall')
    ax.plot(overall_thresholds, f1_values, 'r-', label='F1 Score')
    
    # Mark current and optimal thresholds
    current_threshold = threshold_analysis['overall']['current_threshold']
    optimal_threshold = threshold_analysis['overall']['optimal_threshold']
    
    ax.axvline(x=current_threshold, color='gray', linestyle='--', 
               label=f'Current Threshold ({current_threshold:.2f})')
    ax.axvline(x=optimal_threshold, color='black', linestyle='-', 
               label=f'Optimal Threshold ({optimal_threshold:.2f})')
    
    ax.set_xlabel('Confidence Threshold')
    ax.set_ylabel('Score')
    ax.set_title('Overall Threshold Analysis')
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_threshold_overall.png", dpi=300)
    plt.close()
    
    # 3. Comparative threshold analysis
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create a bar chart comparing current vs. optimal thresholds
    model_names = all_models
    current_thresholds = [threshold_analysis[model]['current_threshold'] for model in model_names]
    optimal_thresholds = [threshold_analysis[model]['optimal_threshold'] for model in model_names]
    
    x = np.arange(len(model_names))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, current_thresholds, width, label='Current Threshold')
    bars2 = ax.bar(x + width/2, optimal_thresholds, width, label='Optimal Threshold')
    
    ax.set_xlabel('Model')
    ax.set_ylabel('Threshold Value')
    ax.set_title('Current vs. Optimal Thresholds by Model')
    ax.set_xticks(x)
    ax.set_xticklabels(model_names, rotation=45, ha='right')
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_threshold_comparison.png", dpi=300)
    plt.close()
    
    # Create a summary visualization
    fig, axs = plt.subplots(2, 2, figsize=(15, 12))
    
    # Top left: Model F1 scores
    if all_models:
        f1_values = [model_metrics[model]['f1'] for model in all_models]
        bars = axs[0, 0].bar(all_models, f1_values)
        axs[0, 0].set_title('F1 Score by Model')
        axs[0, 0].set_xlabel('Model')
        axs[0, 0].set_ylabel('F1 Score')
        axs[0, 0].set_xticklabels(all_models, rotation=45, ha='right')
        axs[0, 0].grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add labels
        for bar in bars:
            height = bar.get_height()
            axs[0, 0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{height:.2f}', ha='center', va='bottom')
    else:
        axs[0, 0].text(0.5, 0.5, 'No model data available', 
                     ha='center', va='center', transform=axs[0, 0].transAxes)
    
    # Top right: Sources accuracy (top 5)
    if source_accuracy:
        # Get top 5 most used sources
        top_sources = sorted(source_counts.items(), key=lambda x: x[1], reverse=True)[:5]
        top_source_names = [s[0] for s in top_sources]
        
        source_accuracies = [source_accuracy[s]['accuracy'] for s in top_source_names]
        source_counts_plot = [source_counts[s] for s in top_source_names]
        
        bars = axs[0, 1].bar(top_source_names, source_accuracies)
        axs[0, 1].set_title('Accuracy by Top 5 Confidence Sources')
        axs[0, 1].set_xlabel('Source')
        axs[0, 1].set_ylabel('Accuracy')
        axs[0, 1].set_xticklabels(top_source_names, rotation=45, ha='right')
        axs[0, 1].grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add count labels
        for i, bar in enumerate(bars):
            height = bar.get_height()
            axs[0, 1].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                          f'n={source_counts_plot[i]}', ha='center', va='bottom')
    else:
        axs[0, 1].text(0.5, 0.5, 'No source data available', 
                     ha='center', va='center', transform=axs[0, 1].transAxes)
    
    # Bottom left: Classification accuracy
    if research_area_stats or science_keyword_stats or division_stats:
        # Create a summary of classification performance
        classification_summary = {}
        
        # Average accuracy by type
        if research_area_stats:
            values = [s['accuracy'] for s in research_area_stats.values() if s['total'] >= 5]
            if values:
                classification_summary['Research Areas'] = np.mean(values)
                
        if science_keyword_stats:
            values = [s['accuracy'] for s in science_keyword_stats.values() if s['total'] >= 5]
            if values:
                classification_summary['Science Keywords'] = np.mean(values)
                
        if division_stats:
            values = [s['accuracy'] for s in division_stats.values() if s['total'] >= 5]
            if values:
                classification_summary['Divisions'] = np.mean(values)
        
        if classification_summary:
            names = list(classification_summary.keys())
            values = list(classification_summary.values())
            
            bars = axs[1, 0].bar(names, values)
            axs[1, 0].set_title('Average Accuracy by Classification Type')
            axs[1, 0].set_xlabel('Classification Type')
            axs[1, 0].set_ylabel('Average Accuracy')
            axs[1, 0].grid(axis='y', linestyle='--', alpha=0.7)
            
            # Add labels
            for bar in bars:
                height = bar.get_height()
                axs[1, 0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                              f'{height:.2f}', ha='center', va='bottom')
        else:
            axs[1, 0].text(0.5, 0.5, 'Insufficient classification data', 
                         ha='center', va='center', transform=axs[1, 0].transAxes)
    else:
        axs[1, 0].text(0.5, 0.5, 'No classification data available', 
                     ha='center', va='center', transform=axs[1, 0].transAxes)
    
    # Bottom right: Micro-average performance metrics
    micro_metrics = {
        'Precision': micro_precision,
        'Recall': micro_recall,
        'F1 Score': micro_f1
    }
    
    bars = axs[1, 1].bar(micro_metrics.keys(), micro_metrics.values())
    axs[1, 1].set_title('Overall Performance Metrics')
    axs[1, 1].set_xlabel('Metric')
    axs[1, 1].set_ylabel('Score')
    axs[1, 1].grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add labels
    for bar in bars:
        height = bar.get_height()
        axs[1, 1].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                      f'{height:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(f"{output_path_base}_summary.png", dpi=300)
    plt.close()
    
    # Save all metrics to JSON
    with open(f"{output_path_base}_complete.json", 'w') as f:
        # Convert NumPy values to Python types for JSON serialization
        def convert_numpy(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {k: convert_numpy(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [convert_numpy(i) for i in obj]
            else:
                return obj
        
        json.dump(convert_numpy(complete_metrics), f, indent=2)
    
    print(f"Comprehensive metrics visualizations saved with base path: {output_path_base}")
    
    # Return all metrics for further analysis
    return complete_metrics

In [None]:
def main():
    start_time = time.time()
    
    print("Starting optimized science publication classifier...")
    print(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    
    # Initialize classifier
    print("Initializing science classifier...")
    with torch.inference_mode():
        science_classifier = ScienceClassifier.get_instance()
    print("Science classifier initialized successfully.")
    
    # Load configuration files
    print("Loading configuration files...")
    curated_mapping = load_curated_models('./curated_publications.json')
    model_keywords = load_model_keywords('./model_keywords.json')
    model_descriptions = load_model_descriptions('./model_descriptions.json')
    
    # Initialize ranker
    print("Initializing relevance ranker...")
    with torch.inference_mode():
        ranker = RelevanceRanker(model_descriptions)
    
    # Initialize model embeddings
    print("Initializing model embeddings...")
    with torch.inference_mode():
        model_embeddings = initialize_model_embeddings(model_descriptions)
    
    # Clean memory
    optimize_memory()
    
    # Load test data
    print("Loading test publications...")
    try:
        with open('./labeled_test_data_plusmomo.json') as f:
            test_data = json.load(f)
        print(f"Loaded {len(test_data)} test publications.")
    except Exception as e:
        print(f"Error loading test data: {str(e)}")
        return
    
    # Initialize context manager
    print("Building context validation profiles...")
    with open('./curated_publications.json') as f:
        full_curated = json.load(f)
    
    with torch.inference_mode():
        context_manager = ModelContextManager(full_curated)
    
    # Process publications
    print("\nProcessing publications...")
    results = process_publication_batch(
        test_data, 
        curated_mapping, 
        model_keywords, 
        model_embeddings, 
        science_classifier,
        context_manager,
        ranker=ranker
    )
    
    # Save results
    print("Saving results to results.json...")
    with open('results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    # Generate visualizations with threshold analysis
    print("\nGenerating visualizations with threshold analysis...")
    metrics = visualize_metrics(results)
    
    # Show TF-IDF model-specific terms
    print("\nModel-specific terminology based on TF-IDF analysis:")
    model_specific_terms = context_manager.get_model_specific_terms()
    for model, terms in model_specific_terms.items():
        print(f"\n{model} distinctive terms:")
        print(", ".join(terms[:10]))  # Show top 10 terms
    
    # Find optimal thresholds
    print("\nFinding optimal thresholds...")
    optimal_thresholds = find_optimal_thresholds(results)
    
    print("\nOptimal model-specific thresholds:")
    for model, data in optimal_thresholds['per_model'].items():
        current = MODEL_THRESHOLDS.get(model, 0.4)
        print(f"{model}: {data['threshold']:.2f} (current: {current:.2f}, F1: {data['f1']:.2f})")
    
    print(f"\nOptimal overall threshold: {optimal_thresholds['overall']['threshold']:.2f}")
    print(f"Overall F1 score with optimal thresholds: {optimal_thresholds['overall']['f1']:.3f}")
    
    # Compare performance with current vs. optimal thresholds
    print("\nComparing performance with current vs. optimal thresholds:")
    current_performance = analyze_threshold_performance(results)
    optimal_performance = analyze_threshold_performance(
        results, 
        model_thresholds=optimal_thresholds['model_thresholds'],
        overall_threshold=optimal_thresholds['overall']['threshold']
    )
    
    print(f"Current F1: {current_performance['overall']['f1']:.3f}, " + 
          f"Optimal F1: {optimal_performance['overall']['f1']:.3f}, " +
          f"Improvement: {(optimal_performance['overall']['f1'] - current_performance['overall']['f1']) * 100:.1f}%")
    
    # Report completion
    total_time = time.time() - start_time
    print(f"\nProcessing completed in {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average time per publication: {total_time/len(test_data):.4f}s")
    
    if torch.cuda.is_available():
        print(f"Peak GPU memory usage: {torch.cuda.max_memory_allocated(0) / 1024**2:.2f} MB")

In [None]:
if __name__ == "__main__":
    main()