# ColabFold A3M to Boltz2 Multimer Prediction

This notebook demonstrates how to convert ColabFold-generated A3M monomer MSA files into paired MSAs for Boltz2 multimer structure predictions.

## üöÄ Quick Start: CLI One-Liner

**The fastest way** - predict directly from A3M files with a single command:

```bash
# Basic prediction
boltz2 --base-url http://localhost:8002 multimer-msa \
    chain_A.a3m chain_B.a3m \
    -c A,B \
    -o complex.cif

# Save all outputs (structure + scores + CSVs)
boltz2 --base-url http://localhost:8002 multimer-msa \
    chain_A.a3m chain_B.a3m \
    -c A,B \
    -o complex.cif \
    --save-all --save-csv
```

This command automatically:
1. Parses the A3M files
2. Auto-detects the best pairing mode (TaxID or UniRef ID)
3. Pairs sequences using greedy matching (like ColabFold)
4. Submits the prediction to Boltz2
5. Saves the output structure (and scores JSON with `--save-all`)

---

## Features
- **Auto-detection**: Automatically chooses the best pairing mode based on your A3M files
- **ColabFold compatible**: Works with standard ColabFold output (UniRef100 format)
- **TaxID support**: Also supports taxonomy-annotated A3M files for more precise pairing

## Prerequisites
- Boltz2 NIM running locally (default: `http://localhost:8002`)
- ColabFold-generated A3M files for each chain


In [1]:
# Import required libraries
import asyncio
from pathlib import Path
import tempfile

from boltz2_client import (
    Boltz2Client,
    Polymer,
    PredictionRequest,
    AlignmentFileRecord,
    convert_a3m_to_multimer_csv,
    create_paired_msa_per_chain,
    SpeciesMapper
)
from boltz2_client.a3m_to_csv_converter import A3MParser

print("‚úÖ Imports successful!")


‚úÖ Imports successful!


## 1. Understanding ColabFold A3M Formats

ColabFold generates A3M files in two main formats:

### Format 1: Standard ColabFold (UniRef100 IDs only)
```
>Query|-|Query
MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLG
>UniRef100_A0A2N5EEG3  340  0.994
MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLG
```

### Format 2: With Taxonomy Annotations (OX= field)
```
>Query|-|Query
MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLG
>UniRef100_P01116 Ras GTPase OS=Homo sapiens OX=9606
MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLG
```

Our converter **auto-detects** the format and chooses the best pairing strategy!


In [2]:
# Load the species mapper for TaxID lookups
speclist_path = Path('../boltz2_client/data/speclist.txt')
if speclist_path.exists():
    count = SpeciesMapper.load_speclist(speclist_path)
    print(f"‚úÖ Loaded {count:,} species mappings from speclist.txt")
else:
    print("‚ÑπÔ∏è Using built-in species mappings (54 common species)")

# Show stats
stats = SpeciesMapper.get_mapping_stats()
print(f"Total species available for TaxID lookup: {stats['total']:,}")


‚úÖ Loaded 27,836 species mappings from speclist.txt
Total species available for TaxID lookup: 27,890


## 2. Create Example A3M Files

Let's create two example A3M files representing a heterodimer (Chain A and Chain B) with taxonomy annotations.


In [3]:
# Define protein sequences for our heterodimer
CHAIN_A_SEQ = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
CHAIN_B_SEQ = "MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPAD"

print(f"Chain A: {len(CHAIN_A_SEQ)} residues")
print(f"Chain B: {len(CHAIN_B_SEQ)} residues")
print(f"Total complex: {len(CHAIN_A_SEQ) + len(CHAIN_B_SEQ)} residues")


Chain A: 65 residues
Chain B: 63 residues
Total complex: 128 residues


In [4]:
# Create A3M content with TaxID annotations (OX= fields)
# This simulates ColabFold output with taxonomy information

CHAIN_A_A3M = f""">Query|-|Query Chain A
{CHAIN_A_SEQ}
>UniRef100_P01116 Ras GTPase OS=Homo sapiens OX=9606 GN=KRAS
{CHAIN_A_SEQ}
>UniRef100_P01112 Ras GTPase OS=Mus musculus OX=10090 GN=Hras
{CHAIN_A_SEQ.replace('M', 'L', 1)}
>UniRef100_Q62636 Ras GTPase OS=Rattus norvegicus OX=10116 GN=Nras
{CHAIN_A_SEQ.replace('K', 'R', 1)}
>UniRef100_P79800 Ras GTPase OS=Bos taurus OX=9913 GN=KRAS
{CHAIN_A_SEQ.replace('E', 'D', 1)}
>UniRef100_P08644 Ras GTPase OS=Gallus gallus OX=9031 GN=KRAS
{CHAIN_A_SEQ.replace('S', 'T', 1)}
"""

CHAIN_B_A3M = f""">Query|-|Query Chain B
{CHAIN_B_SEQ}
>UniRef100_Q9Y6K9 Protein kinase OS=Homo sapiens OX=9606 GN=PKC
{CHAIN_B_SEQ}
>UniRef100_P23456 Protein kinase OS=Mus musculus OX=10090 GN=Pkc
{CHAIN_B_SEQ.replace('M', 'L', 1)}
>UniRef100_Q65432 Protein kinase OS=Bos taurus OX=9913 GN=PKC
{CHAIN_B_SEQ.replace('V', 'I', 1)}
>UniRef100_P98765 Protein kinase OS=Danio rerio OX=7955 GN=pkc
{CHAIN_B_SEQ.replace('E', 'D', 1)}
"""

print("‚úÖ Created A3M content with TaxID annotations (OX= fields)")
print(f"\nChain A MSA: 6 sequences")
print(f"Chain B MSA: 5 sequences")
print(f"\nChain A A3M preview:")
for line in CHAIN_A_A3M.split("\n")[:4]:
    print(f"  {line[:70]}{'...' if len(line) > 70 else ''}")


‚úÖ Created A3M content with TaxID annotations (OX= fields)

Chain A MSA: 6 sequences
Chain B MSA: 5 sequences

Chain A A3M preview:
  >Query|-|Query Chain A
  MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG
  >UniRef100_P01116 Ras GTPase OS=Homo sapiens OX=9606 GN=KRAS
  MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG


In [5]:
# Save A3M content to temporary files
tmpdir = tempfile.mkdtemp(prefix="boltz2_notebook_")
chain_a_file = Path(tmpdir) / "chain_A.a3m"
chain_b_file = Path(tmpdir) / "chain_B.a3m"

chain_a_file.write_text(CHAIN_A_A3M)
chain_b_file.write_text(CHAIN_B_A3M)

print(f"üìÅ Files saved to: {tmpdir}")
print(f"   - {chain_a_file.name}")
print(f"   - {chain_b_file.name}")


üìÅ Files saved to: /tmp/boltz2_notebook_sy50ei7_
   - chain_A.a3m
   - chain_B.a3m


## 3. Parse and Analyze A3M Files

Let's examine what the parser extracts from our A3M files, including TaxIDs.


In [6]:
# Parse A3M files
msa_a = A3MParser.parse_file(chain_a_file)
msa_b = A3MParser.parse_file(chain_b_file)

print("üìä Chain A MSA Analysis:")
print(f"   Sequences: {len(msa_a.sequences)}")
print(f"   Query: {msa_a.query_sequence[:40]}...")
print(f"   TaxIDs found: {msa_a.get_tax_ids()}")

print(f"\nüìä Chain B MSA Analysis:")
print(f"   Sequences: {len(msa_b.sequences)}")
print(f"   Query: {msa_b.query_sequence[:40]}...")
print(f"   TaxIDs found: {msa_b.get_tax_ids()}")

# Find common TaxIDs (these will be paired)
common_taxids = msa_a.get_tax_ids() & msa_b.get_tax_ids()
print(f"\nüîó Common TaxIDs (will be paired): {common_taxids}")
print(f"   Expected paired sequences: {len(common_taxids)} + 1 (query) = {len(common_taxids) + 1}")


üìä Chain A MSA Analysis:
   Sequences: 6
   Query: MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIV...
   TaxIDs found: {'9606', '9913', '10090', '10116', '9031'}

üìä Chain B MSA Analysis:
   Sequences: 5
   Query: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLW...
   TaxIDs found: {'9606', '7955', '10090', '9913'}

üîó Common TaxIDs (will be paired): {'9606', '10090', '9913'}
   Expected paired sequences: 3 + 1 (query) = 4


## 4. Convert to Paired CSV Format

Now we convert the A3M files to Boltz2's paired CSV format. The converter will:
1. **Auto-detect** whether to use TaxID or UniRef ID pairing
2. Match sequences from the same organism across chains
3. Generate per-chain CSV files with matching keys


In [7]:
# Convert with auto-detection (default ColabFold-style behavior)
result = convert_a3m_to_multimer_csv(
    a3m_files={'A': chain_a_file, 'B': chain_b_file},
    pairing_strategy='greedy',  # Default, like ColabFold
    # use_tax_id=None (default) triggers auto-detection
)

print("‚úÖ Conversion complete!")
print(f"   Paired sequences: {result.num_pairs}")
print(f"   Chain IDs: {result.chain_ids}")
print(f"\nüìÑ Per-chain CSV preview (Chain A):")
print(result.csv_per_chain['A'][:400])


‚úÖ Conversion complete!
   Paired sequences: 4
   Chain IDs: ['A', 'B']

üìÑ Per-chain CSV preview (Chain A):
key,sequence
1,MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG
2,LKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG
3,MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG
4,MKTVRQDRLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG


## 5. Submit Prediction to Boltz2 NIM

Now we submit the prediction with the paired MSA to Boltz2.


In [8]:
# Create per-chain MSA structures for Boltz2
msa_per_chain = create_paired_msa_per_chain(result)

# Create polymers with paired MSA
protein_A = Polymer(
    id="A",
    molecule_type="protein",
    sequence=result.query_sequences['A'],
    msa=msa_per_chain['A']
)

protein_B = Polymer(
    id="B",
    molecule_type="protein",
    sequence=result.query_sequences['B'],
    msa=msa_per_chain['B']
)

print("‚úÖ Polymers created with paired MSA:")
print(f"   Chain A: {len(protein_A.sequence)} residues")
print(f"   Chain B: {len(protein_B.sequence)} residues")


‚úÖ Polymers created with paired MSA:
   Chain A: 65 residues
   Chain B: 63 residues


In [9]:
# Create prediction request
request = PredictionRequest(
    polymers=[protein_A, protein_B],
    recycling_steps=3,
    sampling_steps=50,  # Lower for demo; use 200 for production
    diffusion_samples=1
)

async def run_prediction():
    """Submit prediction to Boltz2 NIM."""
    client = Boltz2Client(base_url="http://localhost:8002")
    
    # Check server health
    print("üì° Checking Boltz2 NIM server...")
    health = await client.health_check()
    print(f"   Server status: {health.status}")
    
    # Submit prediction
    print("\n‚è≥ Running prediction (this may take a minute)...")
    response = await client.predict(request)
    
    print("\n‚úÖ PREDICTION COMPLETE!")
    print(f"   Structures returned: {len(response.structures)}")
    
    return response

# Run the prediction
response = await run_prediction()


üì° Checking Boltz2 NIM server...
   Server status: healthy

‚è≥ Running prediction (this may take a minute)...

‚úÖ PREDICTION COMPLETE!
   Structures returned: 1


In [10]:
# Save and display structure
if response.structures:
    structure = response.structures[0]
    cif_content = structure.cif_data if hasattr(structure, 'cif_data') else str(structure)
    
    # Save structure
    output_cif = Path(tmpdir) / "heterodimer_complex.cif"
    output_cif.write_text(cif_content)
    print(f"üíæ Structure saved to: {output_cif}")
    
    # Count atoms
    atom_count = cif_content.count('ATOM ')
    print(f"   Total atoms: {atom_count}")
    
    # Show CIF preview
    print("\nüìÑ CIF Preview:")
    print("-" * 60)
    for line in cif_content.split('\n')[:12]:
        print(f"  {line[:70]}")


üíæ Structure saved to: /tmp/boltz2_notebook_sy50ei7_/heterodimer_complex.cif
   Total atoms: 979

üìÑ CIF Preview:
------------------------------------------------------------
  format='mmcif' structure="data_model\n_entry.id model\n_struct.entry_i


## 6. Summary

This notebook demonstrated the complete workflow:

1. **Created A3M files** with taxonomy annotations (OX= fields)
2. **Parsed and analyzed** the MSA files to extract TaxIDs
3. **Auto-detected** the pairing mode (TaxID in this case)
4. **Converted** to Boltz2's paired CSV format
5. **Submitted prediction** to Boltz2 NIM
6. **Received structure** with both chains

### Key Features

| Feature | Description |
|---------|-------------|
| **Auto-detection** | Automatically chooses TaxID or UniRef ID pairing |
| **ColabFold compatible** | Works with standard ColabFold output |
| **Greedy pairing** | Matches sequences in ‚â•2 chains (like ColabFold) |
| **27,890 species** | Built-in TaxID mapping database |

### CLI Alternative: One-Command Prediction

The easiest way - predict directly from A3M files with a single command:

```bash
# All-in-one: Convert + Predict in one step
boltz2 --base-url http://localhost:8002 multimer-msa \
    chain_A.a3m chain_B.a3m \
    -c A,B \
    -o complex.cif

# Save all outputs (structure + confidence scores + CSVs)
boltz2 --base-url http://localhost:8002 multimer-msa \
    chain_A.a3m chain_B.a3m \
    -c A,B \
    -o complex.cif \
    --save-all --save-csv

# With custom quality settings
boltz2 --base-url http://localhost:8002 multimer-msa \
    chain_A.a3m chain_B.a3m \
    -c A,B \
    -o complex.cif \
    --sampling-steps 400

# Force UniRef ID pairing for standard ColabFold output
boltz2 --base-url http://localhost:8002 multimer-msa \
    chain_A.a3m chain_B.a3m \
    -c A,B \
    --pairing-mode uniref
```

**Output files with `--save-all --save-csv`:**
```
output/
‚îú‚îÄ‚îÄ complex.cif              # 3D structure
‚îú‚îÄ‚îÄ complex.scores.json      # Confidence, pLDDT, pTM scores
‚îú‚îÄ‚îÄ complex_chain_A.csv      # Paired MSA
‚îî‚îÄ‚îÄ complex_chain_B.csv
```

### Multi-Endpoint Load Balancing

For high-throughput processing with multiple GPUs/NIMs:

```bash
# Use 4 Boltz2 NIMs with automatic load balancing
boltz2 --multi-endpoint \
    --base-url "http://gpu1:8000,http://gpu2:8000,http://gpu3:8000,http://gpu4:8000" \
    multimer-msa chain_A.a3m chain_B.a3m -c A,B

# Strategies: round_robin, least_loaded (default), random
```

### CLI Alternative: Convert Only

If you just want to convert A3M files to CSV without prediction:

```bash
boltz2 convert-msa chain_A.a3m chain_B.a3m -c A,B -o paired.csv
```


In [None]:
# Output location
print(f"üìÅ All output files saved to: {tmpdir}")
print(f"\nFiles:")
for f in Path(tmpdir).iterdir():
    print(f"   - {f.name}")
print("\n‚úÖ Notebook complete! You can delete the temp directory when done.")
