# Welcome to AMBROSIA: carbohydrate binding residues predictor


# Notebook initialization

In [3]:
#@title ## Load AMBROSIA Github repository
import os
if os.getcwd() == "/content":
    !git clone https://github.com/DSIMB/AMBROSIA.git
    %cd AMBROSIA

Cloning into 'AMBROSIA'...
remote: Enumerating objects: 48, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 48 (delta 14), reused 40 (delta 12), pack-reused 0 (from 0)[K
Receiving objects: 100% (48/48), 10.51 MiB | 3.89 MiB/s, done.
Resolving deltas: 100% (14/14), done.
/content/AMBROSIA


In [4]:
#@title ## Install necessary dependencies
!pip install torch fair-esm ankh plotly py3Dmol matplotlib

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Collecting ankh
  Downloading ankh-1.10.0-py3-none-any.whl.metadata (18 kB)
Collecting py3Dmol
  Downloading py3Dmol-2.4.2-py2.py3-none-any.whl.metadata (1.9 kB)
Collecting biopython<2.0,>=1.80 (from ankh)
  Downloading biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting datasets<3.0.0,>=2.7.1 (from ankh)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting sentencepiece<0.2.0,>=0.1.97 (from ankh)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets<3.0.0,>=2.7.1->ankh)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets<3.0.0,>=2.7.1->ankh)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets<3.0.0,>=2.7

In [5]:
#@title ## Import Necessary libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import esm
import ankh
from plotly import graph_objects as go
from google.colab import drive
from collections import defaultdict
import numpy as np
import py3Dmol
from google.colab import files
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from urllib.request import urlopen
# drive.mount('/content/drive')


In [7]:
#@title # Read input sequence
sequence_name = "Anti-sigma-W factor RsiW"
input_sequence = """
MGNNTGTILEIKGNKAIVMTNTCDFIAITRMPEMFVGQQVDLNNSAIKSKSNPLKYFAIAGMFVLILCSVLIYQLVKPSAVFAYVDVDINPSLELLIDKKANVIEVKTLNSDADALVKDIRLVNKSLTNAVKIIIKESQNKGFIRPDTKNAVLISASINPGKSISSAVSSEKILDVIVSDLQKTDFSIGAVSIKAEVVKVDPIERSEAVKNNISMGRYKLFEEITESDENIDIEKAKTEGLSKIIEEYETKEQEKTIASVDKDNSYKPVQDNKEILDKPKNSTTKDNPKVADNKKPENNNSQKYSNGNSNSSKSSAVKPNKAEDQFKASRSNSENNSSNNRDQSKNTNKKSSDEKKTLDQGSKPITTDDGTKSLNNKNNNKNNDEKPKNHPAKENKQENGNNNQQKSKEKNKK
"""

input_sequence = input_sequence.replace('\n', '')
start_index = 1
sequence_labels = [f"{aa}{i}" for i, aa in enumerate(input_sequence, start=start_index)]

# Generate embeddings

In [8]:
#@title ## Generate ESM-2 embedding

esm2_data = [(sequence_name, input_sequence)]

# ESM-2
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

batch_labels, batch_strs, batch_tokens = batch_converter(esm2_data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33])
esm2_embedding = results["representations"][33][0,1:-1]

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


In [9]:
#@title ## Generate Ankh embedding

model, tokenizer = ankh.load_large_model()
model.eval()

ankh_data = [list(input_sequence)]
outputs = tokenizer.batch_encode_plus(ankh_data,
                                    add_special_tokens=True,
                                    padding=True,
                                    is_split_into_words=True,
                                    return_tensors="pt")
with torch.no_grad():
    embeddings = model(input_ids=outputs['input_ids'], attention_mask=outputs['attention_mask'])
    ankh_embedding = embeddings.last_hidden_state[0, :-1]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.85k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/849 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/7.52G [00:00<?, ?B/s]

In [11]:
#@title ## Check resulting embedding dimensions

print(esm2_embedding.shape, len(input_sequence), ankh_embedding.shape)
assert esm2_embedding.shape[0] == len(input_sequence) == ankh_embedding.shape[0], "Something went wrong during embedding generation"

torch.Size([413, 1280]) 413 torch.Size([413, 1536])


# Ambrosia Model definition and parameters load

In [None]:
activation_string_to_function = {
    'relu': F.relu,
    'tanh': F.tanh,
}

class CNN(nn.Module):
    def __init__(self, in_channels, hidden_layers=[512], dropout=0.1,
                 kernel_size=11, activation_function='relu'):
        super().__init__()
        if not isinstance(hidden_layers, list):
            hidden_layers = [hidden_layers]

        # Initialize convolutional layers. We use a ModuleList to store them.
        hidden_layers = [in_channels] + hidden_layers + [1]
        self.conv_layers = nn.ModuleList()
        for n_layers, n_layers2 in zip(hidden_layers[:-1], hidden_layers[1:]):
            self.conv_layers.append(
                nn.Conv1d(n_layers, n_layers2, kernel_size=kernel_size, padding='same'))
        self.dropout = nn.Dropout(p=dropout)
        self.activation_function = activation_string_to_function.get(activation_function, 'relu')

    def forward(self, x):
        """Forward pass through the convolutional layers."""
        for conv in self.conv_layers[:-1]:
            x = self.activation_function(self.dropout(conv(x)))
        x = self.conv_layers[-1](x)
        return x.squeeze()

In [None]:
# Define embedding sizes
embedding_sizes = {
    "esm2": esm2_embedding.shape[1] if esm2_embedding.shape[1] else 1280,
    "ankh": ankh_embedding.shape[1] if ankh_embedding.shape[1] else 1560
}
# Initialize models dictionary
models = defaultdict(dict)
results = defaultdict(dict)
# Define paths to model parameters
param_paths = {
    "ankh": "/content/drive/MyDrive/ambrosia_models/models_yangfan/ambrosia_ankh_fold{}.pt",
    "esm2": "/content/drive/MyDrive/ambrosia_models/models_yangfan/ambrosia_esm2_fold{}.pt"
}


In [None]:
# Load models and evaluate
for model_type in ["ankh", "esm2"]:
    for fold in range(5):
        model = CNN(in_channels=embedding_sizes[model_type], kernel_size=32)
        param_path = param_paths[model_type].format(fold)
        model.load_state_dict(torch.load(param_path, map_location='cpu'))
        model.eval()
        models[model_type][fold] = model

# Perform inference and store results
with torch.no_grad():
    for model_type in ["ankh", "esm2"]:
        embedding = ankh_embedding if model_type == "ankh" else esm2_embedding
        for fold in range(5):
            model = models[model_type][fold]
            logits = model(embedding.transpose(0, 1))
            results[model_type][fold] = logits

In [None]:
# Sum logits and calculate probabilities and classes
def calculate_probabilities_and_classes(results):
    all_logits = []
    ankh_logits = []
    esm2_logits = []

    for model_type in results:
        for fold in results[model_type]:
            logits = results[model_type][fold].numpy()
            all_logits.append(logits)
            if model_type == "ankh":
                ankh_logits.append(logits)
            elif model_type == "esm2":
                esm2_logits.append(logits)

    all_logits_sum = np.sum(all_logits, axis=0)
    ankh_logits_sum = np.sum(ankh_logits, axis=0)
    esm2_logits_sum = np.sum(esm2_logits, axis=0)

    all_probs = torch.sigmoid(torch.tensor(all_logits_sum)).numpy()
    ankh_probs = torch.sigmoid(torch.tensor(ankh_logits_sum)).numpy()
    esm2_probs = torch.sigmoid(torch.tensor(esm2_logits_sum)).numpy()

    return all_probs, ankh_probs, esm2_probs

all_probs, ankh_probs, esm2_probs = calculate_probabilities_and_classes(results)

# Function to plot probabilities
def plot_probabilities(sequence_labels, meta_probs, esm2_probs, ankh_probs, fold_probs_esm2, fold_probs_ankh):
    fig = go.Figure()

    # Add meta average probabilities by default
    fig.add_trace(go.Scatter(
        x=sequence_labels,
        y=meta_probs,
        mode='lines',
        fill='tozeroy',
        name='Meta Average Probabilities',
        visible=True
    ))

    # Add esm2 average probabilities
    fig.add_trace(go.Scatter(
        x=sequence_labels,
        y=esm2_probs,
        mode='lines',
        fill='tozeroy',
        name='ESM2 Average Probabilities',
        visible='legendonly'
    ))

    # Add ankh average probabilities
    fig.add_trace(go.Scatter(
        x=sequence_labels,
        y=ankh_probs,
        mode='lines',
        fill='tozeroy',
        name='Ankh Average Probabilities',
        visible='legendonly'
    ))

    # Add fold probabilities for esm2 and ankh
    for fold in range(5):
        fig.add_trace(go.Scatter(
            x=sequence_labels,
            y=fold_probs_esm2[fold],
            mode='lines',
            fill='tozeroy',
            name=f'ESM2 Fold {fold} Probabilities',
            visible='legendonly'
        ))
        fig.add_trace(go.Scatter(
            x=sequence_labels,
            y=fold_probs_ankh[fold],
            mode='lines',
            fill='tozeroy',
            name=f'Ankh Fold {fold} Probabilities',
            visible='legendonly'
        ))

    # Add threshold line
    fig.add_shape(
        type='line',
        x0=0,
        y0=0.5,
        x1=len(sequence_labels) - 1,
        y1=0.5,
        line=dict(color='Red', dash='dash'),
        name='Threshold'
    )

    # Update layout
    fig.update_layout(
        title='',
        xaxis_title='Amino Acid Position',
        yaxis_title='Probabilities',
        xaxis=dict(tickmode='linear', tickvals=list(range(0, len(sequence_labels), 10)), ticktext=sequence_labels[::10], dtick=10),
        yaxis=dict(range=[0, 1]),
        template='plotly_white'
    )

    fig.show()

# Calculate individual fold probabilities
fold_probs_esm2 = [torch.sigmoid(torch.tensor(results["esm2"][fold].numpy())).numpy() for fold in range(5)]
fold_probs_ankh = [torch.sigmoid(torch.tensor(results["ankh"][fold].numpy())).numpy() for fold in range(5)]

# Plot all probabilities
plot_probabilities(sequence_labels, all_probs, esm2_probs, ankh_probs, fold_probs_esm2, fold_probs_ankh)

# [Experimental] View on structure



In [None]:
# Function to read PDB file from URL
def fetch_pdb_from_url(url):
    response = urlopen(url)
    pdb_data = response.read().decode('utf-8').splitlines()
    return pdb_data

# Prompt user for input method
input_method = input("Enter 'upload' to upload a PDB file or 'url' to provide a URL: ").strip().lower()

if input_method == 'upload':
    uploaded = files.upload()
    pdb_filename = next(iter(uploaded))
    with open(pdb_filename, 'r') as file:
        pdb_data = file.readlines()
elif input_method == 'url':
    pdb_url = input("Enter the URL of the PDB file: ").strip()
    pdb_data = fetch_pdb_from_url(pdb_url)
else:
    raise ValueError("Invalid input method. Enter 'upload' or 'url'.")
# Modify B-factors in the PDB file
new_pdb_data = []
prob_index = 0
last_res_id = None
for line in pdb_data:
    if line.startswith("ATOM"):
        res_id = line[22:26].strip()
        if res_id != last_res_id:
            if prob_index < len(all_probs):
                new_line = line[:60] + f"{all_probs[prob_index]*100:6.2f}" + line[66:]
                prob_index += 1
            else:
                new_line = line
            last_res_id = res_id
        else:
            new_line = line[:60] + f"{all_probs[prob_index-1]*100:6.2f}" + line[66:]
        new_pdb_data.append(new_line)
    else:
        new_pdb_data.append(line)

# Save the modified structure to ensure line breaks are correct
modified_pdb_filename = "modified_structure.pdb"
with open(modified_pdb_filename, 'w') as file:
    file.writelines('\n'.join(new_pdb_data))

Enter 'upload' to upload a PDB file or 'url' to provide a URL: url
Enter the URL of the PDB file: https://alphafold.ebi.ac.uk/files/AF-B8I2U8-F1-model_v4.pdb


In [None]:
# Lire le fichier PDB modifié
with open(modified_pdb_filename, 'r') as file:
    pdb_content = file.read()

# Visualiser la structure avec Py3Dmol
view = py3Dmol.view(width=800, height=600)
view.addModel(pdb_content, "pdb")
view.setStyle({'cartoon': {'color': 'grey'}})

# Colorier les résidus en fonction des probabilités
for i, prob in enumerate(all_probs):
    if prob > 0.5:
      resi = str(i + 1)
      view.addStyle({'resi': resi}, {'stick': {'color': 'red'}})

view.zoomTo()
view.show()