# Biological Data Project

Group members:

- Alberto Calabrese

- Marlon Helbing

- Lorenzo Baietti

"A protein domain is a conserved part of a given protein sequence and tertiary structure that can evolve, function, and exist independently of the rest of the protein chain. Each domain forms a compact three-dimensional structure and often can be independently stable and folded." (Wikipedia).

The project is about the characterization of a single domain. Each group is provided with a representative domain sequence and the corresponding Pfam identifier (see table below). The objective of the project is to build a sequence model starting from the assigned sequence and to provide a functional characterization of the entire domain family (homologous proteins).

## Input
A representative sequence of the domain family. Columns are: group, UniProt accession, organism, Pfam identifier, Pfam name, domain position in the corresponding UniProt protein, domain sequence.

```
UniProt : P54315 
PfamID : PF00151 
Domain Position : 18-353 
Organism : Homo sapiens (Human) 
Pfam Name : Lipase/vitellogenin 
Domain Sequence : KEVCYEDLGCFSDTEPWGGTAIRPLKILPWSPEKIGTRFLLYTNENPNNFQILLLSDPSTIEASNFQMDRKTRFIIHGFIDKGDESWVTDMCKKLFEVEEVNCICVDWKKGSQATYTQAANNVRVVGAQVAQMLDILLTEYSYPPSKVHLIGHSLGAHVAGEAGSKTPGLSRITGLDPVEASFESTPEEVRLDPSDADFVDVIHTDAAPLIPFLGFGTNQQMGHLDFFPNGGESMPGCKKNALSQIVDLDGIWAGTRDFVACNHLRSYKYYLESILNPDGFAAYPCTSYKSFESDKCFPCPDQGCPQMGHYADKFAGRTSEEQQKFFLNTGEASNF
```

## Domain model definition
The objective of the first part of the project is to build a PSSM and HMM model representing the assigned domain. The two models will be generated starting from the assigned input sequence. The accuracy of the models will be evaluated against Pfam annotations as provided in the SwissProt database.

In [1]:
from Bio import AlignIO
from collections import Counter
import pandas as pd
from scipy.stats import entropy
import math
from Bio.Align import MultipleSeqAlignment
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import SeqIO
import sys

In [None]:
class ConservationAnalyzer:
    def __init__(self, alignment_file):
        """
        Initialize with an alignment file
            alignment_file (str): Path to the alignment file
        """
        self.alignment = AlignIO.read(alignment_file, 'fasta')
        self.num_sequences = len(self.alignment)
        self.alignment_length = self.alignment.get_alignment_length()
        
    def get_column(self, pos):
        """Extract a column from the alignment"""
        return [record.seq[pos] for record in self.alignment]
    
    def calculate_gap_frequency(self, pos):
        """Calculate frequency of gaps in a column"""
        column = self.get_column(pos)
        return column.count('-') / len(column)
    
    def calculate_amino_acid_frequencies(self, pos):
        """Calculate frequencies of each amino acid in a column"""
        column = self.get_column(pos)
        total = len(column) - column.count('-')  # Don't count gaps, such that when we calculate conservation scores the gaps don't mess it up 
        if total == 0:
            return {}
        
        counts = Counter(aa for aa in column if aa != '-')
        return {aa: count/total for aa, count in counts.items()}
    
    def calculate_conservation_score(self, pos):
        """
        Calculate conservation score based on frequency of most common amino acid
        Ignores gaps in calculation
        """
        freqs = self.calculate_amino_acid_frequencies(pos)
        if not freqs:
            return 0
        return max(freqs.values())
    
    def calculate_entropy(self, pos):
        """
        Calculate Shannon entropy for a column
        Lower entropy means higher conservation
        """
        freqs = self.calculate_amino_acid_frequencies(pos)
        if not freqs:
            return float('inf')  
        
        return -sum(p * math.log2(p) for p in freqs.values())
    
    def get_amino_acid_groups(self):
        """Define groups of similar amino acids 
           Based on : https://en.wikipedia.org/wiki/Conservative_replacement#:~:text=There%20are%2020%20naturally%20occurring,both%20small%2C%20negatively%20charged%20residues.
        """
        return {
            'aliphatic': set('GAVLI'),
            'hydroxyl': set('SCUTM'),
            'cyclic': set('P'),
            'aromatic': set('FYW'),
            'basic': set('HKR'),
            'acidic': set('DENQ')
        }
    
    def calculate_group_conservation(self, pos):
        """
        Calculate conservation considering amino acid groups
        Basically the same as calculate_conversation_score, just that it calculates based on the groups, not single amino acids !
        """
        column = self.get_column(pos)
        groups = self.get_amino_acid_groups()
        
        # Assign each amino acid to its group
        aa_to_group = {}
        for group_name, aas in groups.items():
            for aa in aas:
                aa_to_group[aa] = group_name
        
        # Count group occurrences
        group_counts = Counter(aa_to_group.get(aa, 'other') 
                             for aa in column if aa != '-')
        
        if not group_counts:
            return 0
            
        return max(group_counts.values()) / sum(group_counts.values())



    """
    def find_similar_sequences(self, similarity_threshold):
        # TODO : I think using JalView for this is better : JalView --> Edit --> Remove Redundancy 
        similar_pairs = []
        
        for i in range(len(self.alignment)):
            for j in range(i + 1, len(self.alignment)):
                seq1 = str(self.alignment[i].seq)
                seq2 = str(self.alignment[j].seq)
                
                # Calculate similarity (ignoring gaps)
                matches = sum(a == b for a, b in zip(seq1, seq2) if a != '-' and b != '-')
                total = sum(1 for a, b in zip(seq1, seq2) if a != '-' and b != '-')
                
                if total > 0:
                    similarity = matches / total
                    if similarity >= similarity_threshold:
                        similar_pairs.append((
                            self.alignment[i].id,
                            self.alignment[j].id,
                            similarity
                        ))
    
        return similar_pairs


    def analyze_rows(self, similarity_threshold = 0.95):
        similar_pairs = self.find_similar_sequences(similarity_threshold)
        print(f"We have {len(similar_pairs)} many pairs with {similarity_threshold} or more identity (excluding gaps) of a total of {self.num_sequences} sequences")
    """

    # TODO : I took very strict values now such that the number of residues per sequence is below 100 (right now we have length 77) ; the PSSM creation with 
    # much higher length did not work, but maybe we should write an email and ask ; nevertheless, we can first try some evaluation based on that PSSM and see our scores
    def analyze_columns(self, gap_threshold=0.37, conservation_threshold=0.9):
        """
        Analyze all columns and return comprehensive metrics
        Returns DataFrame with various conservation metrics for each position
        """
        data = []
        
        for i in range(self.alignment_length):
            gap_freq = self.calculate_gap_frequency(i)
            cons_score = self.calculate_conservation_score(i)
            info_content = self.calculate_entropy(i)
            group_cons = self.calculate_group_conservation(i)
            
            data.append({
                'position': i + 1,
                'gap_frequency': gap_freq,
                'single_conservation': cons_score,
                'entropy': info_content,
                'group_conservation': group_cons,
                # Here we should look possibly for better ideas
                # Check gap frequency not too high (i.e. not nearly all elements in the columns gaps (-))
                # Check that the group conservation is high enough (i.e. the amino acids are not too different
                # ; right now we do with groups and not single amino acid sequence since I'd say the groups
                # are more representative (if we do single amino acids, we'd delete more stuff))
                'suggested_remove': (gap_freq > gap_threshold or       
                                   group_cons < conservation_threshold)
            })
        
        return pd.DataFrame(data)

In [None]:
def remove_columns_from_alignment(input_file, output_file, columns_to_remove, format="fasta"):
    """
    Remove specified columns from a multiple sequence alignment and save to new file
    
    Args:
        input_file (str): Path to input alignment file
        output_file (str): Path where to save trimmed alignment
        columns_to_remove (list): List of column indices to remove (0-based)
        format (str): File format (default: "fasta")
    """
    # Read the alignment
    alignment = AlignIO.read(input_file, format)
    
    # Sort columns to remove in descending order
    # (so removing them doesn't affect the indices of remaining columns)
    columns_to_remove = sorted(columns_to_remove, reverse=True)
    
    # Create new alignment records
    new_records = []
    
    # Process each sequence
    for record in alignment:
        # Convert sequence to list for easier manipulation
        seq_list = list(record.seq)
        
        # Remove specified columns
        for col in columns_to_remove:
            del seq_list[col]
        
        # Create new sequence record
        new_seq = Seq(''.join(seq_list)) # Join the list element to a string again (i.e. after removal of amino acids out of sequence represented as list, turn into one string again) and turn into Seq object
        new_record = SeqRecord(new_seq,
                            id=record.id,
                            name=record.name,
                            description=record.description)
        new_records.append(new_record)
    
    # Create new alignment
    # TODO : Maybe we have to add some variables here (i.e. how to do the MSA)!
    new_alignment = MultipleSeqAlignment(new_records)
    
    # Write to file
    AlignIO.write(new_alignment, output_file, format)
    
    return new_alignment

In [None]:
# Example usage:
if __name__ == "__main__":
    # Initialize analyzer 
    analyzer = ConservationAnalyzer("clustal_rows_removed_100threshold.fa")
    
    # Get comprehensive analysis
    analysis = analyzer.analyze_columns()
   # analysis_2 = analyzer.analyze_rows()
    
    # Print summary statistics
    print("\nAlignment Summary:")
    print(f"Number of sequences: {analyzer.num_sequences}")
    print(f"Alignment length: {analyzer.alignment_length}")


    # Print number of True/False
    counts = analysis['suggested_remove'].value_counts()

    counts_true = counts[True]  # To be removed
    counts_false = counts[False] # To be kept

    print(f"With the current removal tactic, we would remove {(counts_true / (counts_true + counts_false)):.2f} percent of columns ; we keep {counts_false} of {counts_false + counts_true} columns")
    

    # Save detailed analysis to CSV
    analysis.to_csv("conservation_analysis.csv", index=False)


    # Get indices of columns marked for removal
    columns_to_remove = analysis[analysis['suggested_remove']]['position'].values.tolist()
    # Convert to 0-based indices (if positions were 1-based)
    columns_to_remove = [x-1 for x in columns_to_remove]
    
    # Remove columns and save new alignment
    new_alignment = remove_columns_from_alignment(
        "clustal_rows_removed_100threshold.fa",
        "trimmed_alignment.fasta",
        columns_to_remove
    )


        


    print(f"Original alignment length: {analyzer.alignment_length}")
    print(f"Number of columns removed: {len(columns_to_remove)}")
    print(f"New alignment length: {new_alignment.get_alignment_length()}")

## Models building

1. Retrieve homologous proteins starting from your input sequence performing a BLAST search
against UniProt or UniRef50 or UniRef90, or any other database

2. Generate a multiple sequence alignment (MSA) starting from retrieved hits using T-coffee or
ClustalOmega or MUSCLE

3. If necessary, edit the MSA with JalView (or with your custom script or CD-HIT) to remove not
conserved positions (columns) and/or redundant information (rows)

4. Build a PSSM model starting from the MSA

5. Build a HMM model starting from the MSA

## Models evaluation
1. Generate predictions. Run HMM-SEARCH and PSI-BLAST with your models against
SwissProt.

    - Collect the list of retrieved hits

    - Collect matching positions of your models in the retrieved hits

2. Define your ground truth. Find all proteins in SwissProt annotated (and not annotated) with the assigned Pfam domain

    - Collect the list of proteins matching the assigned Pfam domain

    - Collect matching positions of the Pfam domain in the retrieved sequences. Domain positions are available here (large tsv file) or using the InterPro API or align the Pfam domain yourself against SwissProt (HMMSEARCH)

3. Compare your model with the assigned Pfam. Calculate the precision, recall, F-score, balanced accuracy, MCC

    - Comparison at the protein level. Measure the ability of your model to retrieve the same proteins matched by Pfam

    - Comparison at the residue level. Measure the ability of your model to match the same position matched by Pfam

4. Consider refining your models to improve their performance

## Domain family characterization
Once the family model is defined (previous step), you will look at functional (and structural) aspects/properties of the entire protein family. The objective is to provide insights about the main function of the family.

### Taxonomy

1. Collect the taxonomic lineage (tree branch) for each protein of the family_sequences dataset
from UniProt (entity/organism/lineage in the UniProt XML)

2. Plot the taxonomic tree of the family with nodes size proportional to their relative abundance 


In [2]:
# MARLON EDIT 
import pandas as pd 

# Changed that we take a union of the proteins found by HMM (with e-value thresh of 0.001) and PSIBLAST (all 21) to represent our "family_sequences"
e_threshold = 0.001
psiblast_df = pd.read_csv("psiblast_parsed.csv")
hmm_df = pd.read_csv("hmmsearch_output.csv")
filtered_hmm_proteins = hmm_df[hmm_df['E-value'] <= e_threshold]['uniprot_id']
    
psiblast_proteins = set(psiblast_df['uniprot_id'])
hmm_proteins = set(filtered_hmm_proteins)


family_sequences = list(psiblast_proteins.union(hmm_proteins))





In [13]:
# MARLON EDIT 
import requests
import csv
import pandas as pd
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist
from collections import Counter
import numpy as np
from typing import List, Dict, Tuple
import time
from tqdm import tqdm
import logging
import os
from datetime import datetime


class TaxonomyAnalyzer:
    def __init__(self, max_retries: int = 3, retry_delay: int = 1):
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.uniprot_base_url = "https://rest.uniprot.org/uniprotkb/"
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            filename="Final_taxonomy_analysis.log"
        )


    def fetch_taxonomy_info(self, protein_ids: List[str], output_file: str) -> str:
        """
        Fetch taxonomy information 
        """
        taxonomy_data = []
        error_counts = {"success": 0, "failed": 0}

        for protein_id in tqdm(protein_ids, desc="Fetching taxonomy info"):
            for attempt in range(self.max_retries):
                try:
                    response = requests.get(f"{self.uniprot_base_url}{protein_id}.json")
                    response.raise_for_status()
                    data = response.json()

                    taxonomy = data.get("organism", {})
                    scientific_name = taxonomy.get("scientificName", "N/A")
                    lineage = taxonomy.get("lineage", [])
                    taxonomy_data.append([protein_id, scientific_name, " > ".join(lineage)])

                    error_counts["success"] += 1
                    break
                
                except requests.exceptions.RequestException as e:
                    if attempt == self.max_retries - 1:
                        logging.error(f"Failed to fetch {protein_id} after {self.max_retries} attempts: {str(e)}")
                        taxonomy_data.append([protein_id, "Error", ""])
                        error_counts["failed"] += 1
                    else:
                        time.sleep(self.retry_delay)
                        
        with open(output_file, "w", newline="") as file:
            writer = csv.writer(file)
            writer.writerow(["Protein ID", "Scientific Name", "Lineage"])
            writer.writerows(taxonomy_data)
        logging.info(f"Taxonomy data saved to {output_file}")

        return output_file

  


In [25]:
#MARLON EDIT
from Bio import Phylo
import matplotlib.pyplot as plt
import pandas as pd
from io import StringIO

class PhyloTreeVisualizer:
    def create_newick_string(self, taxonomy_df):
        """Convert taxonomy data to Newick format"""
        lineage_counts = {}
        for _, row in taxonomy_df.iterrows():
            if isinstance(row['Lineage'], str):
                taxa = row['Lineage'].split(' > ')
                for i in range(len(taxa)):
                    lineage = ' > '.join(taxa[:i+1])
                    lineage_counts[lineage] = lineage_counts.get(lineage, 0) + 1

        def build_newick(taxa, parent=''):
            current = taxa[-1] if taxa else ''
            current_path = ' > '.join(taxa)
            count = lineage_counts.get(current_path, 1)
            
            children = []
            for lineage in lineage_counts.keys():
                if lineage.startswith(current_path + ' > '):
                    next_level = lineage.split(' > ')[len(taxa)]
                    if next_level not in children:
                        children.append(next_level)
            
            if children:
                child_strings = [build_newick(taxa + [child]) for child in children]
                return f"({','.join(child_strings)}){current}:{count}"
            else:
                return f"{current}:{count}"

        root_taxa = set()
        for lineage in lineage_counts.keys():
            root = lineage.split(' > ')[0]
            if root not in root_taxa:
                root_taxa.add(root)
        
        newick = f"({','.join(build_newick([taxa]) for taxa in root_taxa)});"
        return newick

    def create_phylogenetic_tree(self, taxonomy_file, output_file):
        """Create and save a phylogenetic tree visualization"""
        # Read taxonomy data
        df = pd.read_csv(taxonomy_file)
        
        # Create Newick string
        newick_str = self.create_newick_string(df)
        
        # Parse tree from Newick format
        handle = StringIO(newick_str)
        tree = Phylo.read(handle, "newick")

        # Check with ASCII representation
        print("ASCII representation of the tree:")
        Phylo.draw_ascii(tree)
        
        # Set up the plot with larger figure size and adjusted dimensions
        fig = plt.figure(figsize=(20, 30))  # Increased figure size
        
        # Draw the tree with customized parameters
        axes = fig.add_subplot(1, 1, 1)
        Phylo.draw(tree, 
                  axes=axes,
                  do_show=False,
                  branch_labels=lambda c: str(int(c.branch_length)) if c.branch_length else '')
        
        # Adjust the plot
        axes.set_title("Taxonomic Tree with Branch Lengths Showing Relative Abundance", pad=20, size=16)
        axes.set_xlabel("Relative Abundance", size=12)
        
        # Increase spacing between taxa
        axes.set_xticks(axes.get_xticks())
        axes.set_yticks(axes.get_yticks())
        
        # Adjust label sizes and spacing
        plt.setp(axes.get_xticklabels(), fontsize=10)
        plt.setp(axes.get_yticklabels(), fontsize=10, style='italic')
        
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        
        # Save the plot with high resolution
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_file





In [22]:
#MARLON EDIT
def visualize_phylogenetic_tree(taxonomy_file, output_file):
    visualizer = PhyloTreeVisualizer()
    return visualizer.create_phylogenetic_tree(taxonomy_file, output_file)

In [26]:
#MARLON EDIT
def main():
    # First fetch taxonomy info as before
    analyzer = TaxonomyAnalyzer()
    taxonomy_file = analyzer.fetch_taxonomy_info(family_sequences, "Final_taxonomy_info.csv")
    
    # Create the phylogenetic tree
    tree_file = visualize_phylogenetic_tree(taxonomy_file, "Final_phylogenetic_tree.png")
    
    print("\nFiles created:")
    print(f"Taxonomy data: {taxonomy_file}")
    print(f"Phylogenetic tree: {tree_file}")

if __name__ == "__main__":
    main()

Fetching taxonomy info: 100%|██████████| 57/57 [01:23<00:00,  1.47s/it]



Files created:
Taxonomy data: Final_taxonomy_info.csv
Phylogenetic tree: Final_phylogenetic_tree.png


In [6]:
import requests
import pandas as pd
from ete3 import Tree, TreeStyle, NodeStyle, TextFace
from tqdm import tqdm
import time

# TaxonomyAnalyzer Class for fetching taxonomy information
class TaxonomyAnalyzer:
    def __init__(self, max_retries: int = 3, retry_delay: int = 1):
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.uniprot_base_url = "https://rest.uniprot.org/uniprotkb/"

    def fetch_taxonomy_info(self, protein_ids: list, output_file: str):
        taxonomy_data = []

        pbar = tqdm(protein_ids, desc="Fetching taxonomy data")

        for protein_id in pbar:
            pbar.set_description(f"Processing {protein_id}")

            for attempt in range(self.max_retries):
                try:
                    response = requests.get(f"{self.uniprot_base_url}{protein_id}.json")
                    response.raise_for_status()
                    data = response.json()

                    taxonomy = data.get("organism", {})
                    scientific_name = taxonomy.get("scientificName", "N/A")
                    lineage = taxonomy.get("lineage", [])

                    taxonomy_data.append([protein_id, scientific_name, " > ".join(lineage)])
                    break

                except requests.exceptions.RequestException as e:
                    print(f"Error fetching data for {protein_id}: {e}")
                    if attempt == self.max_retries - 1:
                        taxonomy_data.append([protein_id, "Error", ""])
                    else:
                        time.sleep(self.retry_delay)

        taxonomy_df = pd.DataFrame(taxonomy_data, columns=["Protein ID", "Scientific Name", "Lineage"])
        taxonomy_df.to_csv(output_file, index=False)
        return taxonomy_df

# Load protein IDs from files
def load_protein_ids(psiblast_file, hmm_file, e_threshold=0.001):
    psiblast_df = pd.read_csv(psiblast_file)
    hmm_df = pd.read_csv(hmm_file)

    filtered_hmm_proteins = hmm_df[hmm_df['E-value'] <= e_threshold]['uniprot_id']
    psiblast_proteins = set(psiblast_df['uniprot_id'])
    hmm_proteins = set(filtered_hmm_proteins)

    return list(psiblast_proteins.union(hmm_proteins))

# Process taxonomy data
def process_taxonomy(data, correct_column_name):
    taxonomy_dict = {}
    frequency_counts = {}
    
    for _, row in data.iterrows():
        lineage = row[correct_column_name].split(" > ")
        current = taxonomy_dict
        # Track the full path to maintain hierarchy information
        current_path = [] # such that we count occurences of terms in the correct "level" where they appear (i.e. always count just in the "column" of the linage)
        
        for level in lineage:
            current_path.append(level)
            path_key = " > ".join(current_path)
            
            # Count frequencies using the full path as key
            if path_key not in frequency_counts:
                frequency_counts[path_key] = 0
            frequency_counts[path_key] += 1
            
            if level not in current:
                current[level] = {}
            current = current[level]
    
    return taxonomy_dict, frequency_counts

# Create a Newick string for the taxonomy tree
def dict_to_newick(d, parent_abundance=None):
    newick = ""
    for key, sub_dict in d.items():
        size = parent_abundance.get(key, 1) if parent_abundance else 1
        sub_tree = dict_to_newick(sub_dict, parent_abundance)
        newick += f"({sub_tree}){key}:{size}," if sub_tree else f"{key}:{size},"
    return newick.rstrip(",")



# Fetch taxonomy data
def main():
    psiblast_file = "psiblast_parsed.csv"
    hmm_file = "hmmsearch_output.csv"
    protein_ids = load_protein_ids(psiblast_file, hmm_file)

    analyzer = TaxonomyAnalyzer()
    taxonomy_data = analyzer.fetch_taxonomy_info(protein_ids, "taxonomy_info.csv")

    print("Taxonomy file saved to: taxonomy_info.csv")

    # Correct column name
    correct_column_name = "Lineage"  # Use the correct column name

    # Create a nested dictionary of taxonomy
    taxonomy_dict, frequency_counts = process_taxonomy(taxonomy_data, correct_column_name)

    # Count relative abundance (of the different paths ! ; right now we don't really use that)
    abundance_counts = taxonomy_data[correct_column_name].value_counts().to_dict()

    

# TODO : abundance counts used here, but it doesn't show at all in the graph ; we need to ask professor if what we have now is already enough, then we should remove 
# TODO : this part with abundance counts
    newick_tree = f"({dict_to_newick(taxonomy_dict, abundance_counts)});"

    # Plot using ETE Toolkit
    phylo_tree = Tree(newick_tree, format=1)
    tree_style = TreeStyle()
    tree_style.show_leaf_name = False


    # Adjust node sizes (normalize and refine scaling)
    max_size = 50  # Increase max size for better differentiation
    scaling_factor = 2  # Further refine scaling for visual contrast
    for node in phylo_tree.traverse():
        # Get the full path from root to this node
        path = []
        current = node
        while current:
            if current.name:  # Skip empty names
                path.insert(0, current.name)
            current = current.up
        
        path_key = " > ".join(path)
        count = frequency_counts.get(path_key, 1)
        nstyle = NodeStyle()
        size = abundance_counts.get(node.name, 1)
        nstyle["size"] = min(size * scaling_factor, max_size)  # Scale and cap node size
        node.set_style(nstyle)
        # Add label with name and count
        node.add_face(TextFace(f"{node.name} ({count})", fsize=10), column=0)

    # Improve tree spacing
    tree_style.branch_vertical_margin = 30  # Increase spacing for better visibility

    # Save the tree to a high-resolution PNG file
    output_file = "phylogenetic_tree_freq.png"
    phylo_tree.render(output_file, w=3000, h=2000, tree_style=tree_style)

    print(f"Tree saved to: {output_file}")

if __name__ == "__main__":
    main()


Fetching taxonomy data:   0%|          | 0/57 [00:00<?, ?it/s]

Processing Q64425: 100%|██████████| 57/57 [00:14<00:00,  3.92it/s]


Taxonomy file saved to: taxonomy_info.csv
Tree saved to: phylogenetic_tree_freq.png


### Function

1. Collect GO annotations for each protein of the family_sequences dataset (entity/dbReference type="GO" in the UniProt XML)

2. Calculate the enrichment of each term in the dataset compared to GO annotations available in the SwissProt database (you can download the entire SwissProt XML here). You can use Fisher’ exact test and verify that both two-tails and right-tail P-values (or left-tail depending on how you build the confusion matrix) are close to zero

3. Plot enriched terms in a word cloud 

4. Take into consideration the hierarchical structure of the GO ontology and report most significantly enriched branches, i.e. high level terms

5. Always report the full name of the terms and not only the GO ID

In [1]:
!pip install obonet
!pip install statsmodels
!pip install goatools



In [59]:
# ONTOLOGY MARLON 

import requests
import pandas as pd
import xml.etree.ElementTree as ET
from scipy.stats import fisher_exact
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import obonet
import networkx as nx
from statsmodels.stats.multitest import multipletests
from collections import defaultdict

# Step 1: Load Protein IDs
# TODO : we basically did this above already for taxonomy task and just here neatly written into a function, so we could maybe just do it once in the whole code later on
def load_protein_ids(psiblast_file, hmm_file, e_threshold=0.001):
    """Load protein IDs from PSI-BLAST and HMM search results."""
    psiblast_df = pd.read_csv(psiblast_file)
    hmm_df = pd.read_csv(hmm_file)
    
    filtered_hmm_proteins = hmm_df[hmm_df['E-value'] <= e_threshold]['uniprot_id']
    psiblast_proteins = set(psiblast_df['uniprot_id'])
    hmm_proteins = set(filtered_hmm_proteins)
    
    return list(psiblast_proteins.union(hmm_proteins))

'''
def fetch_go_annotations(protein_id):
    """
    Fetch and categorize GO annotations for a given protein ID from the UniProt API.
    
    Args:
        protein_id (str): The UniProt ID of the protein
        
    Returns:
        dict: A dictionary containing:
            - Categorized GO terms separated by molecular function, biological process, 
              and cellular component (new format)
    """
    # Define the UniProt API URL for XML data
    url = f"https://rest.uniprot.org/uniprotkb/{protein_id}.xml"

    try:
        # Fetch the XML data from UniProt
        response = requests.get(url)
        response.raise_for_status()
        
        # Initialize our data structures
        go_terms = []  # Original format
        categorized_terms = {
            'molecular_function': [],
            'biological_process': [],
            'cellular_component': []
        }
        
        # Set up namespace for XML parsing
        namespaces = {'ns': 'http://uniprot.org/uniprot'}
        root = ET.fromstring(response.content)
        
        # Find all GO term references in the XML
        for db_ref in root.findall(".//ns:dbReference[@type='GO']", namespaces):
            go_id = db_ref.attrib.get('id')
            term = db_ref.find("ns:property[@type='term']", namespaces)

            go_term = term.get('value')
            
            if go_id and term is not None:
                # Store in original format
                term_value = term.attrib['value']
                
                # Categorize based on prefix
                if term_value.startswith('F:'):
                    categorized_terms['molecular_function'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'F:' prefix
                    })
                elif term_value.startswith('P:'):
                    categorized_terms['biological_process'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'P:' prefix
                    })
                elif term_value.startswith('C:'):
                    categorized_terms['cellular_component'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'C:' prefix
                    })
        
        return {
            'categorized': categorized_terms  # New categorized format
}
        
    except requests.exceptions.RequestException as e:
        print(f"Error fetching GO annotations for {protein_id}: {e}")
        return {
            'categorized': {
                'molecular_function': [],
                'biological_process': [],
                'cellular_component': []
            }
        }
    '''

# STEP 1 

 # Define the UniProt API URL for XML data
def fetch_go_annotations(protein_id):
    """
    Fetch GO annotations and create GO ID to protein list mapping.
    
    Args:
        protein_ids (list): List of UniProt protein IDs
        
    Returns:
        List : List of the GO ids found for that protein
    """
    go_ids = []
    

    url = f"https://rest.uniprot.org/uniprotkb/{protein_id}.xml"

    try:
        response = requests.get(url)
        response.raise_for_status()
        
        namespaces = {'ns': 'http://uniprot.org/uniprot'}
        root = ET.fromstring(response.content)
        
        for db_ref in root.findall(".//ns:dbReference[@type='GO']", namespaces):
            go_id = db_ref.attrib.get('id')
            
            if go_id:
                go_ids.append(go_id)
                
    except requests.exceptions.RequestException as e:
        print(f"Error fetching GO annotations for {protein_id}: {e}")
   
            
    return go_ids



def fetch_go_terms(protein_ids):

    go_terms = {}

    for protein_id in protein_ids:
        url = f"https://rest.uniprot.org/uniprotkb/{protein_id}.xml"
        
        try:
            response = requests.get(url)
            response.raise_for_status()
            
            namespaces = {'ns': 'http://uniprot.org/uniprot'}
            root = ET.fromstring(response.content)
            
            for db_ref in root.findall(".//ns:dbReference[@type='GO']", namespaces):
                go_id = db_ref.attrib.get('id')
                term = db_ref.find("ns:property[@type='term']", namespaces)
                if go_id and term is not None:
                    go_term = term.get('value')
                    go_terms[go_id] = go_term
                    
        except requests.exceptions.RequestException as e:
            print(f"Error fetching GO terms for {protein_id}: {e}")
                
    return go_terms

# Let's add some debugging to help understand what's happening
# here we see that the big .xml file has the same structure as the small ones 
# we already analyzed ; thus,we can use the same parsing structure, but this time directly
# just collect the counts of GO terms, because that is all we need (no diff. categories, would just make our code slower)
def print_swissprot_file(swissprot_xml_path, length = 50):
    """
    Just to look at the first few lines to see the structure
    """

    with open(swissprot_xml_path, 'r') as f:
        print("First length lines of the file:")
        for i, line in enumerate(f):
            if i < length:
                print(line.strip())
            else:
                break



'''
def parse_swissprot_go_terms(swissprot_xml_path, family_proteins, skip_proteins):
    """
    Parse GO terms from SwissProt XML file, excluding proteins from our family.
    
    Args:
        swissprot_xml_path (str): Path to the SwissProt XML file
        family_proteins (set): Set of UniProt IDs in our protein family
        skip_proteins (bool): Whether to skip proteins in our family
    
    Returns:
        tuple: (go_term_counts dictionary, total proteins processed)
    """
    # Initialize counters
    go_term_counts = defaultdict(int)
    total_proteins = 0
    skipped_proteins = 0
    
    # Set up namespace for XML parsing
    namespaces = {'ns': 'http://uniprot.org/uniprot'}
    
    # Use iterparse for memory-efficient parsing
    context = ET.iterparse(swissprot_xml_path, events=('end',))
    
    print("Starting to parse SwissProt XML...")
    
    for event, elem in context:
        if elem.tag.endswith('entry'):
            # Get the UniProt ID for this protein
            accession = elem.find(".//ns:accession", namespaces)
            if accession is not None:
                uniprot_id = accession.text
                
                # Skip if this protein is in our family (we need this for the enrichment task to create the contigency table later on)
    
                if uniprot_id in family_proteins and skip_proteins:
                        skipped_proteins += 1
                else:
                    # Process GO terms for non-family proteins
                    for db_ref in elem.findall(".//ns:dbReference[@type='GO']", namespaces):
                        go_id = db_ref.attrib.get('id')
                        if go_id:
                            go_term_counts[go_id] += 1
                    total_proteins += 1
        

            
            # Clear the element to save memory
            elem.clear()
            
            # Print progress periodically
            if (total_proteins + skipped_proteins) % 10000 == 0:
                print(f"Processed {total_proteins} proteins "
                      f"(skipped {skipped_proteins} family proteins)...")
              #  break
        

    
    return go_term_counts, total_proteins
    '''

def parse_swissprot_go_terms(swissprot_xml_path, family_proteins):
   """
   Parse GO terms from SwissProt XML file for each protein.
   
   Args:
       swissprot_xml_path (str): Path to SwissProt XML file
       family_proteins (set): UniProt IDs in protein family
   
   Returns:
       dict: protein ID -> list of GO IDs for that protein
   """
   protein_to_go = defaultdict(list)
   total_proteins = 0
   skipped_proteins = 0
   
   namespaces = {'ns': 'http://uniprot.org/uniprot'}
   context = ET.iterparse(swissprot_xml_path, events=('end',))
   
   print("Starting to parse SwissProt XML...")
   
   for event, elem in context:
       if elem.tag.endswith('entry'):
           accession = elem.find(".//ns:accession", namespaces)
           if accession is not None:
               uniprot_id = accession.text
               
               if uniprot_id in family_proteins:
                   skipped_proteins += 1
               else:
                   for db_ref in elem.findall(".//ns:dbReference[@type='GO']", namespaces):
                       go_id = db_ref.attrib.get('id')
                       if go_id:
                           protein_to_go[uniprot_id].append(go_id)
                   total_proteins += 1

           elem.clear()
           
           if (total_proteins + skipped_proteins) % 10000 == 0:
               print(f"Processed {total_proteins} proteins "
                     f"(skipped {skipped_proteins} family proteins)...")
               
                    
               
   return protein_to_go

def calculate_go_enrichment(go_to_proteins_family, go_to_proteins_swissprot, total_proteins_family, total_proteins_swissprot, go_id_to_go_term):
    results = []
    
    
    for go_id in go_to_proteins_family.keys():
   
        # Create the 2x2 contingency table for Fisher's exact test
        # The table looks like this:
        #                   Protein in family    Protein not in family (i.e. all in SwissProt - family proteins)
        # Has GO term            a                    b
        # No GO term             c                    d
        
        # Contingency table calculations:
        a = len(go_to_proteins_family[go_id])  # Proteins with this GO term in family
        
        # For b, we need to make sure we don't subtract more than what's in SwissProt
        b = len(go_to_proteins_swissprot.get(go_id, []))  # Proteins with GO term in rest of SwissProt 
        
        c = total_proteins_family - a  # Proteins without GO term in family
        
        # For d, ensure we don't get negative values by using max
        d = total_proteins_swissprot - b
        
        # Verify all values are non-negative before creating contingency table
        if all(x >= 0 for x in [a, b, c, d]):
            contingency_table = [[a, b], [c, d]]
            
            # Perform Fisher's exact test
            # We ask : is the GO term appearing more often in our family than we would expect by random chance ?
            # The null hypothesis (H0) is: "The proportion of proteins with this GO term in our family 
            # is the same as the proportion in the SwissProt dataset (without the protein in the family)." 
            # In other words, under H0, getting the GO term is independent of being in our family (so it doesn't represent the family)
            # Alternative Hypothesis (H1) depends on what tail to use 
            #Right-tail (greater): Our family has a higher proportion of this GO term than SwissProt
            #Left-tail (less): Our family has a lower proportion of this GO term than SwissProt
            #Two-tail (two-sided): The proportion is different (either higher or lower)
            #Fisher's exact test calculates the probability of seeing our observed data (or more extreme) under the null hypothesis.
            #A very small p-value (like < 0.05) tells us:
            #Two-tail: This GO term's frequency is significantly different from SwissProt
            #Right-tail: This GO term is significantly enriched in our family(overrepresented)
            #Left-tail: This GO term is significantly depleted in our family(underrepresented)

            odds_ratio, pvalue_two_tail = fisher_exact(contingency_table, alternative='two-sided')
            # TODO : including both the p-values for now, we have to understand when to use what (like asked in the task), 
            # TODO : i.e. how we ordered the confusion matrix (contingency table)
            _, pvalue_greater = fisher_exact(contingency_table, alternative='greater')
          #  _, pvalue_less = fisher_exact(contingency_table, alternative='less')
            
            # Calculate fold enrichment safely
            my_proportion = a / total_proteins_family 
            swissprot_proportion = (a+b) / (total_proteins_swissprot + total_proteins_family)
     
            # Not needed anymore when we do a+b for the swissprot proportion : So we calculate contingency table and Fishers Test
            # on the right values now, but the proportion on ALL swissprot, so also the ones that are in family TODO : ask prof if correct
            '''
            # Fold Enrichment
            # TODO : see if the argumentation in the next comment makes sense (send email to prof)
            if b == 0: # When the swissprot count is 0, it means that : 
                                     # When collecting the GO terms of SwissProt, we skipped over the proteins in our family
                                     # Thus, if no protein in SwissProt has this GO term, ONLY the protein in the family itself 
                                     # has that GO term (compared to ALL of SwissProt), thus in the WordCloud later on
                                     # we want to especially show the term of this GO id and will thus give it
                                     # 'inf' amount (infinite) for now
                if my_proportion > 0:
                    fold_enrichment = float('inf')
                else:
                    fold_enrichment = 0
            else:
                fold_enrichment = my_proportion/swissprot_proportion
            '''
       
     
            
            results.append({
                'GO_ID': go_id,
                'GO_Term': go_id_to_go_term.get(go_id, 'N/A'),
                'Count_Prot_Dataset': a,
                'Count_Prot_SwissProt': b,
                'Count_Prot_SwissProt_Actual': a+b,
                'Percentage_Dataset': round(my_proportion * 100, 2),
                'Percentage_SwissProt': round(swissprot_proportion * 100, 10),
                'Fold_Enrichment': round(my_proportion/swissprot_proportion,2),
                'P_Value_Two_Tail': pvalue_two_tail,
                'P_Value_Greater': pvalue_greater,
            })
    
    # Convert to DataFrame and sort by p-value
    df_results = pd.DataFrame(results)
    if not df_results.empty:
        df_results = df_results.sort_values('P_Value_Two_Tail')

    df_results.to_csv("enrichment_results.csv")
    
    return df_results

In [26]:
# Helper function to turn a Protein_ID : [GO_terms] into a GO_term : [Protein_IDs] dictionary

def reverse_protein_go_dict(protein_to_go):
   """
   Convert protein->GO dict to GO->proteins dict.
   """
   go_to_proteins = defaultdict(list)
   for protein, go_terms in protein_to_go.items():
       for go_term in go_terms:
           go_to_proteins[go_term].append(protein)
   return go_to_proteins

In [53]:
# MARLON 
# Hierarchical Structure

import pandas as pd
import networkx as nx
from goatools import obo_parser
import matplotlib.pyplot as plt

def analyze_go_hierarchy():
    # First, we downloaded the go.obo file so we can parse it 
    go_obo = obo_parser.GODag('go.obo')
    
    # Read our enrichment results
    df = pd.read_csv("enrichment_results.csv")
    
    # Filter for significantly enriched terms
    enriched_terms = df[
        (df['P_Value_Two_Tail'] < 0.05) &
        (df['P_Value_Greater'] < 0.05)
    ]
    
    # Create a dictionary to store branch information
    branch_info = {}
    
    # For each enriched term, traverse up its ancestry
    for _, row in enriched_terms.iterrows():
        go_id = row['GO_ID']
        if go_id in go_obo:
            term = go_obo[go_id]
            
            # Get all ancestors (parents) up to the root of the DAG (since we use get_all_parents we do that here! get_parents would just get the direct parents)
            ancestors = term.get_all_parents()
            
            # Add information about this term to all its ancestor branches
            for ancestor_id in ancestors:
                if ancestor_id not in branch_info:
                    branch_info[ancestor_id] = {
                        'term_name': go_obo[ancestor_id].name,
                        'enriched_children': [],
                        'total_significance': 0,
                        'depth': go_obo[ancestor_id].depth,
                    }

                # TODO : correct ????
                # Our go_id is a child to the current ancestors (note that this is not necessarily a direct child, but maybe also much more down in the tree somewhere)
                branch_info[ancestor_id]['enriched_children'].append({
                    'id': go_id,
                    'name': term.name,
                    'p_value': row['P_Value_Two_Tail']
                })
                # Measure significance based on -log value of the p value of all the childs of the ancestor (lower p values have higher -log scores!)
                branch_info[ancestor_id]['total_significance'] += -np.log10(row['P_Value_Two_Tail'])
    
    # Filter for high-level terms (lower depth) with multiple enriched children
    significant_branches = {
        go_id: info for go_id, info in branch_info.items() # take each key,value of the branch_info dictionary
        if len(info['enriched_children']) >= 2  # At least 2 enriched children
        and info['depth'] <= 3  # High-level term (adjust this threshold as needed)
    }
    
    # Sort branches by their total significance
    sorted_branches = sorted(
        significant_branches.items(),
        key=lambda x: x[1]['total_significance'],
        reverse=True
    )
    
    # Create a list to store the branch information
    branch_data = []

    # Convert the branch information into a format suitable for a DataFrame
    for go_id, info in sorted_branches[:20]:  # Top 20 branches
        branch_data.append({
            'GO_ID': go_id,
            'Branch_Name': info['term_name'],
            'Hierarchy_Depth': info['depth'],
            'Number_Enriched_Terms': len(info['enriched_children']),
            'Total_Significance_Score': info['total_significance']
        })

    # Create a DataFrame and save to CSV
    branches_df = pd.DataFrame(branch_data)
    branches_df.to_csv('enriched_branches.csv', index=False)


In [54]:
# MARLON
analyze_go_hierarchy()

go.obo: fmt(1.2) rel(2024-11-03) 43,983 Terms


In [60]:
#MARLON 
def main():

    psiblast_file = "psiblast_parsed.csv"
    hmm_file = "hmmsearch_output.csv"
    protein_ids = load_protein_ids(psiblast_file, hmm_file)
    

    
    # Proteins_to_GO terms for our family 
    print("Fetching GO annotations...")
    family_annotations = {}
    for pid in tqdm(protein_ids, desc="Fetching GO annotations"):
        family_annotations[pid] = fetch_go_annotations(pid)

    total_proteins_family = len(family_annotations)

    go_id_to_go_term = fetch_go_terms(protein_ids)

    
    # Proteins_to_GO terms for SwissProt
    swissprot_annotations = parse_swissprot_go_terms("uniprot_sprot.xml", protein_ids) #go_counts_swissprot, num_proteins_swissprot
    
    total_proteins_swissprot = len(swissprot_annotations)

    # Now Map the GO terms to the proteins ; for the enrichment task, we need to know how many proteins have a certain GO term
    go_to_proteins_swissprot = reverse_protein_go_dict(swissprot_annotations)
    go_to_proteins_family = reverse_protein_go_dict(family_annotations)
   



    
    # Calculate GO enrichments for both with skipped proteins and without 
    _ = calculate_go_enrichment(go_to_proteins_family, go_to_proteins_swissprot, total_proteins_family, total_proteins_swissprot, go_id_to_go_term)
    
    
    
    # Read the enrichment results
    df = pd.read_csv("enrichment_results.csv")

    # Get the terms to the GO ids from the family data
  #  go_id_to_term = create_go_id_to_term_mapping(family_annotations)

    # Filter for significantly enriched terms
    enriched_terms = df[
    (df['P_Value_Two_Tail'] < 0.05) &
    (df['P_Value_Greater'] < 0.05)
    ]


    # Create word frequencies using the actual GO terms instead of IDs
    word_frequencies = {}
    for _, row in enriched_terms.iterrows():
        go_id = row['GO_ID']
        if go_id in go_id_to_go_term:  # Make sure we have the term for this ID
            term = go_id_to_go_term[go_id]
            # Use fold enrichment as weight
            weight = row['Fold_Enrichment']
            word_frequencies[term] = weight

    # Create and display the word cloud
    wordcloud = WordCloud(
        width=1200, 
        height=800,
        background_color='white',
        prefer_horizontal=0.7,
        max_words=50,  # Limit to top 50 terms for better readability
        min_font_size=10,
        max_font_size=60
    ).generate_from_frequencies(word_frequencies)

    # Plot and save the word cloud
    plt.figure(figsize=(20, 12))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis('off')
    plt.title('GO Term Enrichment Word Cloud', fontsize=16, pad=20)
    plt.savefig('go_enrichment_wordcloud.png', dpi=300, bbox_inches='tight')
    plt.close()

    # Print out the enriched terms for verification
    print("\nTop enriched GO terms:")
    sorted_terms = sorted(word_frequencies.items(), key=lambda x: x[1], reverse=True)
    for term, weight in sorted_terms[:10]:
        print(f"\nTerm: {term}")
        print(f"Weight in word cloud: {weight:.2f}")

    
if __name__ == "__main__":
    main()

Fetching GO annotations...


Fetching GO annotations: 100%|██████████| 57/57 [00:14<00:00,  3.88it/s]


Starting to parse SwissProt XML...
Processed 10000 proteins (skipped 0 family proteins)...
Processed 20000 proteins (skipped 0 family proteins)...
Processed 29985 proteins (skipped 15 family proteins)...
Processed 39985 proteins (skipped 15 family proteins)...
Processed 49985 proteins (skipped 15 family proteins)...
Processed 59985 proteins (skipped 15 family proteins)...
Processed 69985 proteins (skipped 15 family proteins)...
Processed 79985 proteins (skipped 15 family proteins)...
Processed 89974 proteins (skipped 26 family proteins)...
Processed 99974 proteins (skipped 26 family proteins)...
Processed 109974 proteins (skipped 26 family proteins)...
Processed 119974 proteins (skipped 26 family proteins)...
Processed 129974 proteins (skipped 26 family proteins)...
Processed 139974 proteins (skipped 26 family proteins)...
Processed 149974 proteins (skipped 26 family proteins)...
Processed 159974 proteins (skipped 26 family proteins)...
Processed 169974 proteins (skipped 26 family prot

In [48]:
import requests
import pandas as pd
import xml.etree.ElementTree as ET
from scipy.stats import fisher_exact
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import obonet
import networkx as nx
from statsmodels.stats.multitest import multipletests

# Step 1: Load Protein IDs
# TODO : we basically did this above already for taxonomy task and just here neatly written into a function, so we could maybe just do it once in the whole code later on
def load_protein_ids(psiblast_file, hmm_file, e_threshold=0.001):
    """Load protein IDs from PSI-BLAST and HMM search results."""
    psiblast_df = pd.read_csv(psiblast_file)
    hmm_df = pd.read_csv(hmm_file)
    
    filtered_hmm_proteins = hmm_df[hmm_df['E-value'] <= e_threshold]['uniprot_id']
    psiblast_proteins = set(psiblast_df['uniprot_id'])
    hmm_proteins = set(filtered_hmm_proteins)
    
    return list(psiblast_proteins.union(hmm_proteins))


def fetch_go_annotations(protein_id):
    """
    Fetch and categorize GO annotations for a given protein ID from the UniProt API.
    
    Args:
        protein_id (str): The UniProt ID of the protein
        
    Returns:
        dict: A dictionary containing:
            - Categorized GO terms separated by molecular function, biological process, 
              and cellular component (new format)
    """
    # Define the UniProt API URL for XML data
    url = f"https://rest.uniprot.org/uniprotkb/{protein_id}.xml"
    
    try:
        # Fetch the XML data from UniProt
        response = requests.get(url)
        response.raise_for_status()
        
        # Initialize our data structures
        go_terms = []  # Original format
        categorized_terms = {
            'molecular_function': [],
            'biological_process': [],
            'cellular_component': []
        }
        
        # Set up namespace for XML parsing
        namespaces = {'ns': 'http://uniprot.org/uniprot'}
        root = ET.fromstring(response.content)
        
        # Find all GO term references in the XML
        for db_ref in root.findall(".//ns:dbReference[@type='GO']", namespaces):
            go_id = db_ref.attrib.get('id')
            term = db_ref.find("ns:property[@type='term']", namespaces)
            category = db_ref.find("ns:property[@type='category']", namespaces)
            
            if go_id and term is not None:
                # Store in original format
                term_value = term.attrib['value']

                
                # Categorize based on prefix
                if term_value.startswith('F:'):
                    categorized_terms['molecular_function'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'F:' prefix
                    })
                elif term_value.startswith('P:'):
                    categorized_terms['biological_process'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'P:' prefix
                    })
                elif term_value.startswith('C:'):
                    categorized_terms['cellular_component'].append({
                        'id': go_id,
                        'term': term_value[2:]  # Remove 'C:' prefix
                    })
        
        return {
            'categorized': categorized_terms  # New categorized format
        }
        
    except requests.exceptions.RequestException as e:
        print(f"Error fetching GO annotations for {protein_id}: {e}")
        return {
            'categorized': {
                'molecular_function': [],
                'biological_process': [],
                'cellular_component': []
            }
        }

# -----------------------------------------------------------------------------
# Step 3: Fetch Random Proteins
# -----------------------------------------------------------------------------

def fetch_random_proteins(batch_size=100, total_proteins=500):
    """Fetch a list of random reviewed UniProt protein IDs."""
    url = "https://rest.uniprot.org/uniprotkb/stream?query=reviewed:true&format=list"
    try:
        response = requests.get(url)
        response.raise_for_status()
        all_proteins = response.text.splitlines()
        selected_proteins = random.sample(all_proteins, min(total_proteins, len(all_proteins)))
        return [selected_proteins[i:i + batch_size] for i in range(0, len(selected_proteins), batch_size)]
    except requests.exceptions.RequestException as e:
        print(f"Error fetching random proteins: {e}")
        return []

# -----------------------------------------------------------------------------
# Step 4: Load GO Ontology
# -----------------------------------------------------------------------------

def load_go_ontology():
    """Load the GO ontology from OBO file."""
    url = "http://purl.obolibrary.org/obo/go/go-basic.obo"
    graph = obonet.read_obo(url)
    return graph

# -----------------------------------------------------------------------------
# Step 5: Flatten Annotations for Analysis
# -----------------------------------------------------------------------------

def flatten_annotations(annotation_dict):
    """Flatten GO annotations into a list of GO terms."""
    flat_terms = []
    for annotations in annotation_dict.values():
        flat_terms.extend([a["GO_ID"] for a in annotations])
    return flat_terms

# -----------------------------------------------------------------------------
# Step 6: Enrichment Analysis
# -----------------------------------------------------------------------------

def calculate_enrichment(go_term, family_terms, background_terms):
    """Calculate enrichment of a GO term using Fisher's exact test."""
    family_count = family_terms.count(go_term)
    family_not = len(family_terms) - family_count
    background_count = background_terms.count(go_term)
    background_not = len(background_terms) - background_count

    contingency_table = [[family_count, background_count],
                         [family_not, background_not]]
    _, p_value = fisher_exact(contingency_table, alternative='greater')
    return p_value

# -----------------------------------------------------------------------------
# Step 7: Visualize Enriched Terms
# -----------------------------------------------------------------------------

def plot_wordcloud(enrichment_results, annotations):
    """Generate and save a word cloud of enriched GO terms."""
    term_names = {a["GO_ID"]: a["Term"] for ann_list in annotations.values() for a in ann_list}
    enriched_with_names = {term_names[go_id]: -np.log10(p) for go_id, p in enrichment_results.items() if go_id in term_names}

    if not enriched_with_names:
        print("No enriched terms found. Word cloud will not be generated.")
        return

    wordcloud = WordCloud(width=1000, height=600, background_color="white", colormap="viridis").generate_from_frequencies(enriched_with_names)
    wordcloud.to_file("enriched_terms_wordcloud.png")
    print("Word cloud saved as enriched_terms_wordcloud.png")

    plt.figure(figsize=(12, 6))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title("Word Cloud of Enriched GO Terms")
    plt.tight_layout()
    plt.show()

def plot_branch_enrichment(enrichment_results, go_graph):
    """Plot GO branch enrichment."""
    branch_scores = {}
    for go_id, p_value in enrichment_results.items():
        try:
            parents = nx.ancestors(go_graph, go_id)
            for parent in parents:
                if parent not in branch_scores:
                    branch_scores[parent] = []
                branch_scores[parent].append(p_value)
        except nx.NetworkXError:
            continue

    significant_branches = {branch: np.mean(scores)
                            for branch, scores in branch_scores.items() if len(scores) >= 3}

    # Limit to top 50 branches by score
    sorted_branches = sorted(significant_branches.items(), key=lambda x: x[1], reverse=True)[:50]
    branches, scores = zip(*sorted_branches)

    plt.figure(figsize=(14, 10))
    plt.barh(range(len(branches)), scores, color="steelblue")
    plt.yticks(range(len(branches)), [go_graph.nodes[branch]['name'] for branch in branches], fontsize=8)
    plt.xlabel('Mean p-value')
    plt.title('Top 50 GO Branch Enrichments')
    plt.tight_layout()
    plt.savefig("go_enrichment_branches.png", dpi=300, bbox_inches="tight")
    plt.legend(["Top 50 Branch Enrichments"], loc="lower right")
    print("GO branch enrichment plot saved as go_enrichment_branches.png")

# -----------------------------------------------------------------------------
# Step 8: Write Summary and Results to File
# -----------------------------------------------------------------------------

def write_summary_and_results(enrichment_results, family_annotations, go_graph):
    """Write a summary and detailed results to a text file."""
    with open("enrichment_results.txt", "w") as f:
        # Write summary
        f.write("SUMMARY\n")
        f.write("========\n")
        f.write(f"Number of enriched GO terms: {len(enrichment_results)}\n")
        f.write(f"Top enriched term: {max(enrichment_results, key=enrichment_results.get, default='None')}\n")
        f.write("\n\n")

        # Write detailed results
        f.write("DETAILED RESULTS\n")
        f.write("================\n")
        for go_id, p_value in enrichment_results.items():
            term_name = next((a["Term"] for ann_list in family_annotations.values() 
                             for a in ann_list if a["GO_ID"] == go_id), go_id)
            f.write(f"{go_id}: {term_name} (p-value: {p_value:.2e})\n")

        # Write branch scores
        f.write("\n\nSIGNIFICANT GO BRANCHES\n")
        f.write("========================\n")
        branch_scores = {}
        for go_id, p_value in enrichment_results.items():
            try:
                parents = nx.ancestors(go_graph, go_id)
                for parent in parents:
                    if parent not in branch_scores:
                        branch_scores[parent] = []
                    branch_scores[parent].append(p_value)
            except nx.NetworkXError:
                continue

        significant_branches = {branch: np.mean(scores) 
                                for branch, scores in branch_scores.items() if len(scores) >= 3}
        for branch, score in sorted(significant_branches.items(), key=lambda x: x[1]):
            branch_name = go_graph.nodes[branch].get('name', branch)
            f.write(f"{branch}: {branch_name} (mean p-value: {score:.2e})\n")

# -----------------------------------------------------------------------------
# Main Script
# -----------------------------------------------------------------------------

def main():
    print("Loading GO ontology...")
    go_graph = load_go_ontology()

    psiblast_file = "psiblast_parsed.csv"
    hmm_file = "hmmsearch_output.csv"
    protein_ids = load_protein_ids(psiblast_file, hmm_file)

    print("Fetching GO annotations...")
    family_annotations = {}
    for pid in tqdm(protein_ids, desc="Fetching GO annotations"):
        family_annotations[pid] = fetch_go_annotations(pid)


    print(family_annotations)


    print("Fetching background annotations...")
    background_annotations = {}
    background_batches = fetch_random_proteins(batch_size=50, total_proteins=500)
    for batch in tqdm(background_batches, desc="Processing background proteins"):
        for pid in batch:
            background_annotations[pid] = fetch_go_annotations(pid)

    print("Calculating enrichment...")
    family_terms = flatten_annotations(family_annotations)
    background_terms = flatten_annotations(background_annotations)

    unique_go_terms = set(family_terms)
    enrichment_results = {}
    pvalues = []
    terms = []

    for term in unique_go_terms:
        _, p_value = fisher_exact([
            [family_terms.count(term), len(family_terms) - family_terms.count(term)],
            [background_terms.count(term), len(background_terms) - background_terms.count(term)]
        ], alternative='greater')

        pvalues.append(p_value)
        terms.append(term)

    rejected, p_corrected, _, _ = multipletests(pvalues, method='fdr_bh')

    for term, p_value, significant in zip(terms, p_corrected, rejected):
        if significant:
            enrichment_results[term] = p_value

    print("Generating visualizations...")
    plot_wordcloud(enrichment_results, family_annotations)
    plot_branch_enrichment(enrichment_results, go_graph)

    print("Writing results to file...")
    write_summary_and_results(enrichment_results, family_annotations, go_graph)

if __name__ == "__main__":
    main()


Loading GO ontology...


KeyboardInterrupt: 

### Motifs
1. Search significantly conserved short motifs inside your family. Use ELM classes and ProSite patterns (for ProSite consider only patterns “PA” lines, not the profiles). Make sure to consider as true matches only those that are found inside disordered regions. Disordered regions for the entire SwissProt (as defined by MobiDB-lite) are available here