<a href="https://colab.research.google.com/github/aapratt/PROSTATA/blob/main/PROSTATA_tool.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code is provided according with Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
# Install dependecies and download weights

In [1]:
!pip install fair-esm
!pip install biopython

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Collecting biopython
  Downloading biopython-1.83-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biopython
Successfully installed biopython-1.83


In [2]:
from io import StringIO, BytesIO
from urllib.request import urlretrieve

import numpy as np
import torch
from torch import nn
import esm
from esm.pretrained import load_model_and_alphabet_hub
from Bio import SeqIO

In [3]:
model_names = [
    "ESMForSingleMutationPosOuter",
    "ESMForSingleMutationPosConcat",
    "ESMForSingleMutation_pos_cat_cls",
    "ESMForSingleMutation_pos",
    "ESMForSingleMutation_cls",
]

In [4]:
for model_name in model_names:
    urlretrieve(
        f"https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/prostata/mix_ds_s669_weights/{model_name}_mix_ds_s669",
        model_name,
    )

In [5]:
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5


class ESMForSingleMutationPosConcat(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])["representations"][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])["representations"][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits


HIDDEN_UNITS_POS_OUTER = 5


class ESMForSingleMutationPosOuter(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self._freeze_esm2_layers()
        self.fc1 = nn.Linear(1280 * 1280, HIDDEN_UNITS_POS_OUTER)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_OUTER, 1)

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[: initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])["representations"][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])["representations"][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outer_prod = outputs1_pos.unsqueeze(3) @ outputs2_pos.unsqueeze(2)
        outer_prod_view = outer_prod.view(outer_prod.shape[0], outer_prod.shape[1], -1)
        fc1_outputs = F.relu(self.fc1(outer_prod_view))
        logits = self.fc2(fc1_outputs)
        return logits


class ESMForSingleMutation_pos(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1, 1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1, 1280)))

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])["representations"][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])["representations"][33]
        outputs = self.const1 * outputs1[:, pos + 1, :] + self.const2 * outputs2[:, pos + 1, :]
        logits = self.classifier(outputs)
        return logits


class ESMForSingleMutation_cls(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1, 1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1, 1280)))

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])["representations"][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])["representations"][33]
        outputs = self.const1 * outputs1[:, 0, :] + self.const2 * outputs2[:, 0, :]
        logits = self.classifier(outputs.unsqueeze(0))
        return logits


class ESMForSingleMutation_pos_cat_cls(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.classifier = nn.Linear(1280 * 2, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1, 1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1, 1280)))

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])["representations"][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])["representations"][33]
        cls_out = self.const1 * outputs1[:, 0, :] + self.const2 * outputs2[:, 0, :]
        pos_out = self.const1 * outputs1[:, pos + 1, :] + self.const2 * outputs2[:, pos + 1, :]
        outputs = torch.cat([cls_out.unsqueeze(0), pos_out], axis=-1)
        logits = self.classifier(outputs)
        return logits

# Compute DeltaDDG

In [25]:
# Hardcoded seq from p53 use for testing
seq = "MGSSHHHHHHSSGLVPRGSHMVKLTLSALPALSPAVAVPAYDPRAQIPGIVHFGVGAFHRSHQAMYLDRLLNSGRGAGWAICGVGVLPQDARMRDVLAEQDHLYTLVTRSPDGQAQARVIGAIVEFLFAPDDPERVLERLADPTTRIVSLTVTEGGYSVSNATGEFDPTPPDIAHDLTPGAVPRTFFGFLTEGLRRRRERGLPPFTVVSCDNMPGNGEVTRRALTAFARLQDPELGDWIAHNVAFPNSMVDRITPATTEQDRQDIAAAYGIEDAWPVVAESFAQWVLEDRFTQGRPALETVGVQVVSDVEPYELMKLRLLNASHQALAYLGLLAGYRFVHEVCQDPLFARFLLDYMTQEATPTLRPVPGIDLGAYRRELIARFSNPAIRDPLTRLTVDSSERIPKFLLPVIRDQLARGGELARCALVIASWRAYLATVLEEGSASFPDQHAQALAEAVRRDAQQPGAFLDLEAVFGELGRNARFRTAYLSAWESLRRQGPLGAMRALMGEESSPSNVTSLSGR"  # @param {type:"string"}
mutation_code = "Q46V"  # @param {type:"string"}

In [15]:
seq = "MGSSHHHHHHSSGLVPRGSHMVKLTLSALPALSPAVAVPAYDPRAQIPGIVHFGVGAFHRSHQAMYLDRLLNSGRGAGWAICGVGVLPQDARMRDVLAEQDHLYTLVTRSPDGQAQARVIGAIVEFLFAPDDPERVLERLADPTTRIVSLTVTEGGYSVSNATGEFDPTPPDIAHDLTPGAVPRTFFGFLTEGLRRRRERGLPPFTVVSCDNMPGNGEVTRRALTAFARLQDPELGDWIAHNVAFPNSMVDRITPATTEQDRQDIAAAYGIEDAWPVVAESFAQWVLEDRFTQGRPALETVGVQVVSDVEPYELMKLRLLNASHQALAYLGLLAGYRFVHEVCQDPLFARFLLDYMTQEATPTLRPVPGIDLGAYRRELIARFSNPAIRDPLTRLTVDSSERIPKFLLPVIRDQLARGGELARCALVIASWRAYLATVLEEGSASFPDQHAQALAEAVRRDAQQPGAFLDLEAVFGELGRNARFRTAYLSAWESLRRQGPLGAMRALMGEESSPSNVTSLSGR"

#['Q46V', 'Q231K', 'A268R', 'A283T', 'A297P', 'R432A', 'T437G', 'Q449R']
mutation_codes = ['Q46V','A297P','Q449R']

In [26]:
# Get wildtype sequence, mutation position and mutated sequence
wt_aa = mutation_code[0]
mut_aa = mutation_code[-1]
mut_pos = int(mutation_code[1:-1]) - 1

wt = seq
tt = list(seq)
tt[mut_pos] = mut_aa
mut = "".join(tt)

model = torch.load("ESMForSingleMutation_cls", map_location=torch.device("cpu"))
esm2_alphabet = model.esm1v_alphabet
esm2batch_converter = esm2_alphabet.get_batch_converter()
_, _, esm2_batch_tokens1 = esm2batch_converter([("", wt[:1022])])
_, _, esm2_batch_tokens2 = esm2batch_converter([("", mut[:1022])])
esm2_batch_tokens1 = esm2_batch_tokens1.cuda()
esm2_batch_tokens2 = esm2_batch_tokens2.cuda()

In [31]:
# compute one multi-point mutation
seq = "MGSSHHHHHHSSGLVPRGSHMVKLTLSALPALSPAVAVPAYDPRAQIPGIVHFGVGAFHRSHQAMYLDRLLNSGRGAGWAICGVGVLPQDARMRDVLAEQDHLYTLVTRSPDGQAQARVIGAIVEFLFAPDDPERVLERLADPTTRIVSLTVTEGGYSVSNATGEFDPTPPDIAHDLTPGAVPRTFFGFLTEGLRRRRERGLPPFTVVSCDNMPGNGEVTRRALTAFARLQDPELGDWIAHNVAFPNSMVDRITPATTEQDRQDIAAAYGIEDAWPVVAESFAQWVLEDRFTQGRPALETVGVQVVSDVEPYELMKLRLLNASHQALAYLGLLAGYRFVHEVCQDPLFARFLLDYMTQEATPTLRPVPGIDLGAYRRELIARFSNPAIRDPLTRLTVDSSERIPKFLLPVIRDQLARGGELARCALVIASWRAYLATVLEEGSASFPDQHAQALAEAVRRDAQQPGAFLDLEAVFGELGRNARFRTAYLSAWESLRRQGPLGAMRALMGEESSPSNVTSLSGR"

#['Q46V', 'Q231K', 'A268R', 'A283T', 'A297P', 'R432A', 'T437G', 'Q449R']
mutation_codes = ['Q46V','A297P','Q449R']

seq_list=[]
res_list=[]
for x in mutation_codes:
    wt_aa = x[0]
    mut_aa = x[-1]
    mut_pos = int(x[1:-1]) - 1

    wt = seq
    tt = list(seq)
    tt[mut_pos] = mut_aa
    mut = "".join(tt)
    print(mut)
    #seq_list.append(mut)
    seq=mut

    model = torch.load("ESMForSingleMutation_cls", map_location=torch.device("cpu"))
    esm2_alphabet = model.esm1v_alphabet
    esm2batch_converter = esm2_alphabet.get_batch_converter()
    _, _, esm2_batch_tokens1 = esm2batch_converter([("", wt[:1022])])
    _, _, esm2_batch_tokens2 = esm2batch_converter([("", mut[:1022])])
    esm2_batch_tokens1 = esm2_batch_tokens1.cuda()
    esm2_batch_tokens2 = esm2_batch_tokens2.cuda()

    res = []
    for model_name in model_names:
        model = torch.load(model_name, map_location=torch.device("cpu"))
        model.eval()
        model.cuda()

        with torch.no_grad():
            res.append(
                model(token_ids1=esm2_batch_tokens1, token_ids2=esm2_batch_tokens2, pos=torch.LongTensor([mut_pos]))
                .cpu()
                .numpy()
            )
        print(f"Model {model_name} DDG prediction is {-1*res[-1][0,0,0]}")
    res = np.mean(res)
    print(f"Predicted DDG for the mutation {x} is {-1*res}")
    res_list.append(res)

sumres=sum(res_list)
print(f"Predicted DDG for the mutations {mutation_codes} is {-1*sumres}")


MGSSHHHHHHSSGLVPRGSHMVKLTLSALPALSPAVAVPAYDPRAVIPGIVHFGVGAFHRSHQAMYLDRLLNSGRGAGWAICGVGVLPQDARMRDVLAEQDHLYTLVTRSPDGQAQARVIGAIVEFLFAPDDPERVLERLADPTTRIVSLTVTEGGYSVSNATGEFDPTPPDIAHDLTPGAVPRTFFGFLTEGLRRRRERGLPPFTVVSCDNMPGNGEVTRRALTAFARLQDPELGDWIAHNVAFPNSMVDRITPATTEQDRQDIAAAYGIEDAWPVVAESFAQWVLEDRFTQGRPALETVGVQVVSDVEPYELMKLRLLNASHQALAYLGLLAGYRFVHEVCQDPLFARFLLDYMTQEATPTLRPVPGIDLGAYRRELIARFSNPAIRDPLTRLTVDSSERIPKFLLPVIRDQLARGGELARCALVIASWRAYLATVLEEGSASFPDQHAQALAEAVRRDAQQPGAFLDLEAVFGELGRNARFRTAYLSAWESLRRQGPLGAMRALMGEESSPSNVTSLSGR
Model ESMForSingleMutationPosOuter DDG prediction is -2.270242214202881
Model ESMForSingleMutationPosConcat DDG prediction is -1.2209725379943848
Model ESMForSingleMutation_pos_cat_cls DDG prediction is -2.0761771202087402
Model ESMForSingleMutation_pos DDG prediction is -2.2429118156433105
Model ESMForSingleMutation_cls DDG prediction is -0.16187715530395508
Predicted DDG for the mutation Q46V is -1.5944361686706543
MGSSHHHHHHSSGLVPRGSHMVKLTLSALPALSPAVAVPAYDPRAVIPGIVHFGV

In [2]:
tm_best=[['Q46V', 'A268R', 'A283T', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'A268R', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'Q231K', 'A283T', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'Q231K', 'A268R', 'A283T', 'A297P', 'R432A'],
['Q46V', 'A283T', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'Q231K', 'A268R', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'A268R', 'A283T', 'A297P', 'R432A'],
['Q46V', 'Q231K', 'A268R', 'A297P', 'R432A'],
['Q46V', 'Q231K', 'A268R', 'A283T', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'Q231K', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'A268R', 'A297P', 'R432A'],
['Q46V', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'Q231K', 'A283T', 'A297P', 'R432A'],
['Q46V', 'Q231K', 'A297P', 'R432A'],
['Q46V', 'A283T', 'A297P', 'R432A'],
['Q231K', 'A268R', 'A283T', 'A297P', 'R432A', 'Q449R'],
['Q231K', 'A268R', 'A297P', 'R432A', 'Q449R'],
['Q46V', 'A297P', 'R432A'],
['Q46V', 'Q231K', 'A268R', 'A283T', 'R432A', 'Q449R'],
['Q46V', 'Q231K', 'A268R', 'R432A', 'Q449R'],
['Q46V', 'A268R', 'A283T', 'A297P', 'Q449R']]



In [3]:
#compute ddG for list of multi-point mutations

#['Q46V', 'Q231K', 'A268R', 'A283T', 'A297P', 'R432A', 'T437G', 'Q449R']
#mutation_codes = ['Q46V','A297P','Q449R']

mutation_codes_list=tm_best
total_list=[]

for y in mutation_codes_list:
    seq = "MGSSHHHHHHSSGLVPRGSHMVKLTLSALPALSPAVAVPAYDPRAQIPGIVHFGVGAFHRSHQAMYLDRLLNSGRGAGWAICGVGVLPQDARMRDVLAEQDHLYTLVTRSPDGQAQARVIGAIVEFLFAPDDPERVLERLADPTTRIVSLTVTEGGYSVSNATGEFDPTPPDIAHDLTPGAVPRTFFGFLTEGLRRRRERGLPPFTVVSCDNMPGNGEVTRRALTAFARLQDPELGDWIAHNVAFPNSMVDRITPATTEQDRQDIAAAYGIEDAWPVVAESFAQWVLEDRFTQGRPALETVGVQVVSDVEPYELMKLRLLNASHQALAYLGLLAGYRFVHEVCQDPLFARFLLDYMTQEATPTLRPVPGIDLGAYRRELIARFSNPAIRDPLTRLTVDSSERIPKFLLPVIRDQLARGGELARCALVIASWRAYLATVLEEGSASFPDQHAQALAEAVRRDAQQPGAFLDLEAVFGELGRNARFRTAYLSAWESLRRQGPLGAMRALMGEESSPSNVTSLSGR"
    seq_list=[]
    res_list=[]
    print(y)
    for x in y:
        wt_aa = x[0]
        mut_aa = x[-1]
        mut_pos = int(x[1:-1]) - 1

        wt = seq
        tt = list(seq)
        tt[mut_pos] = mut_aa
        mut = "".join(tt)
        #print(mut)
        #seq_list.append(mut)
        seq=mut

        model = torch.load("ESMForSingleMutation_cls", map_location=torch.device("cpu"))
        esm2_alphabet = model.esm1v_alphabet
        esm2batch_converter = esm2_alphabet.get_batch_converter()
        _, _, esm2_batch_tokens1 = esm2batch_converter([("", wt[:1022])])
        _, _, esm2_batch_tokens2 = esm2batch_converter([("", mut[:1022])])
        esm2_batch_tokens1 = esm2_batch_tokens1.cuda()
        esm2_batch_tokens2 = esm2_batch_tokens2.cuda()

        res = []
        for model_name in model_names:
            model = torch.load(model_name, map_location=torch.device("cpu"))
            model.eval()
            model.cuda()

            with torch.no_grad():
                res.append(
                    model(token_ids1=esm2_batch_tokens1, token_ids2=esm2_batch_tokens2, pos=torch.LongTensor([mut_pos]))
                    .cpu()
                    .numpy()
                )
            #print(f"Model {model_name} DDG prediction is {-1*res[-1][0,0,0]}")
        res = np.mean(res)
        print(f"Predicted DDG for the mutation {x} is {-1*res}")
        res_list.append(res)

    sumres=sum(res_list)
    print(f"Predicted DDG for the mutations {y} is {-1*sumres}")
    total_list.append(-1*sumres)
    print(-1*sumres)
print(total_list)


['Q46V', 'A268R', 'A283T', 'A297P', 'R432A', 'Q449R']


NameError: name 'torch' is not defined

In [29]:
res = []
for model_name in model_names:
    model = torch.load(model_name, map_location=torch.device("cpu"))
    model.eval()
    model.cuda()

    with torch.no_grad():
        res.append(
            model(token_ids1=esm2_batch_tokens1, token_ids2=esm2_batch_tokens2, pos=torch.LongTensor([mut_pos]))
            .cpu()
            .numpy()
        )
    print(f"Model {model_name} DDG prediction is -{res[-1][0,0,0]}")
res = np.mean(res)

Model ESMForSingleMutationPosOuter DDG prediction is -2.270242214202881
Model ESMForSingleMutationPosConcat DDG prediction is -1.2209725379943848
Model ESMForSingleMutation_pos_cat_cls DDG prediction is -2.0761771202087402
Model ESMForSingleMutation_pos DDG prediction is -2.2429118156433105
Model ESMForSingleMutation_cls DDG prediction is -0.16187715530395508


In [28]:
print(f"Predicted DDG for the mutation {mutation_code} is -{res}")

Predicted DDG for the mutation Q46V is 1.5944361686706543
