In [None]:
!pip install requests pandas numpy matplotlib seaborn plotly networkx bio

Collecting bio
  Downloading bio-1.8.0-py3-none-any.whl.metadata (5.7 kB)
Collecting biopython>=1.80 (from bio)
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting gprofiler-official (from bio)
  Downloading gprofiler_official-1.0.0-py3-none-any.whl.metadata (11 kB)
Collecting mygene (from bio)
  Downloading mygene-3.2.2-py2.py3-none-any.whl.metadata (10 kB)
Collecting biothings-client>=0.2.6 (from mygene->bio)
  Downloading biothings_client-0.4.1-py3-none-any.whl.metadata (10 kB)
Downloading bio-1.8.0-py3-none-any.whl (321 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m321.1/321.1 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m55.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gprofiler_official-1.0.0-py3-none-any.whl (9.3

In [12]:
import requests
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import json
import time
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import List, Dict, Tuple, Optional, Set
import logging
import concurrent.futures
import os
import warnings
warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class EnhancedLitVarAnalyzer:

    def __init__(self, email: str = "yl8889@nyu.edu", api_key: str = None,
                 impact_factor_csv: str = None):
        self.email = email
        self.api_key = api_key

        # API endpoints
        self.litvar_sensor = "https://www.ncbi.nlm.nih.gov/research/litvar2-api/sensor"
        self.pubmed_api = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

        # Cache for API responses
        self.cache = {}
        self.cache_dir = "./litvar_cache"
        os.makedirs(self.cache_dir, exist_ok=True)

        # Performance optimization settings
        self.max_workers = 20  # Increased for parallel processing
        self.batch_size = 100  # For batch API calls
        self.use_disk_cache = True  # Enable disk caching
        self.cache_ttl = 7 * 24 * 3600  # Cache for 7 days

        self.all_variant_data = {}

        # Load journal impact factors from CSV if provided
        self.journal_impact_factors = self.load_impact_factors(impact_factor_csv)

        # Session for connection pooling
        self.session = requests.Session()
        self.session.mount('https://', requests.adapters.HTTPAdapter(
            pool_connections=100,
            pool_maxsize=100,
            max_retries=3
        ))

    def load_impact_factors(self, csv_path: str = None) -> Dict[str, float]:
        """
        Load journal impact factors from CSV file using CiteScore column.
        Returns a dictionary mapping journal names to impact factors.
        """
        if not csv_path or not os.path.exists(csv_path):
            print("No impact factor CSV provided or file not found, using defaults")
            return {
                'default': 2.0,
                'high_impact': 10.0
            }

        print(f"Loading journal CiteScores from {csv_path}")

        try:
            df = pd.read_csv(csv_path)

            impact_dict = {}
            if 'CiteScore' in df.columns and 'Title' in df.columns:
                for _, row in df.iterrows():
                    journal_name = str(row['Title']).strip()
                    try:
                        citescore = float(row['CiteScore'])
                        if pd.notna(citescore) and citescore > 0:
                            impact_dict[journal_name.lower()] = citescore

                            if journal_name.lower() == 'ca-a cancer journal for clinicians':
                                impact_dict['ca cancer j clin'] = citescore
                                impact_dict['ca-a cancer j clin'] = citescore
                            elif journal_name.lower() == 'new england journal of medicine':
                                impact_dict['nejm'] = citescore
                                impact_dict['n engl j med'] = citescore
                                impact_dict['the new england journal of medicine'] = citescore
                            elif journal_name.lower() == 'nature':
                                impact_dict['nature'] = citescore
                            elif journal_name.lower() == 'cell':
                                impact_dict['cell'] = citescore
                            elif journal_name.lower() == 'science':
                                impact_dict['science'] = citescore

                    except (ValueError, TypeError):
                        continue

                print(f"   Loaded CiteScores for {len(impact_dict)} journals")

                sorted_items = sorted(impact_dict.items(), key=lambda x: x[1], reverse=True)
                print(f"   Highest CiteScores loaded:")
                for journal, score in sorted_items[:5]:
                    print(f"      {journal[:50]}: {score:.1f}")
            else:
                print(f"   Required columns (Title, CiteScore) not found in CSV")
                print(f"   Available columns: {list(df.columns)[:10]}")

            return impact_dict if impact_dict else {'default': 2.0}

        except Exception as e:
            print(f"   Error loading CiteScores: {str(e)}")
            return {'default': 2.0}

    def load_json_classification_changes(self, filepath: str) -> pd.DataFrame:
        """
        Load and preprocess classification changes from JSON file.
        """
        print(f"\n LOADING CLASSIFICATION CHANGES FROM JSON")
        print(f"   File: {filepath}")

        with open(filepath, 'r') as f:
            data = json.load(f)

        variants_list = []

        for variant in data.get('variants', []):
            if 'classification_changes' in variant and variant['classification_changes']:
                for change in variant['classification_changes']:
                    # Extract citation PMIDs that caused the change
                    citation_pmids = []
                    for citation in change.get('citations_new', []):
                        if 'pmid' in citation:
                            citation_pmids.append(citation['pmid'])

                    variant_record = {
                        'VariationID': variant['variation_id'],
                        'VCV': variant['vcv_accession'],
                        'RS# (dbSNP)_new': variant.get('rsid', ''),
                        'Gene': variant.get('gene', 'BRCA1'),
                        'Submitter': change.get('submitter', ''),
                        'date_old': change.get('date_old', ''),
                        'ClinicalSignificance_old': change.get('classification_old', ''),
                        'ClinicalSignificance_old_norm': self.normalize_classification(
                            change.get('classification_old', '')
                        ),
                        'date_new': change.get('date_new', ''),
                        'ClinicalSignificance_new': change.get('classification_new', ''),
                        'ClinicalSignificance_new_norm': self.normalize_classification(
                            change.get('classification_new', '')
                        ),
                        'SCV_old': change.get('scv_old', ''),
                        'SCV_new': change.get('scv_new', ''),
                        'change_type': change.get('change_type', ''),
                        'citation_pmids': citation_pmids,
                        'citation_count': len(citation_pmids)
                    }
                    variants_list.append(variant_record)

        df = pd.DataFrame(variants_list)

        for col in ['date_old', 'date_new']:
            df[col] = pd.to_datetime(df[col], errors='coerce')

        # Remove rows with missing critical data
        df = df.dropna(subset=['VariationID', 'ClinicalSignificance_new', 'date_new'])

        print(f"   Loaded {len(df)} classification changes")
        print(f"   Found {df['citation_count'].sum()} total citations")

        return df

    def normalize_classification(self, classification: str) -> str:
        """
        Normalize classification names.
        """
        if not classification:
            return 'Unknown'

        classification = classification.lower().strip()

        if 'pathogenic' in classification and 'likely' in classification:
            return 'Likely pathogenic'
        elif 'pathogenic' in classification:
            return 'Pathogenic'
        elif 'benign' in classification and 'likely' in classification:
            return 'Likely benign'
        elif 'benign' in classification:
            return 'Benign'
        elif 'uncertain' in classification or 'unknown' in classification:
            return 'Uncertain significance'
        elif 'conflict' in classification or 'multiple' in classification:
            return 'Conflicting interpretations'
        else:
            return 'Unknown'

    def fetch_publication_details(self, pmid: str) -> Dict:
        """
        Fetch publication details from PubMed including journal and publication date.
        """
        cache_key = f"pubmed_{pmid}"
        if cache_key in self.cache:
            return self.cache[cache_key]

        result = {
            'pmid': pmid,
            'title': '',
            'journal': '',
            'publication_date': None,
            'impact_factor': 1.0,  # Default impact factor
            'authors': [],
            'abstract': ''
        }

        try:
            # Fetch from PubMed
            url = f"{self.pubmed_api}/esummary.fcgi"
            params = {
                'db': 'pubmed',
                'id': pmid,
                'retmode': 'json',
                'email': self.email
            }

            if self.api_key:
                params['api_key'] = self.api_key

            response = requests.get(url, params=params, timeout=10)

            if response.status_code == 200:
                data = response.json()
                if 'result' in data and pmid in data['result']:
                    pub_data = data['result'][pmid]

                    result['title'] = pub_data.get('title', '')
                    result['journal'] = pub_data.get('fulljournalname', pub_data.get('source', ''))

                    # Parse publication date
                    pub_date = pub_data.get('pubdate', pub_data.get('epubdate', ''))
                    if pub_date:
                        try:
                            result['publication_date'] = pd.to_datetime(pub_date, errors='coerce')
                        except:
                            pass

                    # Get authors
                    if 'authors' in pub_data:
                        result['authors'] = [author.get('name', '') for author in pub_data['authors']]

                    # Assign impact factor based on journal
                    impact_assigned = False
                    journal_lower = result['journal'].lower()

                    # Try exact match first
                    if journal_lower in self.journal_impact_factors:
                        result['impact_factor'] = self.journal_impact_factors[journal_lower]
                        impact_assigned = True
                    else:
                        # Try partial matches
                        for journal_key, impact in self.journal_impact_factors.items():
                            if journal_key in journal_lower or journal_lower in journal_key:
                                result['impact_factor'] = impact
                                impact_assigned = True
                                break

                        # Try matching key words
                        if not impact_assigned:
                            journal_words = set(journal_lower.split())
                            for journal_key, impact in self.journal_impact_factors.items():
                                key_words = set(journal_key.split())
                                if len(journal_words & key_words) >= 2:
                                    result['impact_factor'] = impact
                                    impact_assigned = True
                                    break

                    # If no match found, assign based on journal type
                    if not impact_assigned:
                        if 'nature' in journal_lower:
                            result['impact_factor'] = 15.0
                        elif 'science' in journal_lower:
                            result['impact_factor'] = 12.0
                        elif 'cell' in journal_lower:
                            result['impact_factor'] = 10.0
                        elif 'genetics' in journal_lower or 'genomics' in journal_lower:
                            result['impact_factor'] = 5.0
                        elif 'cancer' in journal_lower or 'oncol' in journal_lower:
                            result['impact_factor'] = 6.0
                        elif 'medicine' in journal_lower or 'medical' in journal_lower:
                            result['impact_factor'] = 4.0
                        elif 'clinical' in journal_lower:
                            result['impact_factor'] = 3.5
                        elif 'molecular' in journal_lower:
                            result['impact_factor'] = 4.5
                        elif 'journal' in journal_lower:
                            result['impact_factor'] = 2.5
                        else:
                            result['impact_factor'] = 2.0  # Default for unknown journals

        except Exception as e:
            print(f"         Error fetching PubMed data for {pmid}: {str(e)[:100]}")

        self.cache[cache_key] = result
        time.sleep(0.2)  # Rate limiting

        return result

    def fetch_litvar_publications_with_details(self, rsid: str, citation_pmids: List[str] = None) -> Dict:
        """
        Fetch publications from LitVar2 API for a given rsID.
        """
        print(f"      Fetching LitVar2 publications for rs{rsid}")

        if not rsid.startswith('rs'):
            rsid = f'rs{rsid}'

        result = {
            'rsid': rsid,
            'pmids': [],
            'publications': [],
            'citation_publications': [],  # Publications that caused classification changes
            'other_publications': []  # Other related publications
        }

        try:
            api_url = f"https://www.ncbi.nlm.nih.gov/research/litvar2-api/variant/get/litvar@{rsid}%23%23/publications"

            print(f"         Calling LitVar2 API: {api_url}")

            # Make GET request with proper headers
            headers = {
                'Accept': 'application/json'
            }

            response = requests.get(
                api_url,
                headers=headers,
                timeout=30
            )

            print(f"         Response status: {response.status_code}")

            if response.status_code == 200:
                data = response.json()
                if isinstance(data, dict) and 'pmids' in data:
                    pmids = data['pmids']
                    result['pmids'] = [str(p) for p in pmids]
                    print(f"         Retrieved {len(result['pmids'])} PMIDs from LitVar2")
                    print(f"         PMIDs: {result['pmids'][:10]}...")  # Show first 10
                elif isinstance(data, list):
                    # Sometimes it might return a list directly
                    result['pmids'] = [str(p) for p in data if p]
                    print(f"         Retrieved {len(result['pmids'])} PMIDs from LitVar2 (list format)")
                else:
                    print(f"         Unexpected response format: {type(data)}")
                    print(f"         Response preview: {str(data)[:200]}")
            else:
                print(f"         API returned status {response.status_code}")
                print(f"         Response: {response.text[:500]}")

                # Try alternative LitVar2 endpoint without the @ symbol
                alt_url = f"https://www.ncbi.nlm.nih.gov/research/litvar2-api/variant/get/litvar{rsid}%23%23/publications"
                print(f"         Trying alternative format: {alt_url}")

                alt_response = requests.get(alt_url, headers=headers, timeout=30)
                if alt_response.status_code == 200:
                    alt_data = alt_response.json()
                    if isinstance(alt_data, dict) and 'pmids' in alt_data:
                        result['pmids'] = [str(p) for p in alt_data['pmids']]
                        print(f"         Retrieved {len(result['pmids'])} PMIDs with alternative format")

            # Also check the sensor endpoint for comparison
            sensor_url = f"https://www.ncbi.nlm.nih.gov/research/litvar2-api/sensor/{rsid}"
            try:
                sensor_response = requests.get(sensor_url, timeout=10)
                if sensor_response.status_code == 200:
                    sensor_data = sensor_response.json()
                    expected_count = sensor_data.get('pmids_count', 0)
                    if expected_count > 0:
                        actual_count = len(result['pmids'])
                        if actual_count > 0:
                            print(f"         Sensor shows {expected_count} publications, retrieved {actual_count} PMIDs")
                            if actual_count < expected_count:
                                print(f"          {expected_count - actual_count} PMIDs may be missing")
                        else:
                            print(f"         Sensor shows {expected_count} publications but couldn't retrieve PMIDs")
            except Exception as e:
                print(f"         Sensor check failed: {str(e)[:50]}")

            fetched_count = 0
            citations_matched = 0

            if result['pmids']:
                print(f"         Processing {len(result['pmids'][:100])} PMIDs for publication details...")

                for pmid in result['pmids'][:100]:  # Limit to first 100 for performance
                    # Clean the PMID
                    clean_pmid = str(pmid).strip()

                    # Fetch publication details from PubMed
                    pub_details = self.fetch_publication_details(clean_pmid)

                    if pub_details and (pub_details.get('journal') or pub_details.get('title')):
                        fetched_count += 1

                        # Check if this publication caused a classification change
                        # by comparing with citation_pmids from the JSON input
                        if citation_pmids:
                            # Normalize PMIDs for comparison
                            clean_citations = [str(p).strip() for p in citation_pmids]

                            if clean_pmid in clean_citations:
                                pub_details['caused_change'] = True
                                result['citation_publications'].append(pub_details)
                                citations_matched += 1
                                print(f"         PMID {clean_pmid} caused classification change")
                            else:
                                pub_details['caused_change'] = False
                                result['other_publications'].append(pub_details)
                        else:
                            pub_details['caused_change'] = False
                            result['other_publications'].append(pub_details)

                        result['publications'].append(pub_details)

                print(f"         Fetched details for {fetched_count} publications")
                if citation_pmids:
                    print(f"         {citations_matched} caused classification changes (matched from JSON)")
                    if citations_matched == 0 and len(citation_pmids) > 0:
                        print(f"         No matches found. JSON PMIDs: {citation_pmids[:3]}")
                        print(f"         LitVar2 PMIDs (first 3): {result['pmids'][:3] if result['pmids'] else 'None'}")
            else:
                print(f"         No PMIDs retrieved for {rsid}")

                # If no PMIDs from LitVar2, still process the citation PMIDs from JSON
                if citation_pmids:
                    print(f"         Processing citation PMIDs from JSON directly...")
                    for pmid in citation_pmids:
                        clean_pmid = str(pmid).strip()
                        pub_details = self.fetch_publication_details(clean_pmid)
                        if pub_details and (pub_details.get('journal') or pub_details.get('title')):
                            pub_details['caused_change'] = True
                            pub_details['source'] = 'ClinVar_citation_only'
                            result['citation_publications'].append(pub_details)
                            result['publications'].append(pub_details)
                            fetched_count += 1
                            citations_matched += 1
                            print(f"         Added citation PMID {clean_pmid} from PubMed (not in LitVar2)")

                    print(f"         Processed {fetched_count} citation PMIDs from JSON")

        except Exception as e:
            print(f"         Error fetching LitVar2 data: {str(e)}")
            import traceback
            traceback.print_exc()

        return result

    def load_cached_data(self, cache_key: str):
        """Load data from disk cache if available and not expired."""
        if not self.use_disk_cache:
            return None

        cache_file = os.path.join(self.cache_dir, f"{cache_key}.json")
        if os.path.exists(cache_file):
            # Check if cache is expired
            file_age = time.time() - os.path.getmtime(cache_file)
            if file_age < self.cache_ttl:
                try:
                    with open(cache_file, 'r') as f:
                        return json.load(f)
                except:
                    pass
        return None

    def save_to_cache(self, cache_key: str, data):
        """Save data to disk cache."""
        if not self.use_disk_cache:
            return

        cache_file = os.path.join(self.cache_dir, f"{cache_key}.json")
        try:
            with open(cache_file, 'w') as f:
                json.dump(data, f)
        except:
            pass

    def fetch_litvar_batch(self, rsids: List[str]) -> Dict[str, List[str]]:
        """
        Fetch PMIDs for multiple rsIDs in batches.
        """
        results = {}

        # Check cache first
        uncached_rsids = []
        for rsid in rsids:
            cache_key = f"litvar2_{rsid}"
            cached = self.load_cached_data(cache_key)
            if cached:
                results[rsid] = cached
            else:
                uncached_rsids.append(rsid)

        if not uncached_rsids:
            print(f"         All {len(rsids)} variants found in cache")
            return results

        print(f"         Found {len(results)} cached, fetching {len(uncached_rsids)} from LitVar2...")

        chunk_size = 20
        total_chunks = (len(uncached_rsids) + chunk_size - 1) // chunk_size

        for chunk_idx, i in enumerate(range(0, len(uncached_rsids), chunk_size)):
            chunk = uncached_rsids[i:i+chunk_size]
            print(f"         Processing chunk {chunk_idx + 1}/{total_chunks} ({len(chunk)} variants)...")

            # Use ThreadPoolExecutor with limited workers
            with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
                futures = {
                    executor.submit(self.fetch_single_litvar, rsid): rsid
                    for rsid in chunk
                }

                for future in concurrent.futures.as_completed(futures):
                    rsid = futures[future]
                    try:
                        pmids = future.result()
                        results[rsid] = pmids
                        # Cache the result
                        self.save_to_cache(f"litvar2_{rsid}", pmids)
                    except Exception as e:
                        print(f"            Failed for {rsid}: {str(e)[:50]}")
                        results[rsid] = []

            # avoid rate limiting
            if chunk_idx < total_chunks - 1:
                time.sleep(0.5)

        print(f"         Fetched data for {len(results)} variants total")
        return results

    def fetch_single_litvar(self, rsid: str) -> List[str]:
        """Fetch PMIDs for a single rsID from LitVar2 with retry logic."""
        if not rsid.startswith('rs'):
            rsid = f'rs{rsid}'

        api_url = f"https://www.ncbi.nlm.nih.gov/research/litvar2-api/variant/get/litvar@{rsid}%23%23/publications"

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.session.get(api_url, timeout=10)
                if response.status_code == 200:
                    data = response.json()
                    if isinstance(data, dict) and 'pmids' in data:
                        return [str(p) for p in data['pmids']]
                elif response.status_code == 429:  # Rate limited
                    print(f"            Rate limited, waiting...")
                    time.sleep(2 ** attempt)
                    continue
                return []
            except requests.exceptions.ConnectionError as e:
                if attempt < max_retries - 1:
                    print(f"            Connection error for {rsid}, retrying...")
                    time.sleep(1)
                    continue
                return []
            except Exception:
                return []
        return []

    def fetch_pubmed_batch(self, pmids: List[str]) -> Dict[str, Dict]:
        """
        Fetch details for multiple PMIDs in batch from PubMed.
        Much more efficient than individual requests.
        """
        results = {}

        # Check cache first
        uncached_pmids = []
        for pmid in pmids:
            cache_key = f"pubmed_{pmid}"
            if cache_key in self.cache:
                results[pmid] = self.cache[cache_key]
            else:
                cached = self.load_cached_data(cache_key)
                if cached:
                    results[pmid] = cached
                    self.cache[cache_key] = cached
                else:
                    uncached_pmids.append(pmid)

        if not uncached_pmids:
            return results

        # Batch fetch from PubMed
        print(f"         Fetching batch of {len(uncached_pmids)} publications from PubMed...")

        # PubMed allows up to 200 IDs per request
        chunk_size = 200
        for i in range(0, len(uncached_pmids), chunk_size):
            chunk = uncached_pmids[i:i+chunk_size]

            url = f"{self.pubmed_api}/esummary.fcgi"
            params = {
                'db': 'pubmed',
                'id': ','.join(chunk),
                'retmode': 'json',
                'email': self.email
            }

            if self.api_key:
                params['api_key'] = self.api_key

            try:
                response = self.session.get(url, params=params, timeout=30)
                if response.status_code == 200:
                    data = response.json()
                    if 'result' in data:
                        for pmid in chunk:
                            if pmid in data['result']:
                                pub_data = data['result'][pmid]
                                result = self.parse_pubmed_data(pub_data, pmid)
                                results[pmid] = result
                                # Cache the result
                                self.cache[f"pubmed_{pmid}"] = result
                                self.save_to_cache(f"pubmed_{pmid}", result)
            except Exception as e:
                print(f"         Batch PubMed fetch failed: {str(e)[:50]}")

        return results

    def parse_pubmed_data(self, pub_data: Dict, pmid: str) -> Dict:
        """Parse PubMed data into our format with STRICT journal matching."""
        result = {
            'pmid': pmid,
            'title': pub_data.get('title', ''),
            'journal': pub_data.get('fulljournalname', pub_data.get('source', '')),
            'publication_date': None,
            'impact_factor': 2.0,  # Default impact factor
            'authors': [],
            'abstract': ''
        }

        pub_date = pub_data.get('pubdate', pub_data.get('epubdate', ''))
        if pub_date:
            try:
                result['publication_date'] = pd.to_datetime(pub_date, errors='coerce')
            except:
                pass

        if 'authors' in pub_data:
            result['authors'] = [author.get('name', '') for author in pub_data['authors']]

        journal_name = result['journal']
        journal_lower = journal_name.lower().strip()

        # First try exact match
        if journal_lower in self.journal_impact_factors:
            result['impact_factor'] = self.journal_impact_factors[journal_lower]
            # debug
            if self.journal_impact_factors[journal_lower] > 600:
                print(f"         High CiteScore {self.journal_impact_factors[journal_lower]:.1f} for: {journal_name}")
        else:
            # For non-exact matches
            # Only match if it's a clear abbreviation or variant
            matched = False

            # Check for known abbreviations
            if 'ca cancer j clin' in journal_lower or 'ca-a cancer j clin' in journal_lower:
                if 'ca cancer j clin' in self.journal_impact_factors:
                    result['impact_factor'] = self.journal_impact_factors['ca cancer j clin']
                    matched = True
            elif 'n engl j med' in journal_lower or 'nejm' in journal_lower:
                if 'new england journal of medicine' in self.journal_impact_factors:
                    result['impact_factor'] = self.journal_impact_factors['new england journal of medicine']
                    matched = True

            # If still no match, use conservative defaults based on journal type
            if not matched:
                # DO NOT do partial string matching that could incorrectly assign high scores
                if 'nature' == journal_lower.split()[0] if journal_lower.split() else '':
                    result['impact_factor'] = 15.0
                elif 'science' == journal_lower:
                    result['impact_factor'] = 12.0
                elif 'cell' == journal_lower:
                    result['impact_factor'] = 10.0
                elif 'cancer' in journal_lower and 'ca' not in journal_lower[:5]:  # Avoid CA journal
                    result['impact_factor'] = 5.0
                elif 'genetics' in journal_lower or 'genomics' in journal_lower:
                    result['impact_factor'] = 4.0
                elif 'medicine' in journal_lower or 'medical' in journal_lower:
                    result['impact_factor'] = 3.5
                elif 'clinical' in journal_lower:
                    result['impact_factor'] = 3.0
                else:
                    result['impact_factor'] = 2.0  # Conservative default

        return result

    def process_all_variants_optimized(self, changes_df: pd.DataFrame) -> Dict:
        """
        Optimized version: Process all variants with batch operations and parallel processing.
        """
        print(f"\n FETCHING PUBLICATIONS FOR ALL VARIANTS (OPTIMIZED)")

        all_variant_data = {}

        # Group by VariationID
        grouped = changes_df.groupby('VariationID').agg({
            'RS# (dbSNP)_new': 'first',
            'date_new': 'max',
            'date_old': 'min',
            'ClinicalSignificance_old_norm': 'first',
            'ClinicalSignificance_new_norm': 'last',
            'citation_pmids': lambda x: [pmid for sublist in x for pmid in sublist]
        }).reset_index()

        print(f"   Found {len(grouped)} unique variants to process")

        # Filter valid rsIDs
        valid_variants = grouped[
            (grouped['RS# (dbSNP)_new'].notna()) &
            (grouped['RS# (dbSNP)_new'] != 'None')
        ].copy()

        print(f"   {len(valid_variants)} variants with valid rsIDs")

        # Batch fetch all LitVar data
        print(f"\n   BATCH FETCHING FROM LITVAR2...")
        all_rsids = valid_variants['RS# (dbSNP)_new'].tolist()

        batch_size = 100  # Process 100 variants at a time
        litvar_results = {}

        for batch_start in range(0, len(all_rsids), batch_size):
            batch_end = min(batch_start + batch_size, len(all_rsids))
            batch_rsids = all_rsids[batch_start:batch_end]

            print(f"\n   Processing variants {batch_start + 1}-{batch_end} of {len(all_rsids)}...")
            batch_results = self.fetch_litvar_batch(batch_rsids)
            litvar_results.update(batch_results)

            # Show progress
            if batch_end < len(all_rsids):
                print(f"   Progress: {batch_end}/{len(all_rsids)} variants ({100*batch_end/len(all_rsids):.1f}%)")

        print(f"\n   Completed LitVar fetching for {len(litvar_results)} variants")

        # Collect all unique PMIDs
        print(f"\n   Collecting unique PMIDs...")
        all_pmids = set()
        for pmids in litvar_results.values():
            all_pmids.update(pmids)

        # Add citation PMIDs
        for citation_pmids in valid_variants['citation_pmids']:
            if citation_pmids:
                all_pmids.update([str(p) for p in citation_pmids])

        print(f"   Found {len(all_pmids)} unique PMIDs total")

        # Batch fetch all PubMed data
        print(f"\n   BATCH FETCHING PUBLICATIONS FROM PUBMED...")
        pubmed_results = {}

        # Process PubMed in batches
        pmid_list = list(all_pmids)
        pubmed_batch_size = 500  # PubMed can handle larger batches

        for batch_start in range(0, len(pmid_list), pubmed_batch_size):
            batch_end = min(batch_start + pubmed_batch_size, len(pmid_list))
            batch_pmids = pmid_list[batch_start:batch_end]

            print(f"   Fetching PubMed batch {batch_start + 1}-{batch_end} of {len(pmid_list)}...")
            batch_results = self.fetch_pubmed_batch(batch_pmids)
            pubmed_results.update(batch_results)

        print(f"   Fetched details for {len(pubmed_results)} publications")

        # Process each variant with the pre-fetched data
        print(f"\n   PROCESSING VARIANT DATA...")

        for idx, row in valid_variants.iterrows():
            variant_id = str(row['VariationID'])
            rsid = str(row['RS# (dbSNP)_new'])
            citation_pmids = row['citation_pmids'] if row['citation_pmids'] else []

            litvar_pmids = litvar_results.get(rsid, [])

            publications = []
            citation_publications = []
            other_publications = []

            # Process LitVar publications
            for pmid in litvar_pmids:
                if pmid in pubmed_results:
                    pub_details = pubmed_results[pmid].copy()

                    # Check if it caused a classification change
                    if str(pmid) in [str(p) for p in citation_pmids]:
                        pub_details['caused_change'] = True
                        citation_publications.append(pub_details)
                    else:
                        pub_details['caused_change'] = False
                        other_publications.append(pub_details)

                    publications.append(pub_details)

            # Add citation PMIDs not in LitVar
            for pmid in citation_pmids:
                pmid_str = str(pmid)
                if pmid_str not in litvar_pmids and pmid_str in pubmed_results:
                    pub_details = pubmed_results[pmid_str].copy()
                    pub_details['caused_change'] = True
                    pub_details['source'] = 'ClinVar_citation_only'
                    citation_publications.append(pub_details)
                    publications.append(pub_details)

            all_variant_data[variant_id] = {
                'variant_id': variant_id,
                'rsid': rsid,
                'change_date': row['date_new'],
                'date_old': row['date_old'],
                'old_class': row['ClinicalSignificance_old_norm'],
                'new_class': row['ClinicalSignificance_new_norm'],
                'citation_pmids_from_json': citation_pmids,
                'publications': publications,
                'citation_publications': citation_publications,
                'other_publications': other_publications,
                'total_publications': len(publications),
                'total_litvar_pmids': len(litvar_pmids),
                'citations_matched': len(citation_publications)
            }

            if (idx + 1) % 500 == 0:
                print(f"      Processed {idx + 1}/{len(valid_variants)} variants...")

        stats = {
            'total_variants': len(valid_variants),
            'variants_with_pmids': sum(1 for v in all_variant_data.values() if v['total_litvar_pmids'] > 0),
            'perfect_matches': sum(1 for v in all_variant_data.values()
                                 if v['citation_pmids_from_json'] and
                                 v['citations_matched'] == len(v['citation_pmids_from_json'])),
            'partial_matches': sum(1 for v in all_variant_data.values()
                                 if v['citation_pmids_from_json'] and
                                 0 < v['citations_matched'] < len(v['citation_pmids_from_json'])),
            'no_matches': sum(1 for v in all_variant_data.values()
                            if v['citation_pmids_from_json'] and v['citations_matched'] == 0)
        }

        print(f"\n   PROCESSING SUMMARY:")
        print(f"      Total variants: {stats['total_variants']}")
        print(f"      Variants with LitVar PMIDs: {stats['variants_with_pmids']}")
        print(f"      Perfect citation matches: {stats['perfect_matches']}")
        print(f"      Partial citation matches: {stats['partial_matches']}")
        print(f"      No citation matches: {stats['no_matches']}")

        print(f"\n   Completed processing {len(all_variant_data)} variants")
        self.all_variant_data = all_variant_data
        return all_variant_data
        """
        Process all variants and fetch their publication data.
        Enhanced to handle cases where citation PMIDs aren't in LitVar results.
        """
        print(f"\n FETCHING PUBLICATIONS FOR ALL VARIANTS")

        all_variant_data = {}

        # Group by VariationID to combine citation PMIDs
        grouped = changes_df.groupby('VariationID').agg({
            'RS# (dbSNP)_new': 'first',
            'date_new': 'max',
            'date_old': 'min',
            'ClinicalSignificance_old_norm': 'first',
            'ClinicalSignificance_new_norm': 'last',
            'citation_pmids': lambda x: [pmid for sublist in x for pmid in sublist]
        }).reset_index()

        print(f"   Found {len(grouped)} unique variants to process")

        # Track statistics
        stats = {
            'total_variants': len(grouped),
            'variants_with_pmids': 0,
            'perfect_matches': 0,
            'partial_matches': 0,
            'no_matches': 0
        }

        for idx, row in grouped.iterrows():
            variant_id = str(row['VariationID'])
            rsid = str(row['RS# (dbSNP)_new'])

            if rsid and rsid != 'nan' and rsid != 'None':
                print(f"\n   Processing variant {variant_id} (rs{rsid})")

                # Show citation PMIDs from JSON
                citation_pmids = row['citation_pmids']
                if citation_pmids:
                    print(f"      Citation PMIDs from JSON: {citation_pmids[:5]}..." if len(citation_pmids) > 5 else f"      Citation PMIDs from JSON: {citation_pmids}")

                # Fetch from LitVar
                pub_data = self.fetch_litvar_publications_with_details(rsid, citation_pmids)

                # If no matches found, also fetch the citation PMIDs directly from PubMed
                if citation_pmids and len(pub_data['citation_publications']) == 0:
                    print(f"      Fetching citation PMIDs directly from PubMed...")
                    for pmid in citation_pmids:
                        clean_pmid = str(pmid).strip()
                        pub_details = self.fetch_publication_details(clean_pmid)
                        if pub_details and (pub_details.get('journal') or pub_details.get('title')):
                            pub_details['caused_change'] = True
                            pub_details['source'] = 'ClinVar_citation'
                            pub_data['citation_publications'].append(pub_details)
                            pub_data['publications'].append(pub_details)
                            print(f"         Added citation PMID {clean_pmid} from PubMed")

                # Update statistics
                if pub_data['pmids']:
                    stats['variants_with_pmids'] += 1

                if citation_pmids:
                    if len(pub_data['citation_publications']) == len(citation_pmids):
                        stats['perfect_matches'] += 1
                    elif len(pub_data['citation_publications']) > 0:
                        stats['partial_matches'] += 1
                    else:
                        stats['no_matches'] += 1

                all_variant_data[variant_id] = {
                    'variant_id': variant_id,
                    'rsid': rsid,
                    'change_date': row['date_new'],
                    'date_old': row['date_old'],
                    'old_class': row['ClinicalSignificance_old_norm'],
                    'new_class': row['ClinicalSignificance_new_norm'],
                    'citation_pmids_from_json': citation_pmids,
                    'publications': pub_data['publications'],
                    'citation_publications': pub_data['citation_publications'],
                    'other_publications': pub_data['other_publications'],
                    'total_publications': len(pub_data['publications']),
                    'total_litvar_pmids': len(pub_data.get('pmids', [])),
                    'citations_matched': len(pub_data['citation_publications'])
                }

        print(f"\n   PROCESSING SUMMARY:")
        print(f"      Total variants: {stats['total_variants']}")
        print(f"      Variants with LitVar PMIDs: {stats['variants_with_pmids']}")
        print(f"      Perfect citation matches: {stats['perfect_matches']}")
        print(f"      Partial citation matches: {stats['partial_matches']}")
        print(f"      No citation matches: {stats['no_matches']}")

        self.all_variant_data = all_variant_data
        return all_variant_data

    def calculate_variant_representativeness(self, variant_data: Dict) -> float:
        """
        Calculate how representative a variant is for visualization.
        """
        score = 0

        # Citations that caused changes, HIGHEST PRIORITY
        score += len(variant_data['citation_publications']) * 25

        # Number of total publications
        score += min(variant_data['total_publications'], 50) * 2

        # Classification change significance
        change_sig = self.calculate_change_significance(
            variant_data['old_class'],
            variant_data['new_class']
        )
        score += change_sig * 15

        # Publication diversity
        if variant_data['publications']:
            journals = set(p['journal'] for p in variant_data['publications'] if p.get('journal'))
            score += len(journals) * 3

        # High impact publications
        for pub in variant_data['publications']:
            if pub.get('impact_factor', 0) > 10:
                score += 5
            elif pub.get('impact_factor', 0) > 5:
                score += 2

        # If have both LitVar and citation data
        if variant_data.get('total_litvar_pmids', 0) > 0 and variant_data.get('citations_matched', 0) > 0:
            score += 20

        return score

    def calculate_change_significance(self, old_class: str, new_class: str) -> float:
        """
        Calculate significance score for classification change.
        """
        significance_map = {
            'Benign': 0,
            'Likely benign': 1,
            'Uncertain significance': 2,
            'Conflicting interpretations': 2.5,
            'Likely pathogenic': 3,
            'Pathogenic': 4,
            'Unknown': 2
        }

        old_score = significance_map.get(old_class, 2)
        new_score = significance_map.get(new_class, 2)

        return abs(new_score - old_score)

    def select_top_variants(self, all_variant_data: Dict, top_n: int = 10) -> Dict:
        """
        Select top N most representative variants.
        """
        print(f"\n SELECTING TOP {top_n} REPRESENTATIVE VARIANTS")

        variant_scores = []

        for variant_id, data in all_variant_data.items():
            score = self.calculate_variant_representativeness(data)
            variant_scores.append({
                'variant_id': variant_id,
                'score': score,
                'data': data
            })

        # Sort by score
        variant_scores.sort(key=lambda x: x['score'], reverse=True)

        top_variants = {}
        for item in variant_scores[:top_n]:
            top_variants[item['variant_id']] = item['data']
            print(f"   Selected: Variant {item['variant_id']} (Score: {item['score']:.1f})")
            print(f"      Change: {item['data']['old_class']} → {item['data']['new_class']}")
            print(f"      Publications: {item['data']['total_publications']}")
            print(f"      Citations causing change: {len(item['data']['citation_publications'])}")

        return top_variants

    def create_timeline_scatter_visualization(self, top_variants: Dict) -> go.Figure:
        """
        Create timeline scatter plot with journal CiteScores.
        Y-axis: Journal CiteScore, LOG SCALED
        X-axis: Publication date
        Red dots for publications that caused classification changes
        Blue dots for other publications
        """
        print(f"\n CREATING TIMELINE SCATTER VISUALIZATION")

        n_variants = len(top_variants)

        vertical_spacing = 0.02

        fig = make_subplots(
            rows=n_variants,
            cols=1,
            subplot_titles=[
                f"Variant {vid} (rs{data['rsid']}): {data['old_class']} → {data['new_class']}"
                for vid, data in top_variants.items()
            ],
            vertical_spacing=vertical_spacing,
            row_heights=[1/n_variants] * n_variants
        )

        for idx, (variant_id, data) in enumerate(top_variants.items(), 1):
            citation_pubs = []  # Publications that caused changes
            other_pubs = []     # Other publications

            for pub in data['publications']:
                if pub['publication_date'] is not None:
                    if pub.get('impact_factor', 2.0) <= 0:
                        pub['impact_factor'] = 1.0

                    if pub.get('caused_change', False):
                        citation_pubs.append(pub)
                    else:
                        other_pubs.append(pub)

            # Plot other publications with blue dots
            if other_pubs:
                dates = [p['publication_date'] for p in other_pubs]
                impacts = [float(p.get('impact_factor', 2.0)) for p in other_pubs]

                # Create hover text
                labels = []
                for p in other_pubs:
                    original_score = float(p.get('impact_factor', 2.0))
                    label = (
                        f"PMID: {p['pmid']}<br>"
                        f"Journal: {p.get('journal', 'Unknown')[:50]}<br>"
                        f"CiteScore: {original_score:.1f}<br>"
                        f"Title: {p.get('title', '')[:100]}..."
                    )
                    labels.append(label)

                fig.add_trace(
                    go.Scatter(
                        x=dates,
                        y=impacts,
                        mode='markers',
                        name='Related Publications',
                        marker=dict(
                            size=8,
                            color='blue',
                            opacity=0.6,
                            line=dict(width=1, color='darkblue')
                        ),
                        text=labels,
                        hovertemplate='%{text}<br>Date: %{x|%Y-%m}<extra></extra>',
                        showlegend=(idx == 1)
                    ),
                    row=idx, col=1
                )

            # Plot citation publications that caused changes with red dots
            if citation_pubs:
                dates = [p['publication_date'] for p in citation_pubs]
                impacts = [float(p.get('impact_factor', 2.0)) for p in citation_pubs]

                labels = []
                for p in citation_pubs:
                    original_score = float(p.get('impact_factor', 2.0))
                    label = (
                        f"CAUSED CLASSIFICATION CHANGE <br>"
                        f"PMID: {p['pmid']}<br>"
                        f"Journal: {p.get('journal', 'Unknown')[:50]}<br>"
                        f"CiteScore: {original_score:.1f}<br>"
                        f"Title: {p.get('title', '')[:100]}..."
                    )
                    labels.append(label)

                fig.add_trace(
                    go.Scatter(
                        x=dates,
                        y=impacts,
                        mode='markers',
                        name='Publications Causing Change',
                        marker=dict(
                            size=10,
                            color='red',
                            opacity=0.9,
                            line=dict(width=2, color='darkred')
                        ),
                        text=labels,
                        hovertemplate='%{text}<br>Date: %{x|%Y-%m}<extra></extra>',
                        showlegend=(idx == 1)
                    ),
                    row=idx, col=1
                )

            # Add vertical line for classification change date
            if data.get('change_date') and pd.notna(data['change_date']):
                change_date = pd.to_datetime(data['change_date'])
                if pd.notna(change_date):
                    fig.add_shape(
                        type="line",
                        x0=change_date, x1=change_date,
                        y0=0.5, y1=1000,
                        yref=f"y{idx}" if idx > 1 else "y",
                        line=dict(
                            color="red",
                            width=2,
                            dash="dash"
                        ),
                        opacity=0.4,
                        row=idx, col=1
                    )
                    fig.add_annotation(
                        x=change_date,
                        y=1000,
                        yref=f"y{idx}" if idx > 1 else "y",
                        text="Changed",
                        showarrow=False,
                        font=dict(size=9, color="red"),
                        yshift=10,
                        row=idx, col=1
                    )

            # Add vertical line for previous classification date
            if data.get('date_old') and pd.notna(data['date_old']):
                old_date = pd.to_datetime(data['date_old'])
                if pd.notna(old_date):
                    fig.add_shape(
                        type="line",
                        x0=old_date, x1=old_date,
                        y0=0.5, y1=1000,
                        yref=f"y{idx}" if idx > 1 else "y",
                        line=dict(
                            color="orange",
                            width=2,
                            dash="dot"
                        ),
                        opacity=0.3,
                        row=idx, col=1
                    )
                    fig.add_annotation(
                        x=old_date,
                        y=1000,
                        yref=f"y{idx}" if idx > 1 else "y",
                        text="Previous",
                        showarrow=False,
                        font=dict(size=9, color="orange"),
                        yshift=10,
                        row=idx, col=1
                    )
            fig.update_yaxes(
                title_text="CiteScore (log scale)",
                title_font_size=10,
                type="log",
                row=idx, col=1,
                range=[np.log10(0.5), np.log10(1000)],
                tickvals=[0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000],
                ticktext=['0.5', '1', '2', '5', '10', '20', '50', '100', '200', '500', '1000'],  # Labels
                tickfont=dict(size=9),
                showgrid=True,
                gridwidth=0.5,
                gridcolor='lightgray',
                minor=dict(showgrid=True, gridcolor='#f0f0f0', gridwidth=0.3)  # Minor gridlines
            )

            fig.update_xaxes(
                tickfont=dict(size=9),
                row=idx, col=1,
                showgrid=True,
                gridwidth=0.5,
                gridcolor='lightgray'
            )

        fig.update_layout(
            height=500 * n_variants,
            title_text="Publication Timeline with Journal CiteScores (Log Scale)<br>" +
                    "<sub>Red: citations that caused classification change | Blue: other related publications | " +
                    "Dashed lines: classification change dates</sub>",
            title_font_size=14,
            hovermode='closest',
            showlegend=True,
            legend=dict(
                orientation="h",
                yanchor="top",
                y=1.003,
                xanchor="center",
                x=0.5,
                font=dict(size=10),
                bgcolor="rgba(255,255,255,0.8)"
            ),
            margin=dict(t=60, b=30, l=60, r=30),
            plot_bgcolor='white',
            paper_bgcolor='white'
        )

        fig.update_xaxes(
            title_text="Publication Date",
            title_font_size=11,
            row=n_variants, col=1
        )

        return fig

    def generate_summary_report(self, all_variant_data: Dict, top_variants: Dict, output_dir: str):
        """
        Generate comprehensive summary report.
        """
        report = []
        report.append("=" * 80)
        report.append("LITVAR ANALYSIS REPORT - JOURNAL IMPACT FACTOR TIMELINE")
        report.append("=" * 80)
        report.append(f"Analysis Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        report.append("")

        total_variants = len(all_variant_data)
        total_pubs = sum(d['total_publications'] for d in all_variant_data.values())
        total_citations = sum(len(d['citation_publications']) for d in all_variant_data.values())

        report.append("OVERALL STATISTICS:")
        report.append(f"  Total variants analyzed: {total_variants}")
        report.append(f"  Total publications found: {total_pubs}")
        report.append(f"  Publications causing classification changes: {total_citations}")
        report.append("")

        report.append("TOP 10 REPRESENTATIVE VARIANTS:")
        for i, (variant_id, data) in enumerate(top_variants.items(), 1):
            report.append(f"\n  {i}. Variant {variant_id} (rs{data['rsid']}):")
            report.append(f"     Classification: {data['old_class']} → {data['new_class']}")
            report.append(f"     Total publications: {data['total_publications']}")
            report.append(f"     Publications causing change: {len(data['citation_publications'])}")

            if data['citation_publications']:
                report.append("     Key publications causing change:")
                for pub in data['citation_publications'][:3]:
                    report.append(f"       - PMID {pub['pmid']}: {pub['journal']} (IF: {pub['impact_factor']:.1f})")

        report.append("")
        report.append("JOURNAL IMPACT FACTOR DISTRIBUTION:")
        all_impacts = []
        for data in all_variant_data.values():
            all_impacts.extend([p['impact_factor'] for p in data['publications']])

        if all_impacts:
            report.append(f"  Average impact factor: {np.mean(all_impacts):.2f}")
            report.append(f"  Median impact factor: {np.median(all_impacts):.2f}")
            report.append(f"  Max impact factor: {np.max(all_impacts):.2f}")

        report_text = "\n".join(report)

        with open(f"{output_dir}/analysis_report.txt", 'w') as f:
            f.write(report_text)

        print("\n" + report_text)

    def run_complete_analysis(self, json_file: str, output_dir: str = './litvar_results'):
        """
        Run complete analysis pipeline with JSON input
        """
        print(f"\n{'='*80}")
        print("STARTING LITVAR ANALYSIS WITH JSON INPUT")
        print(f"{'='*80}")

        start_time = time.time()

        os.makedirs(output_dir, exist_ok=True)

        # Load classification changes from JSON
        changes_df = self.load_json_classification_changes(json_file)

        if changes_df.empty:
            print("No classification changes found in JSON")
            return None

        changes_df.to_csv(f"{output_dir}/processed_changes.csv", index=False)

        # Process all variants using batch processing
        all_variant_data = self.process_all_variants_optimized(changes_df)

        # Select top 10 representative variants
        top_variants = self.select_top_variants(all_variant_data, top_n=10)

        # Create timeline scatter visualization
        timeline_fig = self.create_timeline_scatter_visualization(top_variants)
        timeline_fig.write_html(f"{output_dir}/timeline_scatter_impact_factors.html")
        print(f"   Saved timeline scatter plot visualization")

        # Generate ML features
        ml_features = self.generate_ml_features(all_variant_data)
        ml_features.to_csv(f"{output_dir}/ml_features.csv", index=False)
        print(f"   Saved ML features for {len(ml_features)} variants")

        # Generate summary report
        self.generate_summary_report(all_variant_data, top_variants, output_dir)

        elapsed_time = time.time() - start_time

        print(f"\n{'='*80}")
        print("ANALYSIS COMPLETE")
        print(f"{'='*80}")
        print(f"Total runtime: {elapsed_time:.1f} seconds ({elapsed_time/60:.1f} minutes)")
        print(f"Results saved to: {output_dir}")


    def generate_ml_features(self, all_variant_data: Dict) -> pd.DataFrame:
        """
        Generate ML features for model training.
        """
        print(f"\n GENERATING ML FEATURES")

        features_list = []

        for variant_id, data in all_variant_data.items():
            feature_dict = {
                'variant_id': variant_id,
                'rsid': data['rsid'],
                'old_class': data['old_class'],
                'new_class': data['new_class'],
                'total_publications': data['total_publications'],
                'citations_causing_change': len(data['citation_publications']),
                'change_significance': self.calculate_change_significance(
                    data['old_class'], data['new_class']
                )
            }

            # Publication temporal features
            if data['publications']:
                valid_pubs = [p for p in data['publications'] if p['publication_date']]
                if valid_pubs and data['change_date']:
                    # Publications before and after change
                    before = sum(1 for p in valid_pubs if p['publication_date'] < data['change_date'])
                    after = sum(1 for p in valid_pubs if p['publication_date'] >= data['change_date'])
                    feature_dict['pubs_before_change'] = before
                    feature_dict['pubs_after_change'] = after

                    # Recent publication surge (6 months before change)
                    six_months_before = data['change_date'] - pd.Timedelta(days=180)
                    recent_pubs = sum(1 for p in valid_pubs
                                    if six_months_before <= p['publication_date'] < data['change_date'])
                    feature_dict['recent_publication_surge'] = recent_pubs

            # Impact factor features
            if data['publications']:
                impacts = [p['impact_factor'] for p in data['publications']]
                feature_dict['mean_impact_factor'] = np.mean(impacts) if impacts else 0
                feature_dict['max_impact_factor'] = np.max(impacts) if impacts else 0
                feature_dict['high_impact_count'] = sum(1 for i in impacts if i > 10)

            # Journal diversity
            if data['publications']:
                journals = set(p['journal'] for p in data['publications'] if p['journal'])
                feature_dict['unique_journals'] = len(journals)

            features_list.append(feature_dict)

        features_df = pd.DataFrame(features_list)
        features_df = features_df.fillna(0)

        print(f"   Generated {len(features_df)} feature rows with {len(features_df.columns)} features")

        return features_df


def main():
    EMAIL = "yl8889@nyu.edu"
    API_KEY = ""
    JSON_FILE = "/content/BRCA1_changes_parallel_20250815_234141.json"
    IMPACT_FACTOR_CSV = "/content/journal_ranking_data.csv"
    OUTPUT_DIR = "./litvar_analysis_results"

    print("Initializing Enhanced LitVar Analyzer...")
    analyzer = EnhancedLitVarAnalyzer(
        email=EMAIL,
        api_key=API_KEY,
        impact_factor_csv=IMPACT_FACTOR_CSV
    )

    results = analyzer.run_complete_analysis(JSON_FILE, OUTPUT_DIR)



if __name__ == "__main__":
    main()

Initializing Enhanced LitVar Analyzer...
Loading journal CiteScores from /content/journal_ranking_data.csv
   Loaded CiteScores for 17145 journals
   Highest CiteScores loaded:
      ca-a cancer journal for clinicians: 642.9
      ca cancer j clin: 642.9
      ca-a cancer j clin: 642.9
      nature reviews molecular cell biology: 164.4
      new england journal of medicine: 134.4

STARTING LITVAR ANALYSIS WITH JSON INPUT

 LOADING CLASSIFICATION CHANGES FROM JSON
   File: /content/BRCA1_changes_parallel_20250815_234141.json
   Loaded 6318 classification changes
   Found 12460 total citations

 FETCHING PUBLICATIONS FOR ALL VARIANTS (OPTIMIZED)
   Found 3436 unique variants to process
   3436 variants with valid rsIDs

   BATCH FETCHING FROM LITVAR2...

   Processing variants 1-100 of 3436...
         Found 43 cached, fetching 57 from LitVar2...
         Processing chunk 1/3 (20 variants)...
         Processing chunk 2/3 (20 variants)...
         Processing chunk 3/3 (17 variants)...
  