In [None]:
import pandas as pd
import torch
import esm
from tqdm.notebook import tqdm

In [None]:
# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()
model = model.to("cuda:1")

In [None]:
def getSequenceRepresentation(Data):
    batch_labels, batch_strs, batch_tokens = batch_converter(Data)
    batch_tokens = batch_tokens.to("cuda:1")
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33].cpu()
    del results, batch_labels, batch_strs, batch_tokens
    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_representations = []
    for i, (_, seq) in enumerate(Data):
        sequence_representations.append(token_representations[i, 1 : len(seq) + 1].cpu().numpy())
    del token_representations
    return sequence_representations#, results

In [None]:
# df = pd.read_csv("/ssdata/clinvar/clinvar/parseClinvarWithVariant.csv")
df = pd.read_csv("/data/projects/processBio/clinvar/clinvar/parseClinvarWithVariant.csv")

In [None]:
df[["GeneID_x","GeneID_y","variantSeq"]]

In [None]:
df[df.GeneID_x != df.GeneID_y].GeneSymbol

In [None]:
def prepSeq(s,loc,windowSize=510):
    sPrime = s[max(0, loc-windowSize) : min(len(s), loc + windowSize)]
    return sPrime

In [None]:
representations = []
batchSize=1
for idx,row in tqdm(df.iterrows(),total=df.shape[0]):
    try:
        rep = getSequenceRepresentation([(idx,prepSeq(row.variantSeq, int(row.variant[3:-3])-1))])
    except TypeError:
        rep = []
    representations.append(rep)

In [None]:
df[[len(e) == 0 for e in representations]]

In [None]:
import pickle

In [None]:
pickle.dump(representations,open("/data/projects/processBio/clinvar/clinvar/parseClinvarRepresentations.pkl","wb"))

In [None]:
! du -hs /data/projects/processBio/clinvar/clinvar/parseClinvarRepresentations.pkl