In [None]:
class TaxonomyAnalyzer:
    def __init__(self, max_retries: int = 3, retry_delay: int = 1):
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.uniprot_base_url = "https://rest.uniprot.org/uniprotkb/"
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            filename="Final_taxonomy_analysis.log"
        )


    def fetch_taxonomy_info(self, protein_ids: List[str], output_file: str) -> str:
        """
        Fetch taxonomy information 
        """
        taxonomy_data = []
        error_counts = {"success": 0, "failed": 0}

        for protein_id in tqdm(protein_ids, desc="Fetching taxonomy info"):
            for attempt in range(self.max_retries):
                try:
                    response = requests.get(f"{self.uniprot_base_url}{protein_id}.json")
                    response.raise_for_status()
                    data = response.json()

                    taxonomy = data.get("organism", {})
                    scientific_name = taxonomy.get("scientificName", "N/A")
                    lineage = taxonomy.get("lineage", [])
                    taxonomy_data.append([protein_id, scientific_name, " > ".join(lineage)])

                    error_counts["success"] += 1
                    break
                
                except requests.exceptions.RequestException as e:
                    if attempt == self.max_retries - 1:
                        logging.error(f"Failed to fetch {protein_id} after {self.max_retries} attempts: {str(e)}")
                        taxonomy_data.append([protein_id, "Error", ""])
                        error_counts["failed"] += 1
                    else:
                        time.sleep(self.retry_delay)
                        
        with open(output_file, "w", newline="") as file:
            writer = csv.writer(file)
            writer.writerow(["Protein ID", "Scientific Name", "Lineage"])
            writer.writerows(taxonomy_data)
        logging.info(f"Taxonomy data saved to {output_file}")

        return output_file

  

In [None]:
class PhyloTreeVisualizer:
    def create_newick_string(self, taxonomy_df):
        """Convert taxonomy data to Newick format"""
        lineage_counts = {}
        for _, row in taxonomy_df.iterrows():
            if isinstance(row['Lineage'], str):
                taxa = row['Lineage'].split(' > ')
                for i in range(len(taxa)):
                    lineage = ' > '.join(taxa[:i+1])
                    lineage_counts[lineage] = lineage_counts.get(lineage, 0) + 1

        def build_newick(taxa, parent=''):
            current = taxa[-1] if taxa else ''
            current_path = ' > '.join(taxa)
            count = lineage_counts.get(current_path, 1)
            
            children = []
            for lineage in lineage_counts.keys():
                if lineage.startswith(current_path + ' > '):
                    next_level = lineage.split(' > ')[len(taxa)]
                    if next_level not in children:
                        children.append(next_level)
            
            if children:
                child_strings = [build_newick(taxa + [child]) for child in children]
                return f"({','.join(child_strings)}){current}:{count}"
            else:
                return f"{current}:{count}"

        root_taxa = set()
        for lineage in lineage_counts.keys():
            root = lineage.split(' > ')[0]
            if root not in root_taxa:
                root_taxa.add(root)
        
        newick = f"({','.join(build_newick([taxa]) for taxa in root_taxa)});"
        return newick

    def create_phylogenetic_tree(self, taxonomy_file, output_file):
        """Create and save a phylogenetic tree visualization"""
        # Read taxonomy data
        df = pd.read_csv(taxonomy_file)
        
        # Create Newick string
        newick_str = self.create_newick_string(df)
        
        # Parse tree from Newick format
        handle = StringIO(newick_str)
        tree = Phylo.read(handle, "newick")

        # Check with ASCII representation
        print("ASCII representation of the tree:")
        Phylo.draw_ascii(tree)
        
        # Set up the plot with larger figure size and adjusted dimensions
        fig = plt.figure(figsize=(20, 30))  # Increased figure size
        
        # Draw the tree with customized parameters
        axes = fig.add_subplot(1, 1, 1)
        Phylo.draw(tree, 
                  axes=axes,
                  do_show=False,
                  branch_labels=lambda c: str(int(c.branch_length)) if c.branch_length else '')
        
        # Adjust the plot
        axes.set_title("Taxonomic Tree with Branch Lengths Showing Relative Abundance", pad=20, size=16)
        axes.set_xlabel("Relative Abundance", size=12)
        
        # Increase spacing between taxa
        axes.set_xticks(axes.get_xticks())
        axes.set_yticks(axes.get_yticks())
        
        # Adjust label sizes and spacing
        plt.setp(axes.get_xticklabels(), fontsize=10)
        plt.setp(axes.get_yticklabels(), fontsize=10, style='italic')
        
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        
        # Save the plot with high resolution
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_file





In [None]:
def visualize_phylogenetic_tree(taxonomy_file, output_file):
    visualizer = PhyloTreeVisualizer()
    return visualizer.create_phylogenetic_tree(taxonomy_file, output_file)

In [None]:
def main():
    # First fetch taxonomy info as before
    analyzer = TaxonomyAnalyzer()
    taxonomy_file = analyzer.fetch_taxonomy_info(family_sequences, "Final_taxonomy_info.csv")
    
    # Create the phylogenetic tree
    tree_file = visualize_phylogenetic_tree(taxonomy_file, "Final_phylogenetic_tree.png")
    
    print("\nFiles created:")
    print(f"Taxonomy data: {taxonomy_file}")
    print(f"Phylogenetic tree: {tree_file}")

if __name__ == "__main__":
    main()

In [None]:
import requests
import pandas as pd
import xml.etree.ElementTree as ET
from scipy.stats import fisher_exact
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import obonet
import networkx as nx
from statsmodels.stats.multitest import multipletests

# Step 1: Load Protein IDs
# TODO : we basically did this above already for taxonomy task and just here neatly written into a function, so we could maybe just do it once in the whole code later on
def load_protein_ids(psiblast_file, hmm_file, e_threshold=0.001):
    """Load protein IDs from PSI-BLAST and HMM search results."""
    psiblast_df = pd.read_csv(psiblast_file)
    hmm_df = pd.read_csv(hmm_file)
    
    filtered_hmm_proteins = hmm_df[hmm_df['E-value'] <= e_threshold]['uniprot_id']
    psiblast_proteins = set(psiblast_df['uniprot_id'])
    hmm_proteins = set(filtered_hmm_proteins)
    
    return list(psiblast_proteins.union(hmm_proteins))


def fetch_go_annotations(protein_id):
    """
    Fetch and categorize GO annotations for a given protein ID from the UniProt API.
    
    Args:
        protein_id (str): The UniProt ID of the protein
        
    Returns:
        dict: A dictionary containing:
            - Categorized GO terms separated by molecular function, biological process, 
              and cellular component (new format)
    """
    # Define the UniProt API URL for XML data
    url = f"https://rest.uniprot.org/uniprotkb/{protein_id}.xml"
    
    try:
        # Fetch the XML data from UniProt
        response = requests.get(url)
        response.raise_for_status()
        
        # Initialize our data structures
        go_terms = []  # Original format
        categorized_terms = {
            'molecular_function': [],
            'biological_process': [],
            'cellular_component': []
        }
        
        # Set up namespace for XML parsing
        namespaces = {'ns': 'http://uniprot.org/uniprot'}
        root = ET.fromstring(response.content)
        
        # Find all GO term references in the XML
        for db_ref in root.findall(".//ns:dbReference[@type='GO']", namespaces):
            go_id = db_ref.attrib.get('id')
            term = db_ref.find("ns:property[@type='term']", namespaces)
            category = db_ref.find("ns:property[@type='category']", namespaces)
            
            if go_id and term is not None:
                # Store in original format
                term_value = term.attrib['value']

                
                # Categorize based on prefix
                if term_value.startswith('F:'):
                    categorized_terms['molecular_function'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'F:' prefix
                    })
                elif term_value.startswith('P:'):
                    categorized_terms['biological_process'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'P:' prefix
                    })
                elif term_value.startswith('C:'):
                    categorized_terms['cellular_component'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'C:' prefix
                    })
        
        return {
            'categorized': categorized_terms  # New categorized format
        }
        
    except requests.exceptions.RequestException as e:
        print(f"Error fetching GO annotations for {protein_id}: {e}")
        return {
            'categorized': {
                'molecular_function': [],
                'biological_process': [],
                'cellular_component': []
            }
        }

# -----------------------------------------------------------------------------
# Step 3: Fetch Random Proteins
# -----------------------------------------------------------------------------

def fetch_random_proteins(batch_size=100, total_proteins=500):
    """Fetch a list of random reviewed UniProt protein IDs."""
    url = "https://rest.uniprot.org/uniprotkb/stream?query=reviewed:true&format=list"
    try:
        response = requests.get(url)
        response.raise_for_status()
        all_proteins = response.text.splitlines()
        selected_proteins = random.sample(all_proteins, min(total_proteins, len(all_proteins)))
        return [selected_proteins[i:i + batch_size] for i in range(0, len(selected_proteins), batch_size)]
    except requests.exceptions.RequestException as e:
        print(f"Error fetching random proteins: {e}")
        return []

# -----------------------------------------------------------------------------
# Step 4: Load GO Ontology
# -----------------------------------------------------------------------------

def load_go_ontology():
    """Load the GO ontology from OBO file."""
    url = "http://purl.obolibrary.org/obo/go/go-basic.obo"
    graph = obonet.read_obo(url)
    return graph

# -----------------------------------------------------------------------------
# Step 5: Flatten Annotations for Analysis
# -----------------------------------------------------------------------------

def flatten_annotations(annotation_dict):
    """Flatten GO annotations into a list of GO terms."""
    flat_terms = []
    for annotations in annotation_dict.values():
        flat_terms.extend([a["GO_ID"] for a in annotations])
    return flat_terms

# -----------------------------------------------------------------------------
# Step 6: Enrichment Analysis
# -----------------------------------------------------------------------------

def calculate_enrichment(go_term, family_terms, background_terms):
    """Calculate enrichment of a GO term using Fisher's exact test."""
    family_count = family_terms.count(go_term)
    family_not = len(family_terms) - family_count
    background_count = background_terms.count(go_term)
    background_not = len(background_terms) - background_count

    contingency_table = [[family_count, background_count],
                         [family_not, background_not]]
    _, p_value = fisher_exact(contingency_table, alternative='greater')
    return p_value

# -----------------------------------------------------------------------------
# Step 7: Visualize Enriched Terms
# -----------------------------------------------------------------------------

def plot_wordcloud(enrichment_results, annotations):
    """Generate and save a word cloud of enriched GO terms."""
    term_names = {a["GO_ID"]: a["Term"] for ann_list in annotations.values() for a in ann_list}
    enriched_with_names = {term_names[go_id]: -np.log10(p) for go_id, p in enrichment_results.items() if go_id in term_names}

    if not enriched_with_names:
        print("No enriched terms found. Word cloud will not be generated.")
        return

    wordcloud = WordCloud(width=1000, height=600, background_color="white", colormap="viridis").generate_from_frequencies(enriched_with_names)
    wordcloud.to_file("enriched_terms_wordcloud.png")
    print("Word cloud saved as enriched_terms_wordcloud.png")

    plt.figure(figsize=(12, 6))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title("Word Cloud of Enriched GO Terms")
    plt.tight_layout()
    plt.show()

def plot_branch_enrichment(enrichment_results, go_graph):
    """Plot GO branch enrichment."""
    branch_scores = {}
    for go_id, p_value in enrichment_results.items():
        try:
            parents = nx.ancestors(go_graph, go_id)
            for parent in parents:
                if parent not in branch_scores:
                    branch_scores[parent] = []
                branch_scores[parent].append(p_value)
        except nx.NetworkXError:
            continue

    significant_branches = {branch: np.mean(scores)
                            for branch, scores in branch_scores.items() if len(scores) >= 3}

    # Limit to top 50 branches by score
    sorted_branches = sorted(significant_branches.items(), key=lambda x: x[1], reverse=True)[:50]
    branches, scores = zip(*sorted_branches)

    plt.figure(figsize=(14, 10))
    plt.barh(range(len(branches)), scores, color="steelblue")
    plt.yticks(range(len(branches)), [go_graph.nodes[branch]['name'] for branch in branches], fontsize=8)
    plt.xlabel('Mean p-value')
    plt.title('Top 50 GO Branch Enrichments')
    plt.tight_layout()
    plt.savefig("go_enrichment_branches.png", dpi=300, bbox_inches="tight")
    plt.legend(["Top 50 Branch Enrichments"], loc="lower right")
    print("GO branch enrichment plot saved as go_enrichment_branches.png")

# -----------------------------------------------------------------------------
# Step 8: Write Summary and Results to File
# -----------------------------------------------------------------------------

def write_summary_and_results(enrichment_results, family_annotations, go_graph):
    """Write a summary and detailed results to a text file."""
    with open("enrichment_results.txt", "w") as f:
        # Write summary
        f.write("SUMMARY\n")
        f.write("========\n")
        f.write(f"Number of enriched GO terms: {len(enrichment_results)}\n")
        f.write(f"Top enriched term: {max(enrichment_results, key=enrichment_results.get, default='None')}\n")
        f.write("\n\n")

        # Write detailed results
        f.write("DETAILED RESULTS\n")
        f.write("================\n")
        for go_id, p_value in enrichment_results.items():
            term_name = next((a["Term"] for ann_list in family_annotations.values() 
                             for a in ann_list if a["GO_ID"] == go_id), go_id)
            f.write(f"{go_id}: {term_name} (p-value: {p_value:.2e})\n")

        # Write branch scores
        f.write("\n\nSIGNIFICANT GO BRANCHES\n")
        f.write("========================\n")
        branch_scores = {}
        for go_id, p_value in enrichment_results.items():
            try:
                parents = nx.ancestors(go_graph, go_id)
                for parent in parents:
                    if parent not in branch_scores:
                        branch_scores[parent] = []
                    branch_scores[parent].append(p_value)
            except nx.NetworkXError:
                continue

        significant_branches = {branch: np.mean(scores) 
                                for branch, scores in branch_scores.items() if len(scores) >= 3}
        for branch, score in sorted(significant_branches.items(), key=lambda x: x[1]):
            branch_name = go_graph.nodes[branch].get('name', branch)
            f.write(f"{branch}: {branch_name} (mean p-value: {score:.2e})\n")

# -----------------------------------------------------------------------------
# Main Script
# -----------------------------------------------------------------------------

def main():
    print("Loading GO ontology...")
    go_graph = load_go_ontology()

    psiblast_file = "psiblast_parsed.csv"
    hmm_file = "hmmsearch_output.csv"
    protein_ids = load_protein_ids(psiblast_file, hmm_file)

    print("Fetching GO annotations...")
    family_annotations = {}
    for pid in tqdm(protein_ids, desc="Fetching GO annotations"):
        family_annotations[pid] = fetch_go_annotations(pid)


    print(family_annotations)


    print("Fetching background annotations...")
    background_annotations = {}
    background_batches = fetch_random_proteins(batch_size=50, total_proteins=500)
    for batch in tqdm(background_batches, desc="Processing background proteins"):
        for pid in batch:
            background_annotations[pid] = fetch_go_annotations(pid)

    print("Calculating enrichment...")
    family_terms = flatten_annotations(family_annotations)
    background_terms = flatten_annotations(background_annotations)

    unique_go_terms = set(family_terms)
    enrichment_results = {}
    pvalues = []
    terms = []

    for term in unique_go_terms:
        _, p_value = fisher_exact([
            [family_terms.count(term), len(family_terms) - family_terms.count(term)],
            [background_terms.count(term), len(background_terms) - background_terms.count(term)]
        ], alternative='greater')

        pvalues.append(p_value)
        terms.append(term)

    rejected, p_corrected, _, _ = multipletests(pvalues, method='fdr_bh')

    for term, p_value, significant in zip(terms, p_corrected, rejected):
        if significant:
            enrichment_results[term] = p_value

    print("Generating visualizations...")
    plot_wordcloud(enrichment_results, family_annotations)
    plot_branch_enrichment(enrichment_results, go_graph)

    print("Writing results to file...")
    write_summary_and_results(enrichment_results, family_annotations, go_graph)

if __name__ == "__main__":
    main()

In [None]:
    '''
    # Proteins_to_GO terms for our family 
    print("Fetching GO annotations...")
    family_annotations = {}
    for pid in tqdm(protein_ids, desc="Fetching GO annotations"):
        family_annotations[pid] = fetch_go_annotations(pid)

    total_proteins_family = len(family_annotations)

    go_id_to_go_term = fetch_go_terms(protein_ids)

    
    # Proteins_to_GO terms for SwissProt
    # Note that we downloaded "uniprot_sprot.xml" locally but not onto the GitHub due to its size
    swissprot_annotations = parse_swissprot_go_terms("uniprot_sprot.xml", protein_ids) #go_counts_swissprot, num_proteins_swissprot
    
    total_proteins_swissprot = len(swissprot_annotations)
    


    # Now Map the GO terms to the proteins ; for the enrichment task, we need to know how many proteins have a certain GO term
    go_to_proteins_swissprot = reverse_protein_go_dict(swissprot_annotations)
    go_to_proteins_family = reverse_protein_go_dict(family_annotations)
   



    
    # Calculate GO enrichments for both with skipped proteins and without 
    _ = calculate_go_enrichment(go_to_proteins_family, go_to_proteins_swissprot, total_proteins_family, total_proteins_swissprot, go_id_to_go_term)
    
    '''å‚

In [None]:
    
    '''
    def calculate_conservation_score(self, pos):
        """
        Calculate conservation score based on frequency of most common amino acid
        Ignores gaps in calculation
        """
        freqs = self.calculate_amino_acid_frequencies(pos)
        if not freqs:
            return 0
        return max(freqs.values())
    '''


    '''
    def calculate_group_conservation(self, pos):
        """
        Calculate conservation considering amino acid groups
        Basically the same as calculate_conversation_score, just that it calculates based on the groups, not single amino acids !
        """
        column = self.get_column(pos)
        groups = self.get_amino_acid_groups()
        
        # Assign each amino acid to its group
        aa_to_group = {}
        for group_name, aas in groups.items():
            for aa in aas:
                aa_to_group[aa] = group_name
        
        # Count group occurrences
        group_counts = Counter(aa_to_group.get(aa, 'other') 
                             for aa in column if aa != '-')
        
        if not group_counts:
            return 0
            
        return max(group_counts.values()) / sum(group_counts.values())
    '''

In [None]:
"""FOR HMM"""
# File paths
input_file_path = "Model/Evaluation/Predictions/HMM-SEARCH/hmmsearch_output.txt"
output_file_path = "Model/Evaluation/Predictions/HMM-SEARCH/hmmsearch_output.csv"

# Initialize storage for parsed data
parsed_data = []

# Regular expressions to capture key information
header_regex = r">> ([^\s]+)"
domain_regex = r"\s+(\d+) [!?]\s+[\d\.]+\s+[\d\.]+\s+[\de\.\+\-]+\s+([\de\.\+\-]+)\s+\d+\s+\d+\s+(?:\[\.|\.\.)+\s+(\d+)\s+(\d+)"

with open(input_file_path, "r") as infile:
    current_protein = None

    for line in infile:
        # Match protein header line
        header_match = re.match(header_regex, line)
        if header_match:
            # If we already captured a protein, save its data
            if current_protein:
                parsed_data.append(current_protein)

            # Start a new protein record
            protein_id = header_match.groups()[0]
            current_protein = {
                "protein_name": protein_id.split("|")[2],
                "uniprot_id": protein_id.split("|")[1],
                "domains": []
            }

        # Match domain annotation (including both `!` and `?` lines)
        domain_match = re.match(domain_regex, line)
        if domain_match and current_protein:
            _, score, start, end = domain_match.groups()
            start, end, score = int(start), int(end), float(score)
            length = end - start + 1
            current_protein["domains"].append((score, start, end, length))

    # Handle the last protein record
    if current_protein:
        parsed_data.append(current_protein)

# Prepare fieldnames dynamically
fieldnames = ["protein_name", "uniprot_id"]
max_domains = max(len(protein["domains"]) for protein in parsed_data)
for i in range(1, max_domains + 1):
    if i == 1:
        fieldnames.extend([
            f"E-value", f"domain_start", f"domain_end", f"domain_length"
        ])
    else:
        fieldnames.extend([
        f"domain_{i}_E-value", f"domain_{i}_start", f"domain_{i}_end", f"domain_{i}_length"
    ])

# Write to CSV
with open(output_file_path, "w", newline="") as outfile:
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)
    writer.writeheader()
    for protein in parsed_data:
        row = {
            "protein_name": protein["protein_name"],
            "uniprot_id": protein["uniprot_id"]
        }
        for i, domain in enumerate(protein["domains"], start=1):
            if i == 1:
                row[f"E-value"] = domain[0]
                row[f"domain_start"] = domain[1]
                row[f"domain_end"] = domain[2]
                row[f"domain_length"] = domain[3] 
            else:
                row[f"domain_{i}_E-value"] = domain[0]
                row[f"domain_{i}_start"] = domain[1]
                row[f"domain_{i}_end"] = domain[2]
                row[f"domain_{i}_length"] = domain[3]
        writer.writerow(row)

