In [None]:
from ProtMamba_ssm.core import *
from ProtMamba_ssm.dataloaders import *
from ProtMamba_ssm.utils import *
from ProtMamba_ssm.modules import *
import torch
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import pickle
import pandas as pd
import os

In [None]:
from Bio import Align

def delete_masks(seq):
    masks = ["<mask-1>", "<mask-2>", "<mask-3>", "<mask-4>", "<mask-5>", "<cls>"]
    for mask in masks:
        seq = seq.replace(mask, "")
    return seq

aligner = Align.PairwiseAligner()
aligner.mode = 'global'
aligner.match_score = 1
aligner.mismatch_score = -1
aligner.open_gap_score = -1
aligner.extend_gap_score = -1

def align_sequences(ref_seq, query_seq, print_alignments=False):
    def hamming_str(s1,s2):
        assert len(s1) == len(s2)
        return sum(np.array(list(s1)) != np.array(list(s2)))/len(s1)
    alignments = aligner.align(ref_seq, query_seq)
    if print_alignments:
        print("Score = %.1f:" % alignments[0].score)
        print(alignments[0])
    return hamming_str(alignments[0][0], alignments[0][1]), alignments[0][0], alignments[0][1]

seq1 = "ACDEFGHIKLMNPQRST"
seq2 = "ACDEEGHKLMNQRSTVWY"
align_sequences(seq1, seq2), align_sequences(seq1, seq1)

In [None]:
import string
from Bio import SeqIO
import pyhmmer

alphabet = pyhmmer.easel.Alphabet.amino()

# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

def read_msa(filename: str):
    """ Reads the sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]

def read_msa_unaligned(filename: str):
    """ Reads the sequences from an MSA file, removes only . - and * characters."""
    return [(record.description, str(record.seq).replace(".","").replace("-","").replace("*","").upper()) for record in SeqIO.parse(filename, "fasta")]

def check_msa(msa):
    """ Checks if there are any repeated sequences in the MSA"""
    seqs = set()
    for el in msa:
        seqs.add(el[1])
    assert len(seqs) == len(msa), "There are repeated sequences in the MSA"
    
def make_hmm_from_a3m_msa(msa_filepath, hmm_filename=None):
    # Load MSA from a3m
    msa_tup = read_msa(msa_filepath)
    # check_msa(msa_tup)
    # Create digitized MSA block
    all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa_tup)]
    msa  = pyhmmer.easel.TextMSA(name=b"msa", sequences=all_seqs)
    msa = msa.digitize(alphabet)
    # Fit HMM
    builder = pyhmmer.plan7.Builder(alphabet)
    background = pyhmmer.plan7.Background(alphabet)
    hmm, _, _ = builder.build_msa(msa, background)
    if hmm_filename is not None:
        with open(f"{hmm_filename}.hmm", "wb") as output_file:
            hmm.write(output_file)
    return hmm

def align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=None, sequences_list=None):
    if sequences_list is not None:
        msa = sequences_list
        all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, seq in enumerate(sequences_list)]
    elif sequences_path is not None:
        # Load sequences from a3m
        msa = read_msa_unaligned(sequences_path)
        all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode("utf-8"), sequence=seq) for i, (idz, seq) in enumerate(msa)]
    else:
        raise NotImplementedError("Missing sequences to align/score")
    # Create digitized Sequence block
    seq_block = pyhmmer.easel.TextSequenceBlock(all_seqs)
    seq_block = seq_block.digitize(alphabet)
    # Get all hits from the hmm
    background = pyhmmer.plan7.Background(alphabet)
    pipeline = pyhmmer.plan7.Pipeline(alphabet, background=background, bias_filter=False, F1=1.0, F2=1.0, F3=1.0)
    hits = pipeline.search_hmm(hmm, seq_block)
    if len(hits) != len(msa):
        print(f"Number of hits: {len(hits)} is different from the number of sequences in the MSA: {len(msa)}")
    # Extract hits
    all_hits = {}
    for hit in hits:
        idz, score, evalue = hit.name, hit.score, hit.evalue
        i = int(idz.decode("utf-8"))
        seq = msa[i][1] if sequences_path is not None else sequences_list[i]
        all_hits[seq] = {"score": score, "evalue": evalue}
    return all_hits

In [None]:
from transformers import AutoTokenizer, EsmForProteinFolding
from Bio.PDB import *
import tmscoring

pdb_parser = PDBParser()

def align_structures_CEalign(path_ref, path_query, key=None):
    if key:
        key = "-"+key
    ref_structure = pdb_parser.get_structure("reference", path_ref)
    query_structure = pdb_parser.get_structure("query", path_query)

    aligner = cealign.CEAligner()
    aligner.set_reference(ref_structure[0])
    aligner.align(query_structure[0])
    rmsd = aligner.rms  
    # Save new aligned structure
    io=PDBIO()
    io.set_structure(query_structure)
    str_path = str(path_query).split(".")
    io.save(str_path[0] + f"_aligned"+key+"."+str_path[1])
    return rmsd

def align_structures_TMscore(path_ref, path_query, key=None):
    if key:
        key = "-"+key
    alignment = tmscoring.TMscoring(path_ref, path_query)
    # Find the optimal alignment
    alignment.optimise()
    # Get the TM score:
    tmscore = alignment.tmscore(**alignment.get_current_values())
    # RMSD of the protein aligned according to TM score
    rmsd = alignment.rmsd(**alignment.get_current_values())
    str_path = str(path_query).split(".")
    alignment.write(outputfile=str_path[0] + f"_aligned"+key+"."+str_path[1], appended=True)
    return tmscore, rmsd

def compute_structure(seq, model, struct_path, ref_struct_path, alignment_func=align_structures_TMscore):
    def keep_sequence(seq, l):
        if len(seq) > l:
            return False
        for mm in list(MASK_TO_ID.keys())+["<eos>", "<pad>", "<unk>", "<mask>", "<cls>", "<null_1>", "." , "-"]:
            if mm in seq:
                return False
        return True
    keep = keep_sequence(seq, l=750)
    if keep:
        with torch.no_grad():
            output = model.infer([seq])
        pdb = model.output_to_pdb(output)
        ptm = output["ptm"].item()
        pae = output["predicted_aligned_error"].cpu().numpy()
        mean_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2))).item()
        pos_plddt = ((output["plddt"] * output["atom37_atom_exists"]).sum(dim=(2,)) / output["atom37_atom_exists"].sum(dim=(2,))).cpu().numpy()
        with open(struct_path, "w") as f:
            f.write(pdb[0])
        tmscore, rmsd = alignment_func(ref_struct_path, struct_path, key="")
    else:
        print(f"Sequence {struct_path} is too long")
        ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = 0, 0, 0, 0, 0, 0
    return ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore

## Import generated sequences

In [None]:
name_data = "check-131k(13-18)_gen_seqs_full"
with open(f"figures/generated_sequences/{name_data}.pkl", "rb") as f:
    gen_seqs = pickle.load(f)
fim_generation = True if input("Is it a FIM generated dataset? (y/n): ") == "y" else False

is_fim = True
dataset_name = "encoded_MSAs_test.pkl"
fim_strategy = "multiple_span"
num_natural = 100
# Load the dataset used for training
dataset = Uniclust30_Dataset(dataset_name,
                             filepath="/data1/common/OpenProteinSet/",
                             sample=False,
                             max_msa_len=-1,
                             max_patches=5,
                             mask_fraction=0.2,
                             fim_strategy=fim_strategy,
                             max_position_embeddings=2048,
                             add_position_ids="1d")

## Pairwise align generated sequences with natural sequences

In [None]:
if fim_generation:
    for j in tqdm(gen_seqs.keys()):
        for i, (key, dict_seqs) in enumerate(gen_seqs[j].items()):
            for seq in dict_seqs.keys():
                # Modified positions in fim generated parts
                # print(gen_seqs[j][key][seq]["generated_input_fim"], gen_seqs[j][key][seq]["original_input_fim"])
                new, orig = delete_masks(gen_seqs[j][key][seq]["generated_input_fim"]), delete_masks(gen_seqs[j][key][seq]["original_input_fim"])
                gen_seqs[j][key][seq]["fim_distance/fim_size"] = (sum(np.array(list(new)) != np.array(list(orig))), len(new))
else:
    all_hamming_ctx = {}
    for j in tqdm(gen_seqs.keys()):
        # Select a sample of the dataset to be the input
        data = dataset[j]
        family_id = dataset.cluster_names[j]
        tokens = data["input_ids"][None,:].to("cuda")
        pos_ids = data["position_ids"][None,:].to("cuda")
        # Find baseline hamming distances between natural sequences
        all_context = decode_sequence(tokens[0].cpu().numpy())
        list_sequences_msa = [reorder_masked_sequence(elem+"<cls>") for elem in all_context.split("<cls>")[1:-1]]
        all_hamming_ctx[family_id] = []
        rd_idxs = np.random.choice(len(list_sequences_msa), num_natural)
        for tmp_seq in [el for i, el in enumerate(list_sequences_msa) if i in rd_idxs]:
            all_hamming = []
            for ctx_seq in list_sequences_msa:
                if ctx_seq == tmp_seq:
                    continue
                else:
                    hamming, _, _ = align_sequences(ctx_seq, tmp_seq , print_alignments=False)
                    all_hamming.append(hamming)
            all_hamming_ctx[family_id].append(all_hamming)
            
        for i, (key, dict_seqs) in enumerate(gen_seqs[j].items()):
            for seq in dict_seqs.keys():
                # Hamming distances between generated sequences and natural sequences
                all_hamming = []
                for ctx_seq in list_sequences_msa:
                    hamming, _, _ = align_sequences(ctx_seq, reorder_masked_sequence(seq), print_alignments=False)
                    all_hamming.append(hamming)
                gen_seqs[j][key][seq]["hamming"] = np.array(all_hamming)
    with open(f"figures/generated_sequences/all_hamming_ctx_{name_data}.pkl", "wb") as f:
        pickle.dump(all_hamming_ctx, f)

### Make dataframe with all sequences

In [None]:
all_gen_seqs = {}
for j in tqdm(gen_seqs.keys()):
    for params in gen_seqs[j].keys():
        dict_seqs = gen_seqs[j][params]
        n_seqs_ctx , temperature, top_k, top_p = params
        for seq, values in dict_seqs.items():
            perplexity = values["perplexity"]
            all_gen_seqs[seq] = {"family": j, "family_id": dataset.cluster_names[j], "perplexity": perplexity,
                             "n_seqs_ctx": n_seqs_ctx, "temperature": temperature, "top_k": top_k, "top_p": top_p}
            if fim_generation:
                all_gen_seqs[seq]["original_input"] = values["original_input"]
                all_gen_seqs[seq]["original_input_fim"] = values["original_input_fim"]
                all_gen_seqs[seq]["generated_input_fim"] = values["generated_input_fim"]
                all_gen_seqs[seq]["fim_distance"], all_gen_seqs[seq]["fim_size"] = values["fim_distance/fim_size"]
                all_gen_seqs[seq]["original_sequence"] = reorder_masked_sequence(values["original_input"] + values["original_input_fim"])
                all_gen_seqs[seq]["generated_sequence"] = reorder_masked_sequence(values["original_input"] + values["generated_input_fim"])
                assert len(all_gen_seqs[seq]["original_sequence"]) == len(all_gen_seqs[seq]["generated_sequence"])
            else:
                all_gen_seqs[seq]["hamming"] = values["hamming"]
                all_gen_seqs[seq]["min_hamming"] = np.min(values["hamming"])
                all_gen_seqs[seq]["generated_sequence"] = reorder_masked_sequence(seq)
                all_gen_seqs[seq]["sequence_length"] = len(all_gen_seqs[seq]["generated_sequence"])

df = pd.DataFrame.from_dict(all_gen_seqs, orient="index")
df.reset_index(inplace=True, drop=True)
df.head()

In [None]:
df.describe()

In [None]:
df.to_pickle(f"figures/generated_sequences/dataframe_{name_data}.pkl")

## HMMER scores

In [None]:
families = df["family_id"].unique()
all_scores_ctx = {}
for family_id in tqdm(families):
    msa_filepath = f"figures/pdb_structures/msas/{family_id}.a3m"
    try:
        hmm = make_hmm_from_a3m_msa(msa_filepath)
    except:
        raise Exception(f"Missing MSA of family {family_id}")
    # find all df entries with the same family and align them
    family_df = df[df["family_id"] == family_id]
    sequences = family_df["generated_sequence"].values
    scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)
    # save the scores associated to each sequence in the main df in the columns "score" and "evalue"
    for seq in sequences:
        df.loc[df["generated_sequence"] == seq, "score_gen"] = scores[seq]["score"] if seq in scores.keys() else 0
        df.loc[df["generated_sequence"] == seq, "evalue_gen"] = scores[seq]["evalue"] if seq in scores.keys() else 1
    if fim_generation:
        sequences = family_df["original_sequence"].values
        scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)
        # save the scores associated to each sequence in the main df in the columns "score" and "evalue"
        for seq in sequences:
            df.loc[df["original_sequence"] == seq, "score_orig"] = scores[seq]["score"] if seq in scores.keys() else 0
            df.loc[df["original_sequence"] == seq, "evalue_orig"] = scores[seq]["evalue"] if seq in scores.keys() else 1
    else:
        scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=msa_filepath)
        all_scores_ctx[family_id] = {"score": [scores[seq]["score"] for seq in scores.keys()],
                                     "evalue": [scores[seq]["evalue"] for seq in scores.keys()]}
if not fim_generation:
    with open(f"figures/generated_sequences/all_hmmer_ctx_{name_data}.pkl", "wb") as f:
        pickle.dump(all_scores_ctx, f)

In [None]:
df.to_pickle(f"figures/generated_sequences/dataframe_{name_data}.pkl")
df.head()

## Structure prediction (ESMFold)

In [None]:
# Import the folding model
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

model = model.cuda("cuda:0")
model.esm = model.esm.half()
torch.backends.cuda.matmul.allow_tf32 = True
# model.trunk.set_chunk_size(64)

In [None]:
from pathlib import Path
# Load refernce structures OpenFold
structures_dir = Path("figures/pdb_structures/")
# list all pdb files in the directory
structures_paths = {path.stem: path for path in structures_dir.glob("*.pdb")}
print(f"Reference structures: ", *list(structures_paths.keys()))

In [None]:
# Check if structures of representatives are already computed

representatives_dir = Path("figures/pdb_structures/esmfold/representatives/")
representatives_paths = list((path.stem) for path in representatives_dir.glob("*.pdb"))

bool_var = False
for name in structures_paths.keys():
    if (name + "_esmfold" not in representatives_paths) or (name + "_esmfold_aligned" not in representatives_paths):
        print(f"Missing representative structure for {name}")
        bool_var = True

if bool_var and input("Do you want to compute the representative structures using esmfold? (y/n): ") == "y":
    ppb=PPBuilder()

    openproteinset = {}
    for name, ref_struct_path in structures_paths.items():
        exp_struct = pdb_parser.get_structure("exp", ref_struct_path)
        seq = ppb.build_peptides(exp_struct[0]["A"])
        sequence = "".join([str(sq.get_sequence()) for sq in seq])
        plddt = [residue["CA"].get_bfactor()/100 for residue in exp_struct[0]["A"]]
        openproteinset[name] = {"sequence": sequence, "plddt": plddt}
        assert len(openproteinset[name]["sequence"]) == len(openproteinset[name]["plddt"])
    conf = {}
    for name in openproteinset.keys():
        seq = openproteinset[name]["sequence"]
        ref_struct_path = str(structures_paths[name])
        struct_path = f"figures/pdb_structures/esmfold/representatives/{name}.pdb"
        ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)
        conf[name] = {"ptm": ptm, "pae": pae, "mean_plddt": mean_plddt, "pos_plddt": pos_plddt, "rmsd": rmsd, "tmscore": tmscore}
    with open(f"figures/generated_sequences/all_structures_representatives.pkl", "wb") as f:
        pickle.dump(conf, f)

In [None]:
# df = pd.read_pickle(f"figures/generated_sequences/dataframe_{name_data}.pkl")
# with open(f"figures/generated_sequences/all_structures_ctx_{name_data}.pkl", "rb") as f:
#     all_structures_ctx = pickle.load(f)
# families = df["family_id"].unique()

In [None]:
dir_path = f"figures/pdb_structures/esmfold/generated_{name_data}/"
os.mkdir(dir_path)
all_structures_ctx = {}

for family_id in tqdm(families):
    # find all df entries with the same family
    print(f"Family {family_id}")
    family_df = df[df["family_id"] == family_id]
    ref_struct_path = str(structures_paths[family_id])
    sequences = family_df["generated_sequence"].values
    for seq in tqdm(sequences):
        # get index of sequence in dataframe
        indx = df[df["generated_sequence"] == seq].index[0]
        struct_path = dir_path+f"{family_id}_{indx}_gen.pdb"
        # compute the structure
        if "ptm_gen" not in df.columns:
            # add column
            df["ptm_gen"] = np.nan
        if not df[df["generated_sequence"] == seq]["ptm_gen"].values[0] > 0:
            ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)
            df.loc[df["generated_sequence"] == seq, ["ptm_gen",
                                                    #  "pae_gen",
                                                    "mean_plddt_gen",
                                                    "rmsd_gen",
                                                    "tmscore_gen"]] = ptm, mean_plddt, rmsd, tmscore
        if fim_generation:
            # get plddt values of masked positions
            input_orig = df.loc[indx, "original_input"]
            input_fim_gen = df.loc[indx, "generated_input_fim"]
            full_gen, inds_masks = reorder_masked_sequence(input_orig+input_fim_gen, return_ids=True)
            assert full_gen == seq
            plddt_masked = 0
            if isinstance(pos_plddt, np.ndarray):
                assert pos_plddt.shape[1] == len(full_gen)
                plddt_masked = [el for tup in inds_masks for el in pos_plddt[0,tup[0]:tup[1]]]
                assert len(plddt_masked) == df.loc[indx, "fim_size"]
                plddt_masked = np.mean(plddt_masked)
            df.loc[df["generated]_sequence"] == seq, ["masked_plddt_gen"]] = plddt_masked
    if fim_generation:
        sequences = family_df["original_sequence"].values
        for seq in tqdm(sequences):
            # get index of sequence in dataframe
            indx = df[df["original_sequence"] == seq].index[0]
            struct_path = dir_path+f"{family_id}_{indx}_orig.pdb"
            # compute the structure
            ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)
            # get plddt values of masked positions
            input_orig = df.loc[indx, "original_input"]
            input_fim_orig = df.loc[indx, "original_input_fim"]
            full_orig, inds_masks = reorder_masked_sequence(input_orig+input_fim_orig, return_ids=True)
            assert full_orig == seq
            plddt_masked = 0
            if isinstance(pos_plddt, np.ndarray):
                assert pos_plddt.shape[1] == len(full_orig)
                plddt_masked = [el for tup in inds_masks for el in pos_plddt[0,tup[0]:tup[1]]]
                assert len(plddt_masked) == df.loc[indx, "fim_size"]
                plddt_masked = np.mean(plddt_masked)
            df.loc[df["original_sequence"] == seq, ["ptm_orig",
                                                    # "pae_orig",
                                                    "mean_plddt_orig",
                                                    "masked_plddt_orig",
                                                    "rmsd_orig",
                                                    "tmscore_orig"]] = ptm, mean_plddt, plddt_masked, rmsd, tmscore
            # compare the structure of the original sequence with the one of the generated fim sequence
            tmscore, rmsd = 0, 0
            if df["ptm_gen"][indx] != 0:
                struct_path_gen = dir_path+f"{family_id}_{indx}_gen.pdb"
                tmscore, rmsd = align_structures_TMscore(struct_path, struct_path_gen, key="")
            df.loc[df["original_sequence"] == seq, ["tmscore_orig_gen", "rmsd_orig_gen"]] = tmscore, rmsd
    else:
        msa = read_msa_unaligned(f"figures/pdb_structures/msas/{family_id}.a3m")
        if family_id not in all_structures_ctx.keys():
            all_structures_ctx[family_id] = {}
            subset_seq_ids = np.random.choice(len(msa), num_natural, replace=False)
            for i in subset_seq_ids:
                _, seq = msa[i]
                struct_path = f"figures/pdb_structures/esmfold/natural/{family_id}_{i}.pdb"
                ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)
                all_structures_ctx[family_id][i] = {"ptm": ptm, "mean_plddt": mean_plddt, "rmsd": rmsd, "tmscore": tmscore}
    # save temporary dataframe
    df.to_pickle(f"figures/generated_sequences/dataframe_{name_data}.pkl")        
    if not fim_generation:   
        with open(f"figures/generated_sequences/all_structures_ctx_{name_data}.pkl", "wb") as f:
            pickle.dump(all_structures_ctx, f)  

In [None]:
df.to_pickle(f"figures/generated_sequences/dataframe_{name_data}.pkl")

In [None]:
df.head()