<a href="https://colab.research.google.com/github/Adiaslow/OligomerizationTest/blob/main/oligomerization_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Install Packages
_ = !pip install -q py3Dmol biopython

In [None]:
# @title Imports

import sys
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install --upgrade pandas numpy biopython py3Dmol

# Now the rest of your imports
import os
# Set user agent before any other imports
os.environ['COLABFOLD_USER_AGENT'] = 'colabfold/batch'

import json
import pandas as pd
import warnings
import py3Dmol
import glob
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from Bio import SeqIO
from google.colab import files
import hashlib
import base64
from html import escape
from IPython.display import display, HTML

Collecting pip
  Downloading pip-24.3.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-24.3.1-py3-none-any.whl (1.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-24.3.1
Collecting pandas
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting numpy
  Downloading numpy-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting biopython
  Using cached biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.many

In [None]:
# @title ColabFold Setup
def setup_alphafold_environment():
    """Set up the AlphaFold environment in Google Colab"""
    python_version = f"{sys.version_info.major}.{sys.version_info.minor}"

    # Install base ColabFold
    if not os.path.isfile("COLABFOLD_READY"):
        print("Installing colabfold...")
        os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
        if os.environ.get('TPU_NAME', False) != False:
            os.system("pip uninstall -y jax jaxlib")
            os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
        os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
        os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
        os.system("touch COLABFOLD_READY")

    from colabfold.download import download_alphafold_params
    from colabfold.utils import setup_logging
    from colabfold.batch import get_queries, run, set_model_type
    from colabfold.plot import plot_msa_v2
    from colabfold.colabfold import plot_protein

    return download_alphafold_params, setup_logging, get_queries, run, set_model_type

# Run setup
print("Setting up AlphaFold environment...")
setup_alphafold_environment()
print("Setup complete!")

Setting up AlphaFold environment...
Installing colabfold...
Setup complete!


In [None]:
# @title Classes
@dataclass
class OligomerConfig:
    """Configuration class for oligomerization analysis"""
    num_recycles: int = 3
    use_templates: bool = False
    msa_mode: str = "mmseqs2_uniref_env"
    pair_mode: str = "unpaired_paired"
    model_type: str = "alphafold2_multimer_v3"
    num_relax: int = 0
    max_msa: Optional[str] = None
    num_seeds: int = 20
    use_dropout: bool = False
    save_all: bool = False
    save_recycles: bool = False
    dpi: int = 200

class OligomerAnalysis:
    def __init__(self, config: OligomerConfig):
        self.config = config
        self.results_dir = "oligomer_analysis"
        os.makedirs(self.results_dir, exist_ok=True)

        # Import colabfold modules
        from colabfold.download import download_alphafold_params
        from colabfold.utils import setup_logging
        from colabfold.batch import get_queries, run, set_model_type

        self.download_params = download_alphafold_params
        self.setup_logging = setup_logging
        self.get_queries = get_queries
        self.run = run
        self.set_model_type = set_model_type

    def read_fasta(self, fasta_path: str) -> Dict[str, str]:
        """Read sequences from a FASTA file"""
        sequences = {}
        for record in SeqIO.parse(fasta_path, "fasta"):
            sequences[record.id.split()[0]] = str(record.seq)
        return sequences

    def generate_oligomer_sequence(self, sequence: str, n_copies: int) -> str:
        """Generate oligomer sequence with chain breaks"""
        return ":".join([sequence] * n_copies)

    def create_job_name(self, seq_id: str, n_copies: int) -> str:
        """Create unique job name for each oligomer"""
        base_name = f"{seq_id}_oligomer_{n_copies}"
        return base_name + "_" + hashlib.sha1(base_name.encode()).hexdigest()[:5]

    def run_prediction(self, sequence: str, job_name: str, n_copies: int) -> Dict[str, Any]:
        """Run AlphaFold prediction for a given sequence"""
        try:
            # Create query file
            query_path = os.path.join(self.results_dir, f"{job_name}.csv")
            with open(query_path, "w") as f:
                f.write(f"id,sequence\n{job_name},{sequence}\n")

            # Set up results directory and logging
            result_dir = os.path.join(self.results_dir, job_name)
            os.makedirs(result_dir, exist_ok=True)
            self.setup_logging(Path(os.path.join(result_dir, "log.txt")))

            # Force is_complex to True if n_copies > 1
            queries, _ = self.get_queries(query_path)
            is_complex = n_copies > 1
            model_type = self.set_model_type(is_complex, "alphafold2_multimer_v3")
            self.download_params(model_type, Path("."))

            # Create a3m file for the query
            a3m_path = os.path.join(result_dir, f"{job_name}.a3m")
            with open(a3m_path, "w") as f:
                f.write(f">{job_name}\n{sequence}\n")

            # Run prediction with multimer-specific settings
            results = self.run(
                queries=queries,
                result_dir=result_dir,
                use_templates=self.config.use_templates,
                custom_template_path=None,
                num_relax=self.config.num_relax,
                msa_mode=self.config.msa_mode,
                model_type="alphafold2_multimer_v3",  # Force multimer model
                num_models=5,
                num_recycles=self.config.num_recycles,
                num_seeds=self.config.num_seeds,
                use_dropout=self.config.use_dropout,
                model_order=[1,2,3,4,5],
                is_complex=is_complex,  # Use our computed value
                pair_mode=self.config.pair_mode,
                max_msa=self.config.max_msa,
                save_all=self.config.save_all,
                save_recycles=self.config.save_recycles,
                dpi=self.config.dpi,
                data_dir=Path("."),
                rank_by="iptm"  # Add this to ensure iPTM is computed and used
            )

            if results is None:
                raise ValueError("No results generated")

            return self._extract_metrics(results, job_name)

        except Exception as e:
            print(f"Error in prediction for {job_name}: {str(e)}")
            return {
                "sequence_id": job_name,
                "n_copies": n_copies,
                "status": f"failed: {str(e)}",
                "best_plddt_avg": 0.0,
                "best_plddt_min": 0.0,
                "best_plddt_max": 0.0,
                "best_ptm_score": 0.0,
                "best_interface_score": 0.0
            }

    def _extract_metrics(self, results: Dict[str, Any], job_name: str) -> Dict[str, Any]:
        """Extract and calculate metrics from prediction results"""
        metrics = {
            "mean_plddt": {"min": 0.0, "max": 0.0, "mean": 0.0, "median": 0.0, "sd": 0.0},
            "max_pae": {"min": 0.0, "max": 0.0, "mean": 0.0, "median": 0.0, "sd": 0.0},
            "mean_pae": {"min": 0.0, "max": 0.0, "mean": 0.0, "median": 0.0, "sd": 0.0},
            "ptm": {"min": 0.0, "max": 0.0, "mean": 0.0, "median": 0.0, "sd": 0.0},
            "iptm": {"min": 0.0, "max": 0.0, "mean": 0.0, "median": 0.0, "sd": 0.0}  # Kept for compatibility
        }

        try:
            job_dir = os.path.join(self.results_dir, job_name)
            score_files = list(Path(job_dir).glob("*_scores_rank_*_*.json"))

            if score_files:
                model_metrics = {
                    "mean_plddt": [],
                    "max_pae": [],
                    "mean_pae": [],
                    "ptm": []
                }

                for model_file in score_files:
                    with open(model_file, 'r') as f:
                        scores = json.load(f)

                        # Calculate mean pLDDT for this model
                        if "plddt" in scores:
                            model_metrics["mean_plddt"].append(float(np.mean(scores["plddt"])))

                        # Get max_pae for this model
                        if "max_pae" in scores:
                            model_metrics["max_pae"].append(float(scores["max_pae"]))

                        # Calculate mean_pae for this model
                        if "pae" in scores:
                            pae_matrix = np.array(scores["pae"])
                            model_metrics["mean_pae"].append(float(np.mean(pae_matrix)))

                        # Get PTM score for this model
                        if "ptm" in scores:
                            model_metrics["ptm"].append(float(scores["ptm"]))

                # Calculate statistics across all models
                for metric_name in model_metrics:
                    if model_metrics[metric_name]:
                        values = np.array(model_metrics[metric_name])
                        metrics[metric_name].update({
                            "min": float(np.min(values)),
                            "max": float(np.max(values)),
                            "mean": float(np.mean(values)),
                            "median": float(np.median(values)),
                            "sd": float(np.std(values)) if len(values) > 1 else 0.0
                        })

            # Flatten the metrics dictionary
            flat_metrics = {}
            for metric_name, stats in metrics.items():
                for stat_name, value in stats.items():
                    flat_metrics[f"{metric_name}_{stat_name}"] = value

            return flat_metrics

        except Exception as e:
            print(f"Error extracting metrics for {job_name}: {str(e)}")
            return {f"{metric}_{stat}": 0.0
                    for metric in metrics
                    for stat in ["min", "max", "mean", "median", "sd"]}

    def _analyze_single_model(self, model_file: Path) -> Dict[str, float]:
        """Analyze metrics from a single model file"""
        with open(model_file, 'r') as f:
            scores = json.load(f)

        metrics = {}

        # Calculate pLDDT statistics
        if "plddt" in scores:
            plddt_values = np.array(scores["plddt"])
            metrics.update({
                "plddt_min": float(np.min(plddt_values)),
                "plddt_max": float(np.max(plddt_values)),
                "plddt_mean": float(np.mean(plddt_values)),
                "plddt_median": float(np.median(plddt_values)),
                "plddt_sd": float(np.std(plddt_values)) if len(plddt_values) > 1 else 0.0
            })

        # Calculate PAE statistics
        if "pae" in scores:
            pae_matrix = np.array(scores["pae"])
            pae_values = pae_matrix.flatten()
            metrics.update({
                "pae_min": float(np.min(pae_values)),
                "pae_max": float(np.max(pae_values)),
                "pae_mean": float(np.mean(pae_values)),
                "pae_median": float(np.median(pae_values)),
                "pae_sd": float(np.std(pae_values)) if len(pae_values) > 1 else 0.0
            })

        # Add PTM score
        if "ptm" in scores:
            metrics["ptm"] = float(scores["ptm"])

        # Add iPTM score if available
        if "iptm" in scores:
            metrics["iptm"] = float(scores["iptm"])
        else:
            metrics["iptm"] = 0.0

        return metrics

    def _create_model_summary(self, model_metrics: List[Dict[str, float]], top_model_metrics: Dict[str, float]) -> Dict[str, float]:
        """Create summary statistics across all models and include top model metrics"""
        summary = {}

        # Calculate summary statistics for mean values across all models
        metric_keys = ["plddt_mean", "pae_mean", "ptm", "iptm"]
        for key in metric_keys:
            values = [metrics[key] for metrics in model_metrics if key in metrics]
            if values:
                summary.update({
                    f"all_models_{key}_min": float(np.min(values)),
                    f"all_models_{key}_max": float(np.max(values)),
                    f"all_models_{key}_mean": float(np.mean(values)),
                    f"all_models_{key}_median": float(np.median(values)),
                    f"all_models_{key}_sd": float(np.std(values)) if len(values) > 1 else 0.0
                })

        # Add top model metrics
        for key, value in top_model_metrics.items():
            summary[f"top_model_{key}"] = value

        return summary

    def analyze_oligomers(self, fasta_path: str, oligomer_range: Tuple[int, int]) -> None:
        """Analyze protein oligomerization states and create detailed reports"""
        sequences = self.read_fasta(fasta_path)

        for seq_id, sequence in sequences.items():
            print(f"\nProcessing sequence: {seq_id}")

            for n_copies in range(oligomer_range[0], oligomer_range[1] + 1):
                print(f"Testing {n_copies}-mer")

                oligomer_seq = self.generate_oligomer_sequence(sequence, n_copies)
                job_name = self.create_job_name(seq_id, n_copies)

                # Run prediction
                self.run_prediction(oligomer_seq, job_name, n_copies)

                # Process results
                job_dir = os.path.join(self.results_dir, job_name)
                model_files = sorted(Path(job_dir).glob("*_scores_rank_*_*.json"))

                if not model_files:
                    print(f"No results found for {job_name}")
                    continue

                # Analyze each model and collect metrics
                model_metrics = []
                for model_file in model_files:
                    metrics = self._analyze_single_model(model_file)
                    metrics["model_name"] = model_file.stem
                    model_metrics.append(metrics)

                # Create per-oligomer detailed CSV - using numpy to avoid pandas issues
                detailed_output_path = os.path.join(self.results_dir, f"{job_name}_detailed.csv")
                with open(detailed_output_path, 'w') as f:
                    # Write header
                    metrics_keys = sorted(model_metrics[0].keys())
                    f.write(','.join(metrics_keys) + '\n')

                    # Write data
                    for metrics in model_metrics:
                        row = [str(metrics.get(key, '')) for key in metrics_keys]
                        f.write(','.join(row) + '\n')

                # Create summary with top model and aggregate statistics
                top_model_metrics = model_metrics[0]  # First model is the top-ranked one
                summary_metrics = self._create_model_summary(model_metrics, top_model_metrics)

                # Write summary CSV
                summary_output_path = os.path.join(self.results_dir, f"{job_name}_summary.csv")
                with open(summary_output_path, 'w') as f:
                    # Write header
                    summary_keys = sorted(summary_metrics.keys())
                    f.write(','.join(summary_keys) + '\n')

                    # Write data
                    row = [str(summary_metrics.get(key, '')) for key in summary_keys]
                    f.write(','.join(row) + '\n')

                print(f"Created detailed report: {detailed_output_path}")
                print(f"Created summary report: {summary_output_path}")

In [None]:
def main():
    print("Setting up AlphaFold environment...")

    config = OligomerConfig(
        num_recycles = 3,
        use_templates = False,
        msa_mode = "mmseqs2_uniref_env",
        pair_mode = "unpaired_paired",
        model_type = "alphafold2_multimer_v3",
        num_relax = 0,
        max_msa = None,
        num_seeds = 1,
        use_dropout = False,
        save_all = False,
        save_recycles = False,
        dpi = 200
    )

    print("\nPlease upload your FASTA file...")
    uploaded = files.upload()
    fasta_path = list(uploaded.keys())[0]

    analyzer = OligomerAnalysis(config)

    # analyze_oligomers no longer returns a DataFrame
    analyzer.analyze_oligomers(
        fasta_path=fasta_path,
        oligomer_range=(1, 2)
    )

    # Print location of output files
    print("\nAnalysis complete. Results files have been created in the oligomer_analysis directory.")
    print("You can find detailed reports in *_detailed.csv files")
    print("You can find summary reports in *_summary.csv files")

if __name__ == "__main__":
    results = main()

Setting up AlphaFold environment...

Please upload your FASTA file...


Saving test.fasta to test (14).fasta

Processing sequence: TEST_SEQUENCE_A
Testing 1-mer
2024-12-17 02:23:49,893 Running on GPU
2024-12-17 02:23:49,896 Found 5 citations for tools or databases
2024-12-17 02:23:49,897 Skipping TEST_SEQUENCE_A_oligomer_1_2d05c (already done)
2024-12-17 02:23:49,897 Done
Created detailed report: oligomer_analysis/TEST_SEQUENCE_A_oligomer_1_2d05c_detailed.csv
Created summary report: oligomer_analysis/TEST_SEQUENCE_A_oligomer_1_2d05c_summary.csv
Testing 2-mer
2024-12-17 02:23:49,904 Running on GPU
2024-12-17 02:23:49,907 Found 5 citations for tools or databases
2024-12-17 02:23:49,908 Skipping TEST_SEQUENCE_A_oligomer_2_74582 (already done)
2024-12-17 02:23:49,908 Done
Created detailed report: oligomer_analysis/TEST_SEQUENCE_A_oligomer_2_74582_detailed.csv
Created summary report: oligomer_analysis/TEST_SEQUENCE_A_oligomer_2_74582_summary.csv

Processing sequence: TEST_SEQUENCE_B
Testing 1-mer
2024-12-17 02:23:49,917 Running on GPU
2024-12-17 02:23:49,919 Fo