In [None]:
!pip install Bio



In [1]:
import pandas as pd
import requests
import xml.etree.ElementTree as ET
from datetime import datetime
import time
import json
from pathlib import Path
from typing import List, Dict, Set, Optional
from collections import defaultdict
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import gzip
import pickle
from functools import lru_cache
import queue

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ClinVarParallelAPI:
    def __init__(self, email: str, api_key: str = None, cache_dir: str = "clinvar_cache_parallel"):
        self.email = email
        self.api_key = api_key
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        # Create subdirectories for different cache types
        self.xml_cache_dir = self.cache_dir / "xml"
        self.xml_cache_dir.mkdir(exist_ok=True)
        self.parsed_cache_dir = self.cache_dir / "parsed"
        self.parsed_cache_dir.mkdir(exist_ok=True)
        self.stats_cache = self.cache_dir / "stats.json"

        self.eutils_base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

        self.requests_per_second = 10 if api_key else 3
        self.rate_limiter = threading.Semaphore(self.requests_per_second)
        self.rate_limit_lock = threading.Lock()
        self.last_request_times = []

        print(f"[INIT] Parallel ClinVar API initialized")
        print(f"[INIT] Email: {self.email}")
        print(f"[INIT] API Key: {'Configured' if self.api_key else 'Not provided'}")
        print(f"[INIT] Rate limit: {self.requests_per_second} requests/second")

    def search_brca1_variants(self, limit: int = None) -> List[str]:
        """
        Search for ALL BRCA1 variants in ClinVar
        """
        print("\n" + "="*80)
        print("SEARCHING FOR ALL BRCA1 VARIANTS IN CLINVAR")
        print("="*80)

        # Check if cached variant list
        variants_cache = self.cache_dir / "brca1_variant_ids.json"
        if variants_cache.exists() and not limit:
            with open(variants_cache, 'r') as f:
                cached_ids = json.load(f)
                print(f"[CACHE] Using cached list of {len(cached_ids)} BRCA1 variants")
                return cached_ids

        variation_ids = []

        try:
            search_url = f"{self.eutils_base}/esearch.fcgi"
            search_query = 'BRCA1[gene]'

            # Get count first
            count_params = {
                'db': 'clinvar',
                'term': search_query,
                'retmode': 'json',
                'retmax': 0,
                'email': self.email
            }

            if self.api_key:
                count_params['api_key'] = self.api_key

            response = requests.get(search_url, params=count_params, timeout=30)
            response.raise_for_status()
            data = response.json()

            total_count = int(data.get('esearchresult', {}).get('count', 0))
            print(f"[SEARCH] Found {total_count} BRCA1 records in ClinVar")

            if limit:
                retmax = min(limit, total_count)
            else:
                retmax = total_count

            batch_size = 500
            for start in range(0, retmax, batch_size):
                params = {
                    'db': 'clinvar',
                    'term': search_query,
                    'retmode': 'json',
                    'retstart': start,
                    'retmax': min(batch_size, retmax - start),
                    'email': self.email
                }

                if self.api_key:
                    params['api_key'] = self.api_key

                response = requests.get(search_url, params=params, timeout=30)
                response.raise_for_status()
                data = response.json()

                ids = data.get('esearchresult', {}).get('idlist', [])
                variation_ids.extend(ids)

                print(f"[SEARCH] Retrieved {len(variation_ids)}/{retmax} IDs")

                if limit and len(variation_ids) >= limit:
                    break

                time.sleep(0.1)

            if not limit and len(variation_ids) == total_count:
                with open(variants_cache, 'w') as f:
                    json.dump(variation_ids, f)
                print(f"[CACHE] Saved complete variant list")

            return variation_ids[:limit] if limit else variation_ids

        except Exception as e:
            print(f"[ERROR] Search failed: {e}")
            return []

    def _rate_limit_wait(self):
        """
        Thread-safe rate limiting
        """
        with self.rate_limit_lock:
            current_time = time.time()
            # Remove requests older than 1 second
            self.last_request_times = [t for t in self.last_request_times if current_time - t < 1.0]

            # If made too many requests, wait
            if len(self.last_request_times) >= self.requests_per_second:
                sleep_time = 1.0 - (current_time - self.last_request_times[0])
                if sleep_time > 0:
                    time.sleep(sleep_time)

            # Record this request
            self.last_request_times.append(time.time())

    def fetch_single_variant_parallel(self, variation_id: str, retry_count: int = 3) -> Optional[Dict]:
        """
        Fetch and parse a single variant, with parallel execution
        """
        # Check parsed cache first
        parsed_cache_file = self.parsed_cache_dir / f"{variation_id}.pkl"
        if parsed_cache_file.exists():
            try:
                with open(parsed_cache_file, 'rb') as f:
                    return pickle.load(f)
            except:
                pass

        # Check XML cache
        xml_cache_file = self.xml_cache_dir / f"{variation_id}.xml.gz"
        if xml_cache_file.exists():
            try:
                with gzip.open(xml_cache_file, 'rt', encoding='utf-8') as f:
                    xml_content = f.read()
                    # Parse and cache result
                    result = self.parse_vcv_xml(xml_content, variation_id)
                    if result and result.get('submissions'):
                        with open(parsed_cache_file, 'wb') as f:
                            pickle.dump(result, f)
                    return result
            except:
                pass

        # Fetch from API
        for attempt in range(retry_count):
            try:
                self._rate_limit_wait()

                fetch_url = f"{self.eutils_base}/efetch.fcgi"
                params = {
                    'db': 'clinvar',
                    'rettype': 'vcv',
                    'id': variation_id,
                    'is_variationid': '',
                    'email': self.email
                }

                if self.api_key:
                    params['api_key'] = self.api_key

                response = requests.get(fetch_url, params=params, timeout=30)

                if response.status_code == 200 and 'VariationArchive' in response.text:
                    xml_content = response.text

                    # Cache XML
                    with gzip.open(xml_cache_file, 'wt', encoding='utf-8') as f:
                        f.write(xml_content)

                    result = self.parse_vcv_xml(xml_content, variation_id)
                    if result and result.get('submissions'):
                        with open(parsed_cache_file, 'wb') as f:
                            pickle.dump(result, f)

                    return result
                elif response.status_code == 429:  # Rate limited
                    time.sleep(2 ** attempt)
                    continue

            except requests.exceptions.Timeout:
                if attempt < retry_count - 1:
                    time.sleep(1)
                    continue
            except Exception as e:
                if attempt == retry_count - 1:
                    print(f"[ERROR] Failed to fetch {variation_id} after {retry_count} attempts: {e}")

        return None

    def process_variants_parallel(self, variation_ids: List[str], max_workers: int = 10) -> List[Dict]:
        """
        Both fetching and parsing in parallel
        """
        print(f"\n[PARALLEL] Processing {len(variation_ids)} variants with {max_workers} workers")

        results = []
        processed_count = 0
        failed_count = 0
        start_time = time.time()

        # Use ThreadPoolExecutor for parallel processing
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_id = {
                executor.submit(self.fetch_single_variant_parallel, var_id): var_id
                for var_id in variation_ids
            }

            for future in as_completed(future_to_id):
                var_id = future_to_id[future]
                try:
                    result = future.result(timeout=60)
                    if result and result.get('submissions'):
                        results.append(result)
                        processed_count += 1
                    else:
                        failed_count += 1

                    # Progress update
                    total_done = processed_count + failed_count
                    if total_done % 100 == 0:
                        elapsed = time.time() - start_time
                        rate = total_done / elapsed
                        remaining = (len(variation_ids) - total_done) / rate
                        print(f"[PARALLEL] Progress: {total_done}/{len(variation_ids)} "
                              f"({100*total_done/len(variation_ids):.1f}%) - "
                              f"Rate: {rate:.1f} variants/sec - "
                              f"ETA: {remaining/60:.1f} minutes")

                except Exception as e:
                    print(f"[ERROR] Failed to process {var_id}: {e}")
                    failed_count += 1

        elapsed = time.time() - start_time
        print(f"[PARALLEL] Completed in {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
        print(f"[PARALLEL] Successfully processed: {processed_count}")
        print(f"[PARALLEL] Failed: {failed_count}")
        print(f"[PARALLEL] Average rate: {len(variation_ids)/elapsed:.1f} variants/second")

        return results

    def parse_vcv_xml(self, xml_content: str, variation_id: str) -> Dict:
        """
        Parse VCV XML and extract relevant information
        """
        try:
            root = ET.fromstring(xml_content)

            result = {
                'variation_id': variation_id,
                'name': '',
                'rsid': '',
                'submissions': [],
                'classification_timeline': [],
                'all_citations': {}
            }

            # Get variant basic info
            var_archive = root.find('.//VariationArchive')
            if var_archive is None:
                var_archive = root

            result['vcv_accession'] = var_archive.get('Accession', '')

            # Get variant name - try multiple locations
            # First try the VariationName attribute
            var_name = var_archive.get('VariationName', '').strip()
            if var_name:
                result['name'] = var_name
            else:
                # Try HGVS - get the actual text, not whitespace
                for hgvs_elem in root.findall('.//SimpleAllele/HGVSlist/HGVS'):
                    if hgvs_elem.text and hgvs_elem.text.strip():
                        result['name'] = hgvs_elem.text.strip()
                        break

                if not result['name']:
                    # Try CanonicalSPDI
                    spdi_elem = root.find('.//CanonicalSPDI')
                    if spdi_elem is not None and spdi_elem.text and spdi_elem.text.strip():
                        result['name'] = spdi_elem.text.strip()

                    if not result['name']:
                        # Try ProteinChange
                        protein_elem = root.find('.//ProteinChange')
                        if protein_elem is not None and protein_elem.text and protein_elem.text.strip():
                            result['name'] = protein_elem.text.strip()

            # Get rsID - check multiple locations
            xref_elem = root.find('.//SimpleAllele/XRefList/XRef[@DB="dbSNP"]')
            if xref_elem is not None:
                rs_id = xref_elem.get('ID', '')
                if rs_id:
                    if rs_id.startswith('rs'):
                        result['rsid'] = rs_id
                    else:
                        result['rsid'] = f"rs{rs_id}"

            # Parse ALL ClinicalAssertion elements
            assertions = root.findall('.//ClinicalAssertion')

            for assertion in assertions:
                submission = self.parse_clinical_assertion(assertion)
                if submission:
                    result['submissions'].append(submission)
                    # Collect citations
                    for cit in submission.get('citations', []):
                        if cit.get('pmid'):
                            result['all_citations'][cit['pmid']] = cit

            # Sort submissions by date
            result['submissions'].sort(key=lambda x: x.get('date_last_evaluated', ''))

            # Build classification timeline
            # Check if include conflicts (from setting stored during process call)
            include_conflicts = getattr(self, 'include_conflicts_setting', True)
            result['classification_timeline'] = self.build_classification_timeline(
                result['submissions'],
                include_conflicts=include_conflicts
            )

            return result

        except Exception as e:
            return {}

    def parse_clinical_assertion(self, assertion_elem) -> Dict:
        """
        Parse ClinicalAssertion
        """
        submission = {}

        try:
            # SCV Accession
            scv_elem = assertion_elem.find('.//ClinVarAccession[@Type="SCV"]')
            if scv_elem is not None:
                submission['scv_accession'] = f"{scv_elem.get('Accession')}.{scv_elem.get('Version')}"
                submission['submitter'] = scv_elem.get('SubmitterName', '') or scv_elem.get('OrgAbbreviation', '')
                submission['date_updated'] = scv_elem.get('DateUpdated', '')

            # Classification - check multiple locations
            classification_found = False

            # Classification/GermlineClassification
            class_elem = assertion_elem.find('.//Classification')
            if class_elem is not None:
                submission['date_last_evaluated'] = class_elem.get('DateLastEvaluated', '')
                germ_class = class_elem.find('./GermlineClassification')
                if germ_class is not None and germ_class.text:
                    submission['classification'] = germ_class.text.strip()
                    classification_found = True
                else:
                    som_class = class_elem.find('./SomaticClinicalImpact')
                    if som_class is not None and som_class.text:
                        submission['classification'] = som_class.text.strip()
                        classification_found = True

                review_elem = class_elem.find('./ReviewStatus')
                if review_elem is not None and review_elem.text:
                    submission['review_status'] = review_elem.text.strip()

            # ClinicalSignificance/Description
            if not classification_found:
                clin_sig_elem = assertion_elem.find('.//ClinicalSignificance')
                if clin_sig_elem is not None:
                    desc_elem = clin_sig_elem.find('./Description')
                    if desc_elem is not None and desc_elem.text:
                        submission['classification'] = desc_elem.text.strip()
                        classification_found = True
                    if not submission.get('date_last_evaluated'):
                        submission['date_last_evaluated'] = clin_sig_elem.get('DateLastEvaluated', '')
                    if not submission.get('review_status'):
                        review_elem = clin_sig_elem.find('./ReviewStatus')
                        if review_elem is not None and review_elem.text:
                            submission['review_status'] = review_elem.text.strip()

            # Interpretation
            if not classification_found:
                interp_elem = assertion_elem.find('.//Interpretation')
                if interp_elem is not None:
                    for child in ['Description', 'Classification']:
                        elem = interp_elem.find(f'./{child}')
                        if elem is not None and elem.text:
                            submission['classification'] = elem.text.strip()
                            classification_found = True
                            break
                    if not submission.get('date_last_evaluated'):
                        submission['date_last_evaluated'] = interp_elem.get('DateLastEvaluated', '')

            # Additional submitter info
            sub_id_elem = assertion_elem.find('.//ClinVarSubmissionID')
            if sub_id_elem is not None:
                if not submission.get('submitter'):
                    submission['submitter'] = sub_id_elem.get('submitter', '').replace(';', ',')

            # Condition
            trait_elem = assertion_elem.find('.//TraitSet[@Type="Disease"]/Trait/Name/ElementValue[@Type="Preferred"]')
            if trait_elem is None:
                trait_elem = assertion_elem.find('.//TraitSet/Trait/Name/ElementValue')
            if trait_elem is not None and trait_elem.text:
                submission['condition'] = trait_elem.text

            # Citations
            citations = []
            for citation in assertion_elem.findall('.//Citation'):
                pmid_elem = citation.find('./ID[@Source="PubMed"]')
                if pmid_elem is not None and pmid_elem.text:
                    cit = {'pmid': pmid_elem.text}
                    title_elem = citation.find('./Title')
                    if title_elem is not None:
                        cit['title'] = title_elem.text
                    citations.append(cit)

            submission['citations'] = citations
            submission['citation_count'] = len(citations)

            # Date fallback
            if not submission.get('date_last_evaluated'):
                submission['date_last_evaluated'] = submission.get('date_updated', '')

            if submission.get('scv_accession'):
                return submission

        except:
            pass

        return None

    def build_classification_timeline(self, submissions: List[Dict], include_conflicts: bool = True) -> List[Dict]:
        """
        Build timeline of classification changes
        Args:
            submissions: List of submission dictionaries
            include_conflicts: If True, include conflicting interpretations between labs
        """
        timeline = []

        if len(submissions) < 2:
            return timeline

        # Track changes within same submitter over time
        by_submitter = defaultdict(list)
        for sub in submissions:
            submitter = sub.get('submitter', 'Unknown')
            if submitter and submitter != 'Unknown':
                by_submitter[submitter].append(sub)

        # Find temporal changes within each submitter
        for submitter, subs in by_submitter.items():
            if len(subs) < 2:
                continue

            # Sort by date
            subs.sort(key=lambda x: x.get('date_last_evaluated', ''))

            for i in range(1, len(subs)):
                prev_sub = subs[i-1]
                curr_sub = subs[i]

                # Both must have dates for temporal comparison
                if not prev_sub.get('date_last_evaluated') or not curr_sub.get('date_last_evaluated'):
                    continue

                prev_class = prev_sub.get('classification', '').strip()
                curr_class = curr_sub.get('classification', '').strip()

                # Skip if no classification data
                if not prev_class or not curr_class:
                    continue

                # Normalize for comparison
                if prev_class.lower() == curr_class.lower():
                    continue

                # A real temporal change
                has_citations = (len(prev_sub.get('citations', [])) > 0 or
                               len(curr_sub.get('citations', [])) > 0)

                timeline.append({
                    'submitter': submitter,
                    'date_old': prev_sub.get('date_last_evaluated', ''),
                    'classification_old': prev_sub.get('classification', ''),
                    'date_new': curr_sub.get('date_last_evaluated', ''),
                    'classification_new': curr_sub.get('classification', ''),
                    'scv_old': prev_sub.get('scv_accession', ''),
                    'scv_new': curr_sub.get('scv_accession', ''),
                    'citations_old': prev_sub.get('citations', []),
                    'citations_new': curr_sub.get('citations', []),
                    'has_citations': has_citations,
                    'change_type': self.classify_change(
                        prev_sub.get('classification', ''),
                        curr_sub.get('classification', '')
                    ),
                    'is_conflict': False
                })

        # Add conflicting interpretations (current disagreements between labs)
        if include_conflicts and len(by_submitter) > 1:
            # Get the most recent submission from each submitter
            latest_by_submitter = {}
            for submitter, subs in by_submitter.items():
                if subs and submitter and submitter != 'Unknown':
                    # Get the most recent submission with a classification
                    for sub in reversed(sorted(subs, key=lambda x: x.get('date_last_evaluated', ''))):
                        if sub.get('classification'):
                            latest_by_submitter[submitter] = sub
                            break

            # Check for conflicts between submitters
            if len(latest_by_submitter) > 1:
                classifications = {}
                for submitter, sub in latest_by_submitter.items():
                    class_norm = self.normalize_classification_category(sub.get('classification', ''))
                    if class_norm and class_norm != 'not_provided':
                        if class_norm not in classifications:
                            classifications[class_norm] = []
                        classifications[class_norm].append((submitter, sub))

                # Only add ONE conflict entry per variant if there are disagreements
                if len(classifications) > 1:
                    # Find the most significant conflict
                    has_pathogenic = any('pathogenic' in c for c in classifications.keys())
                    has_benign = any('benign' in c for c in classifications.keys())
                    has_uncertain = 'uncertain_significance' in classifications.keys()

                    # Only add if it's a significant conflict
                    if (has_pathogenic and has_benign) or (has_pathogenic and has_uncertain) or (has_benign and has_uncertain):
                        # Get representatives from each category
                        conflict_summary = []
                        for class_type in sorted(classifications.keys()):
                            if classifications[class_type]:
                                submitter, sub = classifications[class_type][0]
                                conflict_summary.append(f"{submitter}: {sub.get('classification')}")

                        # Add single conflict entry
                        timeline.append({
                            'submitter': 'CONFLICT_BETWEEN_LABS',
                            'date_old': '',
                            'classification_old': 'Conflicting interpretations',
                            'date_new': datetime.now().strftime('%Y-%m-%d'),
                            'classification_new': ' | '.join(conflict_summary[:3]),  # Limit to 3 examples
                            'scv_old': '',
                            'scv_new': '',
                            'citations_old': [],
                            'citations_new': [],
                            'has_citations': False,
                            'change_type': 'conflicting_interpretations',
                            'is_conflict': True
                        })

        return timeline

    def normalize_classification_category(self, classification: str) -> str:
        """
        Normalize classification to major categories for conflict detection
        """
        class_lower = classification.lower().strip()

        if 'pathogenic' in class_lower and 'likely' in class_lower:
            return 'likely_pathogenic'
        elif 'pathogenic' in class_lower:
            return 'pathogenic'
        elif 'benign' in class_lower and 'likely' in class_lower:
            return 'likely_benign'
        elif 'benign' in class_lower:
            return 'benign'
        elif 'uncertain' in class_lower or 'vus' in class_lower:
            return 'uncertain_significance'
        elif 'not provided' in class_lower:
            return 'not_provided'
        else:
            return class_lower

    def classify_change(self, old_class: str, new_class: str) -> str:
        """
        Classify the type of classification change
        """
        old_lower = old_class.lower()
        new_lower = new_class.lower()

        pathogenic = ['pathogenic', 'likely pathogenic']
        uncertain = ['uncertain significance', 'variant of uncertain significance', 'vus']
        benign = ['benign', 'likely benign']

        old_category = None
        new_category = None

        for category, terms in [('pathogenic', pathogenic),
                                ('uncertain', uncertain),
                                ('benign', benign)]:
            if any(term in old_lower for term in terms):
                old_category = category
            if any(term in new_lower for term in terms):
                new_category = category

        if old_category == 'uncertain' and new_category == 'pathogenic':
            return 'upgraded_to_pathogenic'
        elif old_category == 'uncertain' and new_category == 'benign':
            return 'downgraded_to_benign'
        elif old_category == 'pathogenic' and new_category == 'uncertain':
            return 'downgraded_to_uncertain'
        elif old_category == 'benign' and new_category == 'uncertain':
            return 'upgraded_to_uncertain'
        elif old_category == 'pathogenic' and new_category == 'benign':
            return 'major_downgrade'
        elif old_category == 'benign' and new_category == 'pathogenic':
            return 'major_upgrade'
        else:
            return 'other_change'

    def get_statistics(self, limit: int = 1000) -> Dict:
        """
        Get quick statistics about BRCA1 variants
        """
        # Check cached stats
        if self.stats_cache.exists():
            with open(self.stats_cache, 'r') as f:
                stats = json.load(f)
                print("[STATS] Using cached statistics")
                return stats

        print(f"\n[STATS] Analyzing first {limit} variants for statistics...")

        variant_ids = self.search_brca1_variants(limit=limit)

        results = self.process_variants_parallel(variant_ids, max_workers=10)

        stats = {
            'total_variants_analyzed': len(results),
            'variants_with_multiple_submissions': 0,
            'variants_with_changes': 0,
            'variants_with_changes_and_citations': 0,
            'total_submissions': 0,
            'unique_submitters': set(),
            'classification_distribution': defaultdict(int),
            'change_type_distribution': defaultdict(int)
        }

        for result in results:
            num_subs = len(result['submissions'])
            stats['total_submissions'] += num_subs

            if num_subs > 1:
                stats['variants_with_multiple_submissions'] += 1

            for sub in result['submissions']:
                if sub.get('submitter'):
                    stats['unique_submitters'].add(sub['submitter'])
                if sub.get('classification'):
                    stats['classification_distribution'][sub['classification']] += 1

            if result['classification_timeline']:
                stats['variants_with_changes'] += 1

                # Check if any change has citations
                has_citations = any(
                    len(change.get('citations_old', [])) > 0 or
                    len(change.get('citations_new', [])) > 0
                    for change in result['classification_timeline']
                )
                if has_citations:
                    stats['variants_with_changes_and_citations'] += 1

                for change in result['classification_timeline']:
                    stats['change_type_distribution'][change['change_type']] += 1

        stats['unique_submitters'] = len(stats['unique_submitters'])
        stats['classification_distribution'] = dict(stats['classification_distribution'])
        stats['change_type_distribution'] = dict(stats['change_type_distribution'])


        with open(self.stats_cache, 'w') as f:
            json.dump(stats, f, indent=2)

        return stats

    def process_all_brca1_variants(self, limit: int = None, max_workers: int = 10,
                                   require_citations: bool = True, include_conflicts: bool = True):
        """
        Main processing method with TRUE parallel processing
        Args:
            limit: Number of variants to process
            max_workers: Number of parallel workers
            require_citations: If True, only include changes with citations; if False, include all changes
            include_conflicts: If True, include conflicting interpretations between labs
        """
        print("\n" + "="*80)
        print("TRULY PARALLEL PROCESSING - ALL BRCA1 VARIANTS")
        print("="*80)

        start_time = time.time()

        self.include_conflicts_setting = include_conflicts

        variation_ids = self.search_brca1_variants(limit=limit)

        print(f"\n[PROCESS] Processing {len(variation_ids)} variants...")
        print(f"[OPTIMIZE] Using {max_workers} parallel workers for API fetching AND parsing")
        print(f"[FILTER] Require citations: {require_citations}")
        print(f"[FILTER] Include conflicts: {include_conflicts}")

        all_results = self.process_variants_parallel(variation_ids, max_workers=max_workers)

        # Filter for variants with changes
        variants_with_changes = []
        variants_with_changes_no_citations = []
        all_pmids = set()

        for result in all_results:
            if result.get('classification_timeline'):
                # Filter out conflicts if not wanted
                if not include_conflicts:
                    timeline_filtered = [
                        change for change in result['classification_timeline']
                        if not change.get('is_conflict', False)
                    ]
                else:
                    timeline_filtered = result['classification_timeline']

                # Skip if no changes after filtering
                if not timeline_filtered:
                    continue

                # Separate changes with and without citations
                changes_with_citations = [
                    change for change in timeline_filtered
                    if change.get('has_citations', False)
                ]

                changes_without_citations = [
                    change for change in timeline_filtered
                    if not change.get('has_citations', False)
                ]

                # Collect PMIDs
                for sub in result['submissions']:
                    for cit in sub.get('citations', []):
                        if cit.get('pmid'):
                            all_pmids.add(cit['pmid'])

                variant_record = {
                    'variation_id': result['variation_id'],
                    'vcv_accession': result.get('vcv_accession', ''),
                    'rsid': result.get('rsid', ''),
                    'gene': 'BRCA1',
                    'variant_name': result.get('name', '').strip(),
                    'total_submissions': len(result['submissions']),
                    'all_submissions': [
                        {
                            'scv_accession': sub.get('scv_accession', ''),
                            'submitter': sub.get('submitter', ''),
                            'date': sub.get('date_last_evaluated', ''),
                            'classification': sub.get('classification', ''),
                            'review_status': sub.get('review_status', ''),
                            'citation_count': sub.get('citation_count', 0)
                        } for sub in result['submissions']
                    ]
                }

                if require_citations and changes_with_citations:
                    variant_record['classification_changes'] = changes_with_citations
                    variants_with_changes.append(variant_record)
                elif not require_citations and timeline_filtered:
                    variant_record['classification_changes'] = timeline_filtered
                    variants_with_changes.append(variant_record)

                if changes_without_citations:
                    variants_with_changes_no_citations.append(variant_record)

        processing_time = time.time() - start_time
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        output_data = {
            'metadata': {
                'gene': 'BRCA1',
                'extraction_date': datetime.now().isoformat(),
                'processing_time_seconds': processing_time,
                'processing_time_minutes': processing_time / 60,
                'total_variants_processed': len(all_results),
                'variants_with_changes': len(variants_with_changes),
                'variants_with_changes_no_citations': len(variants_with_changes_no_citations),
                'unique_citations': len(all_pmids),
                'processing_rate': len(all_results) / processing_time if processing_time > 0 else 0,
                'parallel_workers': max_workers,
                'citation_filter_applied': require_citations,
                'conflicts_included': include_conflicts
            },
            'variants': variants_with_changes
        }

        output_file = f"BRCA1_changes_parallel_{timestamp}.json"
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False, default=str)

        print(f"\n[SAVE] Results saved to: {output_file}")

        self.print_summary(output_data)

        return output_data

    def print_summary(self, data: Dict):
        """
        Print analysis summary
        """
        print("\n" + "="*80)
        print("ANALYSIS COMPLETE")
        print("="*80)

        metadata = data['metadata']
        variants = data['variants']

        print(f"\n PERFORMANCE:")
        print(f"  Processing time: {metadata['processing_time_seconds']:.1f} seconds")
        print(f"  Processing time: {metadata['processing_time_minutes']:.1f} minutes")
        print(f"  Variants processed: {metadata['total_variants_processed']}")
        print(f"  Processing rate: {metadata['processing_rate']:.1f} variants/second")
        print(f"  Parallel workers used: {metadata['parallel_workers']}")

        print(f"\n RESULTS:")
        print(f"  Variants with changes: {len(variants)}")
        print(f"  Percentage with changes: {100*len(variants)/metadata['total_variants_processed']:.1f}%")

        if variants:
            # Count change types
            change_types = defaultdict(int)
            for var in variants:
                for change in var['classification_changes']:
                    change_types[change['change_type']] += 1

            print(f"\n  Change type distribution:")
            for change_type, count in sorted(change_types.items(), key=lambda x: x[1], reverse=True):
                print(f"    {change_type}: {count}")

            # Sample changes
            print(f"\n  Sample classification changes (first 3):")
            for var in variants[:3]:
                print(f"\n    Variant ID: {var['variation_id']}")
                print(f"    Name: {var['variant_name'] or 'N/A'}")
                print(f"    RS: {var['rsid'] or 'N/A'}")
                for change in var['classification_changes'][:1]:
                    print(f"    Change by {change['submitter']}:")
                    print(f"      {change['classification_old']} → {change['classification_new']}")
                    print(f"      ({change['date_old']} to {change['date_new']})")



def main():


    EMAIL = "yl8889@nyu.edu"
    API_KEY = "92c067f1875f6abf0dfbf5d0d57758beff09"

    # Options: 100 (test), None (all ~15,000)
    LIMIT_VARIANTS = None

    # Number of parallel workers for API fetching
    MAX_WORKERS = 15

    # Citation filter
    # True = Only changes WITH citations (~10-50 variants expected)
    # False = ALL classification changes (~1000-2000 variants expected)
    REQUIRE_CITATIONS = False

    # Include conflicting interpretations between labs
    # True = Include variants where labs currently disagree
    # False = Only temporal changes (same lab changing over time)
    INCLUDE_CONFLICTS = False  # Set to False to exclude conflicts

    # Show statistics before processing
    SHOW_STATISTICS = True

    # Statistics sample size (if SHOW_STATISTICS is True)
    STATS_SAMPLE_SIZE = 1000


    print("\n CONFIGURATION:")
    print(f"  Variants to process: {LIMIT_VARIANTS if LIMIT_VARIANTS else 'ALL (~15,000)'}")
    print(f"  Parallel workers: {MAX_WORKERS}")
    print(f"  Require citations: {REQUIRE_CITATIONS}")
    print(f"  Include conflicts between labs: {INCLUDE_CONFLICTS}")
    print(f"  Show statistics: {SHOW_STATISTICS}")
    print("="*80)

    api = ClinVarParallelAPI(email=EMAIL, api_key=API_KEY)

    if SHOW_STATISTICS:
        print(f"\n Generating statistics from {STATS_SAMPLE_SIZE} variants...")
        stats = api.get_statistics(limit=STATS_SAMPLE_SIZE)
        print(f"\n BRCA1 VARIANT STATISTICS (based on {STATS_SAMPLE_SIZE} variants):")
        print(f"  Variants analyzed: {stats['total_variants_analyzed']}")
        print(f"  Variants with multiple submissions: {stats['variants_with_multiple_submissions']}")
        print(f"  Variants with changes: {stats['variants_with_changes']}")
        print(f"  Variants with changes + citations: {stats['variants_with_changes_and_citations']}")

        if stats['total_variants_analyzed'] > 0:
            print(f"\n  Expected yields:")
            print(f"    With citations filter: ~{100*stats['variants_with_changes_and_citations']/stats['total_variants_analyzed']:.2f}% of variants")
            print(f"    Without citations filter: ~{100*stats['variants_with_changes']/stats['total_variants_analyzed']:.2f}% of variants")

            if LIMIT_VARIANTS:
                est_with_citations = int(LIMIT_VARIANTS * stats['variants_with_changes_and_citations'] / stats['total_variants_analyzed'])
                est_without_citations = int(LIMIT_VARIANTS * stats['variants_with_changes'] / stats['total_variants_analyzed'])
                print(f"\n  Estimated results for {LIMIT_VARIANTS} variants:")
                print(f"    With citations: ~{est_with_citations} variants")
                print(f"    Without citations: ~{est_without_citations} variants")

        print("\n" + "-"*80)

    print(f"\n Starting processing...")
    print(f"  Processing {LIMIT_VARIANTS if LIMIT_VARIANTS else 'ALL'} variants")
    print(f"  Using {MAX_WORKERS} parallel workers")
    print(f"  Citation filter: {'ENABLED (strict)' if REQUIRE_CITATIONS else 'DISABLED (all changes)'}")
    print(f"  Conflict detection: {'ENABLED' if INCLUDE_CONFLICTS else 'DISABLED (temporal changes only)'}")
    print("\n" + "-"*80)

    start_time = time.time()

    try:
        result_data = api.process_all_brca1_variants(
            limit=LIMIT_VARIANTS,
            max_workers=MAX_WORKERS,
            require_citations=REQUIRE_CITATIONS,
            include_conflicts=INCLUDE_CONFLICTS
        )

        elapsed = time.time() - start_time
        print("\n" + "="*80)
        print("PROCESSING COMPLETE")
        print("="*80)
        print(f"\n FINAL RESULTS:")
        print(f"  Total time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
        print(f"  Variants processed: {result_data['metadata']['total_variants_processed']}")
        print(f"  Variants with changes: {len(result_data['variants'])}")
        print(f"  Processing rate: {result_data['metadata']['processing_rate']:.1f} variants/second")

        if len(result_data['variants']) > 0:
            print(f"\n  Change percentage: {100*len(result_data['variants'])/result_data['metadata']['total_variants_processed']:.2f}%")

            # Count change types
            change_types = defaultdict(int)
            temporal_changes = 0
            conflict_changes = 0
            total_changes = 0

            for var in result_data['variants']:
                for change in var.get('classification_changes', []):
                    change_types[change.get('change_type', 'unknown')] += 1
                    total_changes += 1
                    if change.get('is_conflict', False):
                        conflict_changes += 1
                    else:
                        temporal_changes += 1

            if change_types:
                print(f"\n  Total classification events: {total_changes}")
                print(f"    Temporal changes (same lab): {temporal_changes}")
                print(f"    Conflicts (between labs): {conflict_changes}")
                print(f"\n  Change type distribution:")
                for change_type, count in sorted(change_types.items(), key=lambda x: x[1], reverse=True)[:5]:
                    print(f"    {change_type}: {count} ({100*count/total_changes:.1f}%)")

        # Save filename from timestamp in metadata
        output_file = f"BRCA1_changes_parallel_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        print(f"\n Results saved to: {output_file}")

    except Exception as e:
        print(f"\n Error during processing: {e}")
        import traceback
        traceback.print_exc()

    print("\n" + "="*80)


if __name__ == "__main__":
    main()


 CONFIGURATION:
  Variants to process: ALL (~15,000)
  Parallel workers: 15
  Require citations: False
  Include conflicts between labs: False
  Show statistics: True
[INIT] Parallel ClinVar API initialized
[INIT] Email: yl8889@nyu.edu
[INIT] API Key: Configured
[INIT] Rate limit: 10 requests/second

 Generating statistics from 1000 variants...

[STATS] Analyzing first 1000 variants for statistics...

SEARCHING FOR ALL BRCA1 VARIANTS IN CLINVAR
[SEARCH] Found 15452 BRCA1 records in ClinVar
[SEARCH] Retrieved 500/1000 IDs
[SEARCH] Retrieved 1000/1000 IDs

[PARALLEL] Processing 1000 variants with 10 workers
[PARALLEL] Progress: 100/1000 (10.0%) - Rate: 8.1 variants/sec - ETA: 1.9 minutes
[PARALLEL] Progress: 200/1000 (20.0%) - Rate: 8.5 variants/sec - ETA: 1.6 minutes
[PARALLEL] Progress: 300/1000 (30.0%) - Rate: 8.5 variants/sec - ETA: 1.4 minutes
[PARALLEL] Progress: 400/1000 (40.0%) - Rate: 8.6 variants/sec - ETA: 1.2 minutes
[PARALLEL] Progress: 500/1000 (50.0%) - Rate: 8.7 variants