# MULAN inference

In [1]:
import os
os.environ["HF_HOME"] = '/workspace/data/transformers_cache'
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
from torch.utils.data import DataLoader

from mulan.dataset import ProteinDataset, data_collate_fn_dynamic
from mulan.model import StructEsmForMaskedLM
from mulan.model_utils import auto_detect_base_tokenizer

## Preparation
To install repo run the command inside the folder MULAN:
```bash
git clone https://github.com/DFrolova/MULAN.git
cd MULAN
pip install -e .
```

In case of experimental structures, you firstly need to preprocess them with `pdbfixer` to restore missing atoms, remove water from the PDB file, etc.
Firstly, you need to install `pdbfixer`. It should be installed in a separate environment for structure processing:
```bash
conda create -n pdb_processing python=3.7
conda activate pdb_processing
conda install pdbfixer=1.8.1 -c conda-forge
conda install biopython=1.78
```
Then, you run the script, providing correct paths for `initial_structure_path` and `preprocessed_structure_path`.
```bash
python scripts/preprocess_experimental_structures.py
```

Also, you should download and put the foldseek binary file into the `mulan/bin` folder following the instructions provided in the [SaProt repo](https://github.com/westlake-repl/SaProt?tab=readme-ov-file#convert-protein-structure-into-structure-aware-sequence). 
Currently, they provide it [here](https://drive.google.com/file/d/1B_9t3n_nlj8Y3Kpc_mMjtMdY0OPYa7Re/view).
Do not forget to add the rights for execution for foldseek (`chmod +x bin/foldseek`).
If you do not need foldseek sequences (you use only MULAN based on ESM-2), you can pass `extract_foldseek_in_tokenizer=False` when initializing the `ProteinDataset`.
Thus, you do not need to download and use the foldseek binary file.

### Define data paths and properties

1) **Collect protein structures.**
Put the protein structures (either `.pdb` or `.cif` files) you need to encode into the folder. 
The whole path to this folder is `protein_data_path`.
If you need to extract certain chains from the structure (for experimental structures), you need to specify the chain name in the name of the file in the format: `{protein_id}_{chain}.pdb`.
Othervise, the first chain would be extracted.

In [5]:
# Data 
protein_data_path = '/workspace/data/docking/test_structures/PDB/tmp_preprocessed_structures/' # specify the path to the folder with pdb files you need to pass to the model
saved_dataset_path = '/workspace/data/docking/test_structures/PDB_dataset/' # specify the path where to save the preprocessed dataset
is_experimental_structure = True # flag for either AlphaFold structures or experimental ones

# Model
use_foldseek_sequences = False # True if use SaProt initialization for MULAN. Else False

In [3]:
# # Data 
# protein_data_path = <> # specify the path to the folder with pdb files you need to pass to the model
# saved_dataset_path = <> # specify the path where to save the preprocessed dataset
# is_experimental_structure = False # flag for either AlphaFold structures or experimental ones

# # Model
# use_foldseek_sequences = False # True if use SaProt initialization for MULAN. Else False

### Initialize dataset and dataloader

2) **Load the dataset and preprocess the data.**
By default, the dataset would not aggeragate proteins into batches (1 portein = 1 batch). 
However, you can pass `use_sorted_batching=True` to the dataset (and still pass `batch_size=1` to the dataloader!) to aggregate proteins with similar lengths into the batch (maximim `batch_limit` tokens per batch) for a faster inference. 
Further steps of code suppose that each batch contains one protein.

Note that initializing dataset requires firstly to preprocess and tokenize all provided protein structures.
The results will be stored in the `saved_dataset_path` folder.
Further, if you reuse this dataset, the preprocessing step is not required if you store the data inside the `saved_dataset_path`.

During preprocessing, foldseek sequences are also extracted. 
This is done because SaProt-initialized MULAN uses Foldseek.
If you use only ESM-2, you can pass argument `extract_foldseek_in_tokenizer=False` into the `ProteinDataset` class.

In [6]:
dataset = ProteinDataset(
    protein_data_path=protein_data_path, 
    saved_dataset_path=saved_dataset_path,
    use_foldseek_sequences=use_foldseek_sequences,
    # batch_limit=5000, 
    # use_sorted_batching=True,
    # extract_foldseek_in_tokenizer=False,
    is_experimental_structure=is_experimental_structure,
)

INIT ANGLES None


Tokenizing data:   0%|                                                                                                      | 0/1 [00:00<?, ?it/s]

structure <Structure id=>
structure[0] <Model id=0>


Tokenizing data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.31it/s]
Preproc structural data: 1it [00:00, 2006.84it/s]
1it [00:00, 3187.16it/s]


Protein lengths before: 102 102
Protein lengths after: 102 102
Check for lengths 102 102 EEDTAILYPFTISGNDRNGNFTINFKGTPNSTNNGCIGYSYNGDWEKIEWEGSCDGNGNLVVEVPMSKIPAGVTSGEIQIWWHSGDLKMTDYKALEHHHHHH
self.angles 1 torch.Size([102, 11])
use_sorted_batching False
1 1 1


Preproc angles: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3858.61it/s]

self.angles 1 (1, 104, 7)
self.plddts 1 (1, 102)
[[100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100. 100.
  100. 100. 100. 100.]]






### Load the model

Download the pre-trained model from [Zenodo](TODO) and put in into the checkpoint path, which you need to specify in the code. 

In [8]:
# Model
checkpoint_path = '/workspace/data/docking/lang_model/data/lora_train_results/mulan_small/' # path to the folder containing `config.json` and `model.safetensors` files.

In [9]:
# # Model
# checkpoint_path = <> # path to the folder containing `config.json` and `model.safetensors` files.

In [10]:
model = StructEsmForMaskedLM.from_pretrained(
    checkpoint_path,
    device_map="auto",
)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device)

Some weights of StructEsmForMaskedLM were not initialized from the model checkpoint at /workspace/data/docking/lang_model/data/lora_train_results/mulan_small/ and are newly initialized: ['contact_head.regression.bias', 'contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


StructEsmForMaskedLM(
  (esm): StructEsmModel(
    (embeddings): StructEsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
      (struct_embeddings): StructEmbeddings(
        (MLP): Linear(in_features=7, out_features=320, bias=True)
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0): EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): Linear(in_features=320, out_features=320, bias=True)
                  (key): Linear(in_features=320, out_features=320, bias=True)
                  (value): Linear(in_features=320, out_features=320, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): EsmSelfOutput(
                  (dense): Linear(in_features=320, out_features=320, bias=True)
                  (dropo

### Prepare the dataloader

In [11]:
esm_tokenizer = auto_detect_base_tokenizer(model.config, use_foldseek_sequences)

# Initialize dataloader
def data_collator(x): 
    if use_foldseek_sequences:
        one_letter_aas = esm_tokenizer.all_tokens[5:]
    else: 
        one_letter_aas = dataset.tokenizer.one_letter_aas

    return data_collate_fn_dynamic(x, 
        esm_tokenizer=esm_tokenizer,
        nan_value=np.deg2rad(dataset.tokenizer.nan_fill_value),
        mask_inputs=False,
        all_amino_acids=one_letter_aas,
        use_foldseek_sequences=use_foldseek_sequences)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=data_collator)

## Run inference

Run inference to extract residue-level embeddings from the last layer of MULAN. 
Each batch contains one protein sequence.

In [12]:
with torch.no_grad():
    for i, (batch, batch_names) in enumerate(zip(dataloader, dataset.protein_names)):
        struct_inputs = [struct_input.to(device) for struct_input in batch['struct_inputs']]
        # extract embeddings for each batch (1 sequence per batch)
        embeddings = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_mask'].to(device),
            struct_inputs=struct_inputs,
            output_hidden_states=True
        )['hidden_states'][-1]
        embeddings = embeddings[0][1:-1] # residue-level embeddings for the case of 1 protein per batch
        
        # If you want to get protein-level embeddings, you shoud perform average pooling:
        protein_embedding = embeddings.mean(dim=0)
        print(batch_names, embeddings.shape, protein_embedding.shape)

['5aot_A'] torch.Size([102, 320]) torch.Size([320])
