In [None]:
"""
NCBI COG Database Search Tool - T6SS Components

This script queries the NIH Database of Clusters of Orthologous Genes (COGs)
to retrieve information about Type VI Secretion System (T6SS) components.

For each T6SS gene, it retrieves:
- COG identifier number (e.g., COG3515)
- PDB entries (structural data)
- List of orthologous genes across all organisms

After collection, it filters to only keep species that have ALL T6SS components.

API Documentation: https://www.ncbi.nlm.nih.gov/research/cog/webservices/
"""

import requests
import time
import csv
from typing import Optional, Dict, List
from dataclasses import dataclass, field
from collections import defaultdict


# =============================================================================
# T6SS GENE TO COG MAPPING
# =============================================================================

T6SS_COG_MAPPING = {
    'tssA': 'COG3515',
    'tssB': 'COG3516',
    'tssC': 'COG3517',
    'tssD': 'COG3157',
    'tssE': 'COG3518',
    'tssF': 'COG3519',
    'tssG': 'COG3520',
    'tssH': 'COG0542',
    'tssI': 'COG3501',
    'tssJ': 'COG3521',
    'tssK': 'COG3522',
    'tssL': 'COG3455',
    'tssM': 'COG3523',
    'PAAR': 'COG4104',
}

T6SS_GENES = list(T6SS_COG_MAPPING.keys())


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class OrthologousGene:
    """Represents an orthologous gene entry from the COG database."""
    gene_tag: str
    organism_name: str
    bitscore: float


@dataclass
class COGResult:
    """Represents a COG search result with associated metadata."""
    gene_name: str
    cog_id: str
    cog_name: str
    pdb_entries: list
    orthologous_genes: list = field(default_factory=list)


# =============================================================================
# COG DATABASE SEARCHER CLASS
# =============================================================================

class COGSearcher:
    """Search the NCBI COG database by COG IDs."""

    BASE_URL = "https://www.ncbi.nlm.nih.gov/research/cog/api"

    def __init__(self, delay_between_requests: float = 0.3):
        self.session = requests.Session()
        self.session.headers.update({
            'Accept': 'application/json',
            'User-Agent': 'T6SSCOGSearch/1.0 (Academic Research)'
        })
        self.delay = delay_between_requests

    def _make_request(self, endpoint: str, params: Optional[dict] = None) -> dict:
        """Make a request to the COG API with error handling."""
        url = f"{self.BASE_URL}{endpoint}"
        if params is None:
            params = {}
        params['format'] = 'json'

        try:
            response = self.session.get(url, params=params, timeout=30)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            raise ConnectionError(f"API request failed: {e}") from e

        time.sleep(self.delay)
        return response.json()

    def get_cog_definition(self, cog_id: str) -> Optional[dict]:
        """Get the definition/metadata for a specific COG ID."""
        data = self._make_request('/cogdef/', {'cog': cog_id})
        results = data.get('results', [])
        return results[0] if results else None

    def get_orthologous_genes(self, cog_id: str, verbose: bool = True) -> List[OrthologousGene]:
        """Get all orthologous genes for a COG across organisms."""
        orthologs = []
        page = 1

        while True:
            if verbose:
                print(f"      Fetching orthologs page {page} for {cog_id}...", end='\r')

            data = self._make_request('/cog/', {'cog': cog_id, 'page': page})
            results = data.get('results', [])
            if not results:
                break

            for entry in results:
                organism = entry.get('organism', {})
                ortholog = OrthologousGene(
                    gene_tag=entry.get('gene_tag', ''),
                    organism_name=organism.get('genome_name', ''),
                    bitscore=entry.get('bitscore', 0.0),
                )
                orthologs.append(ortholog)

            if not data.get('next'):
                break
            page += 1

        if verbose:
            print(f"      Retrieved {len(orthologs)} orthologs for {cog_id}" + " " * 20)

        return orthologs

    def search_t6ss_cogs(self, verbose: bool = True) -> List[COGResult]:
        """Search for all T6SS COGs and retrieve their information."""
        results = []

        for gene_name, cog_id in T6SS_COG_MAPPING.items():
            if verbose:
                print(f"\n{'='*60}")
                print(f"Fetching {gene_name} ({cog_id})")
                print('='*60)

            cogdef = self.get_cog_definition(cog_id)
            if not cogdef:
                if verbose:
                    print(f"  WARNING: COG {cog_id} not found in database!")
                continue

            result = COGResult(
                gene_name=gene_name,
                cog_id=cog_id,
                cog_name=cogdef.get('name', ''),
                pdb_entries=cogdef.get('pdbs', []) or [],
            )

            if verbose:
                print(f"  COG Name: {result.cog_name[:60]}...")
                print(f"  PDB entries: {result.pdb_entries if result.pdb_entries else 'None'}")
                print(f"  Fetching orthologous genes...")

            result.orthologous_genes = self.get_orthologous_genes(cog_id, verbose=verbose)
            results.append(result)

        return results


# =============================================================================
# FILTERING AND OUTPUT FUNCTIONS
# =============================================================================

def get_species_with_all_genes(results: List[COGResult]) -> Dict[str, Dict[str, str]]:
    """Filter to find species that have ALL T6SS genes."""
    species_genes = defaultdict(lambda: defaultdict(list))

    for result in results:
        for ortholog in result.orthologous_genes:
            species_genes[ortholog.organism_name][result.gene_name].append({
                'gene_tag': ortholog.gene_tag,
                'bitscore': ortholog.bitscore
            })

    # Select best gene tag per species per gene
    species_best_genes = {}
    for species, genes_dict in species_genes.items():
        best_genes = {}
        for gene_name, gene_list in genes_dict.items():
            best = max(gene_list, key=lambda x: x['bitscore'])
            best_genes[gene_name] = best['gene_tag']
        species_best_genes[species] = best_genes

    # Filter to species with ALL genes
    required_set = set(T6SS_GENES)
    return {sp: genes for sp, genes in species_best_genes.items()
            if required_set.issubset(set(genes.keys()))}


def create_output_dataframe(results: List[COGResult], complete_species: Dict[str, Dict[str, str]]) -> Dict[str, List]:
    """Create a data structure for the filtered results."""
    species_list = sorted(complete_species.keys())
    data = {'T6SS_gene': [], 'PDB_entries': [], 'COG_ID': []}
    for species in species_list:
        data[species] = []

    gene_to_result = {r.gene_name: r for r in results}

    for gene in T6SS_GENES:
        if gene in gene_to_result:
            result = gene_to_result[gene]
            data['T6SS_gene'].append(gene)
            data['PDB_entries'].append(';'.join(result.pdb_entries) if result.pdb_entries else '')
            data['COG_ID'].append(result.cog_id)
            for species in species_list:
                data[species].append(complete_species.get(species, {}).get(gene, ''))
        else:
            data['T6SS_gene'].append(gene)
            data['PDB_entries'].append('')
            data['COG_ID'].append(T6SS_COG_MAPPING[gene])
            for species in species_list:
                data[species].append('')

    return data


def save_to_csv(data: Dict[str, List], filename: str = 't6ss_complete_species.csv'):
    """Save the data dictionary to a CSV file."""
    columns = ['T6SS_gene', 'PDB_entries', 'COG_ID'] + [
        k for k in data.keys() if k not in ['T6SS_gene', 'PDB_entries', 'COG_ID']
    ]

    with open(filename, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(columns)
        for i in range(len(data['T6SS_gene'])):
            writer.writerow([data[col][i] for col in columns])

    print(f"\nSaved to {filename}")


# =============================================================================
# MAIN SEARCH FUNCTION
# =============================================================================

def run_t6ss_search(save_csv: bool = True, csv_filename: str = 't6ss_complete_species.csv', verbose: bool = True):
    """
    Main function to run the T6SS COG search.

    Args:
        save_csv: If True, save filtered results to CSV.
        csv_filename: Output CSV filename.
        verbose: If True, print progress messages.

    Returns:
        Tuple of (results, complete_species, data_dict)
    """
    if verbose:
        print("=" * 70)
        print("T6SS COG Database Search")
        print("=" * 70)

    searcher = COGSearcher()
    results = searcher.search_t6ss_cogs(verbose=verbose)

    if verbose:
        print("\n" + "=" * 70)
        print("FILTERING SPECIES")
        print("=" * 70)

    all_species = set()
    for r in results:
        for og in r.orthologous_genes:
            all_species.add(og.organism_name)

    if verbose:
        print(f"Total unique species found across all COGs: {len(all_species)}")

    complete_species = get_species_with_all_genes(results)

    if verbose:
        print(f"Species with ALL {len(T6SS_GENES)} T6SS genes: {len(complete_species)}")

    data = create_output_dataframe(results, complete_species)

    if verbose:
        print("\n" + "=" * 70)
        print("SUMMARY")
        print("=" * 70)
        print(f"\nCOGs found: {len(results)}")
        for r in results:
            pdb_str = f"PDB: {', '.join(r.pdb_entries)}" if r.pdb_entries else "No PDB"
            print(f"  • {r.gene_name} -> {r.cog_id}: {r.cog_name[:40]}... ({pdb_str})")
        print(f"\nSpecies with complete T6SS: {len(complete_species)}")

    if save_csv and complete_species:
        save_to_csv(data, csv_filename)

    return results, complete_species, data


if __name__ == "__main__":
    results, complete_species, data = run_t6ss_search()

T6SS COG Database Search

Fetching tssA (COG3515)
  COG Name: Type VI secretion system cap component TssA/VasJ/EvfE, conta...
  PDB entries: ['6RIU']
  Fetching orthologous genes...
      Retrieved 488 orthologs for COG3515                    

Fetching tssB (COG3516)
  COG Name: Type VI secretion system sheath component TssB/VipA/Hcp2...
  PDB entries: ['4UQZ']
  Fetching orthologous genes...
      Retrieved 489 orthologs for COG3516                    

Fetching tssC (COG3517)
  COG Name: Type VI secretion system sheath component TssC, TssC/VipB/Ev...
  PDB entries: ['5MYU']
  Fetching orthologous genes...
      Retrieved 570 orthologs for COG3517                    

Fetching tssD (COG3157)
  COG Name: Type VI secretion system tube protein TssD/Hcp...
  PDB entries: ['1Y12']
  Fetching orthologous genes...
      Retrieved 743 orthologs for COG3157                    

Fetching tssE (COG3518)
  COG Name: Type VI secretion system baseplate component TssE...
  PDB entries: None
  Fetch