In [None]:
# default_exp embedding

In [None]:
#!pip install fair-esm

In [None]:
# export
import torch
import esm

from tqdm.notebook import tqdm,trange

In [None]:
# export
# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

model = model.to("cuda:1")

In [None]:
model.device

In [None]:
# export
def getSequenceRepresentation(Data):
    """
    Embed the given variants using the ESM-1b model
    
    Arguments:
    - Data : list[(id, sequence),]
    """
    batch_labels, batch_strs, batch_tokens = batch_converter(Data)
    batch_tokens = batch_tokens.to("cuda:1")
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33].cpu()
    del results, batch_labels, batch_strs, batch_tokens
    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_representations = []
    for i, (_, seq) in enumerate(Data):
        sequence_representations.append(token_representations[i, 1 : len(seq) + 1].cpu().numpy())
    del token_representations
    return sequence_representations#, results

In [None]:
def prepSeq(row,windowSize=510):
    s = row["sequence"]
    loc = int(row["loc"]) - 1
    sPrime = s[max(0, loc-windowSize) : min(len(s), loc + windowSize)]
    return sPrime

In [None]:
import pandas as pd
df = pd.read_csv("/data/projects/processBio/clinvar/clinvar/missenseVariants.csv")

In [None]:
for i,j in zip(df.loc[1,"sequence"],df.loc[2,"sequence"]):
    if i!=j:
        print(i,j)

In [None]:
def makeReference(row):
    aaTable = {"Ala":"A",
           "Arg": "R",
           "Asn": "N",
           "Asp": "D",
           "Cys": "C",
           "Gln": "Q",
           "Glu": "E",
           "Gly": "G",
           "His": "H",
           "Ile": "I",
           "Leu": "L",
           "Lys": "K",
           "Met": "M",
           "Phe": "F",
           "Pro": "P",
           "Ser": "S",
           "Thr": "T",
           "Trp": "W",
           "Tyr": "Y",
           "Val": "V"}
    
    og = aaTable[row["og"]]
    loc = int(row["loc"]) -1
    var = aaTable[row["newAA"]]
    print(og,row["sequence"][loc], var)
    assert row.sequence[loc] == og
    return row.sequence[:loc] + og + row.sequence[loc+1:]

In [None]:
df = df.assign(referenceSeq=df.apply(lambda row: makeReference(row),axis=1))

In [None]:
seqs = list(zip(df.index, [prepSeq(r) for i,r in tqdm(df.iterrows())]))

In [None]:
representations = []
batchSize=1
for i in trange(0,len(seqs),batchSize):
    rep = getSequenceRepresentation(seqs[i : i + batchSize])
    representations.append(rep)

In [None]:
import numpy as np

In [None]:
import pickle

In [None]:
pickle.dump(representations, open('/data/projects/processBio/clinvar/clinvar/embeddings.pkl',"wb"))

In [None]:
pickle.dump(seqs, open('/data/projects/processBio/clinvar/clinvar/seqs.pkl',"wb"))

In [None]:
representations[0][0].shape

In [None]:
len(seqs[0][1])

There is a one to one correspondance between seqs.pkl and embeddings.pkl