In [None]:
from processBioDBs.utilities import getSequence

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("/data/projects/processBio/gnomad/gnomad.exomes.r2.1.1.sites.vcf.vep",header=93,delimiter="\t")

In [None]:
df.shape

In [None]:
missense = df[df.Consequence == "missense_variant"]

In [None]:
missense.shape

In [None]:
missense = missense.assign(INFO=missense.Extra.apply(lambda s: dict([l.split("=") for l in s.split(";")])))

In [None]:
symbols = set(i["SYMBOL"] if "SYMBOL" in i else "" for i in missense["info"])

In [None]:
len(symbols)

In [None]:
from tqdm.notebook import tqdm

In [None]:
enspIDS = set(missense.INFO.apply(lambda i: i["ENSP"]))

In [None]:
from multiprocessing import Pool

def f(eid):
    return ensembl_rest.sequence_id(eid)["seq"]

idMapping = {}

for eid in tqdm(enspIDS):
    idMapping[eid] = f(eid)

In [None]:
def getSeq(row,):
    # get the original (reference) amino acid, the variant, and the location of the variant (1-based so you need to convert to zero based)
    og,var = row.Amino_acids.split("/")
    loc = int(row.Protein_position) - 1
    eid = row["INFO"]["ENSP"]
    s = idMapping[eid]
    if s == "" or loc >= len(s) or s[loc]!=og:
        return ""
    sequence = s[max(0,loc-510):loc] + var + s[loc+1:min(len(s), loc+1+510)]
    return sequence

In [None]:
missense = missense.assign(seq=missense.apply(lambda row: getSeq(row),axis=1))

In [None]:
1 - (missense.seq != "").sum() / missense.shape[0]

In [None]:
missense

In [None]:
missense.to_pickle("/data/projects/processBio/gnomad/gnomad.missenseVariants.pd.pkl")

# Embed Sequences

In [None]:
missense = pd.read_pickle("/data/projects/processBio/gnomad/gnomad.missenseVariants.pd.pkl")

In [None]:
import torch
import esm
import os
import torch.nn as nn
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1,2,3'


model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

# model = nn.DataParallel(model.cuda())

In [None]:
model = model.to("cuda:0")

In [None]:
WINDOW_SIZE = 500

Data = list(missense.apply(lambda row: row.seq[max(0,
                                                   int(row.Protein_position) - 1 - WINDOW_SIZE) : min(len(row.seq),
                                                                                              int(row.Protein_position) + WINDOW_SIZE + 1)],axis=1).items())

In [None]:
from tqdm.notebook import trange

In [None]:
BATCHSIZE=1
representations = []
for start in trange(0,len(Data),BATCHSIZE):
    batch_labels, batch_strs, batch_tokens = batch_converter(Data[start : start + BATCHSIZE])
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens.to("cuda:0"), repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33].cpu()
    del results, batch_labels, batch_strs, batch_tokens
    representations.append(token_representations[0,1:-1].cpu().numpy())
    del token_representations

In [None]:
missense = missense.assign(representation=representations)

In [None]:
from processBioDBs.utilities import prepSeq,getRep

In [None]:
prepSeq??

In [None]:
missense.head()

In [None]:
missense = missense.assign(xi=missense.apply(lambda row: prepSeq(row.representation,
                                                                 int(row.Protein_position) - 1,
                                                                 originalWindowSize=500),axis=1))

In [None]:
missense[missense.xi.isna()]

In [None]:
badRows = missense[(missense.xi.apply(lambda xi: np.isnan(xi).any())) & (missense.seq.str.len() > 0)]

In [None]:
badRows.apply(lambda row: len(row.seq) < int(row.Protein_position) - 1,axis=1).all()

In [None]:
import numpy as np

In [None]:
X = np.stack(missense[(missense.seq.str.len() > 0) & missense.apply(lambda row: int(row.Protein_position) - 1 < len(row.seq),axis=1)].xi.values)

In [None]:
np.isnan(X).any()

In [None]:
np.save("/data/projects/processBio/gnomad/gnomadValidMissenseVariantEmbeddings.npy",X)

In [None]:
X.shape