In [1]:
import esm
import torch
import time
import gc
import os
import numpy as np
%run "../scripts/data_processing.py"
%run "../scripts/node_edge_generation.py"
%run "../scripts/graph_functions.py"

#Data path
datadir = os.path.abspath("../../Data/")

In [2]:
#Start graph
startGraph("./pass_ent.txt", "./Results/")

#Get Disease intel
uri = "bolt://localhost:7687" 
driver = GraphDatabase.driver(uri)

disease_dic = {}

#“Which traits are associated with gene $gene’s_name?”
query_1 = [f'match (p:Protein)-[di:direct_evidence]-(d:Disease)\n \
            where  toFloat(di.score) > 0.4 \
            and size(p.seq) < 1500 \
            return d.name as Disease, d.description as Description, \
            count(distinct p.uniprot) as Prot_count, collect(distinct [p.uniprot, p.seq]) as All_seqs  \
            order by Prot_count desc']

results_1 = run_query(query_1[0], driver)
display(results_1)

#Close connection
driver.close()

#Create subset with smaller sequences
#smallseqs = results_1.loc[results_1["Sequence"].str.len() <= 1500]

Graph stopped
Database import successful!
Graph started


Unnamed: 0,Disease,Description,Prot_count,All_seqs
0,genetic disorder,Genetic diseases are diseases in which inherit...,372,"[[P35716, MVQQAESLEAESNLPREALDTEEGEFMACSPVALDE..."
1,cancer,"A tumor composed of atypical neoplastic, often...",301,"[[Q13480, MSGGEVVCSGWLRKSPPEKKLKRYAWKRRWFVLRSG..."
2,body height,The distance from the sole to the crown of the...,268,"[[Q9H336, MKCTAREWLRVTTVLFMARAIPAMVVPNATLLEKLL..."
3,body mass index,An indicator of body density as determined by ...,234,"[[O75912, MDAAGRGCHLLPLPAARGPARAPAAAAAAAASPPGP..."
4,HIV infection,An infection caused by the human immunodeficie...,225,"[[O43242, MKQEGSARRRGADKAKPPPGGGEQEPPPPPAPQDVE..."
...,...,...,...,...
5572,Proximal myotonic myopathy,"Myotonic dystrophy type 2 (MD2), also known as...",1,"[[P62633, MSSNECFKCGRSGHWARECPTGGGRGRGMRSRGRGG..."
5573,Choanal atresia-deafness-cardiac defects-dysmo...,Choanal atresia - deafness - cardiac defects -...,1,"[[P83876, MSYMLPHLHNGWQVDQAILSEEDRVVVIRFGHDWDP..."
5574,Autosomal dominant spastic paraplegia type 4,Autosomal dominant spastic paraplegia type 4 (...,1,"[[Q9UBP0, MNSPGGRGKKKGSGGASNPVPPRPPPPCLAPAPPAA..."
5575,Epithelial recurrent erosion dystrophy,Epithelial recurrent erosion dystrophy (ERED) ...,1,"[[Q9UMD9, MDVTKKNKRDGTEVTERIVTETVTTRLTSLPPKGGT..."


## Sequence embedding

In [3]:
torch.cuda.empty_cache()
print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
# Load ESM-2 model
#https://github.com/facebookresearch/esm#available-models
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

Allocated memory: 0.00 MB
Cached memory: 0.00 MB
Allocated memory: 0.00 MB
Cached memory: 0.00 MB


In [4]:
def EmbedGen(inputlist, model):
    with torch.no_grad():
        embdic = {}
        #Load model to gpu
        if torch.cuda.is_available():
            model = model.cuda()
        #Tokenise + pad sequences
        batch_labels, batch_strs, batch_tokens = batch_converter(inputlist)
        #Iterate over chunks of seqs
        #truestart = time.time()
        for ent in range(len(batch_tokens)):
           # start = time.time()
            #Clear GPU memory
            out          = None
            batch_subset = None
            gc.collect()
            torch.cuda.empty_cache()
            
            #Generate embeddings, apend to dictionary
            batch_subset = batch_tokens[ent].to(device="cuda", non_blocking=True)
            out = model(batch_subset.unsqueeze(0), repr_layers=[33], return_contacts=False)["representations"][33]
            out = out.cpu()
            embdic.setdefault(batch_labels[ent], out)
            #print(f"Time taken (Batch {ent}) = {round(time.time()-start, 2)} seconds \t| Total ({round(time.time()-truestart, 2)})")

    return embdic

#Generate sequence embeddings, return dictionary of all disease <> embeds
alldic = {}
count = 0
truestart = time.time()
for index, row in results_1.iterrows():
    if 20 <= row["Prot_count"] <= 40:
        print("#"*10, "\n", count, row["Disease"], row["Prot_count"])
        initdic = EmbedGen(row["All_seqs"], model)
        alldic.setdefault(row["Disease"], initdic)
        print(f"Time taken (total) = {round(time.time()-truestart, 2)}")
        count += 1

########## 
 0 intraocular pressure measurement 40
Time taken (total) = 29.56
########## 
 1 opioid dependence 40
Time taken (total) = 66.35
########## 
 2 ovarian cancer 40
Time taken (total) = 100.85
########## 
 3 prostate adenocarcinoma 40
Time taken (total) = 137.72
########## 
 4 Agitation 40
Time taken (total) = 174.47
########## 
 5 renal cell carcinoma 40
Time taken (total) = 208.8
########## 
 6 protein measurement 39
Time taken (total) = 240.42
########## 
 7 unspecified peripheral T-cell lymphoma 39
Time taken (total) = 273.84
########## 
 8 Cone rod dystrophy 38
Time taken (total) = 304.63
########## 
 9 brain injury 38
Time taken (total) = 339.5
########## 
 10 inflammatory bowel disease 38
Time taken (total) = 368.08
########## 
 11 Isolated NADH-CoQ reductase deficiency 37
Time taken (total) = 383.09
########## 
 12 obsessive-compulsive disorder 37
Time taken (total) = 409.15
########## 
 13 Primary ciliary dyskinesia 37
Time taken (total) = 433.39
########## 
 14 hair 

## Generate pdist / diversity metric

In [6]:
from scipy.spatial.distance import pdist, squareform
diverse_dict = {}

for dis in alldic:
    #Append classification token to list
    listembd = []
    for ent in alldic[dis]:
        listembd.append(alldic[dis][ent][:, 0, :])
        
    #Convert to 2D array
    listembd = torch.vstack(listembd).numpy()
    
    ##Generate pdist diveristy score
    diversity = np.mean(pdist(listembd, metric='cosine'))
    diverse_dict.setdefault(dis, diversity)

#Convert to pandas df
diverse_df = pd.DataFrame(diverse_dict.items(), columns=["Disease", "Diversity_(pdis)"]).sort_values(by=["Diversity_(pdis)"], ascending=False)
display(diverse_df)

Unnamed: 0,Disease,Diversity_(pdis)
125,Blackfan-Diamond anemia,0.031649
30,Mitochondrial disease,0.027962
62,Rare familial disorder with hypertrophic cardi...,0.023731
85,myelodysplastic syndrome,0.019125
16,leukemia,0.017834
...,...,...
135,"osteoarthritis, knee",0.008919
129,metastatic colorectal cancer,0.008299
140,Myoclonus,0.008070
90,thyroid cancer,0.007067


In [59]:
display(pd.DataFrame(diverse_dict.items(), columns=["Disease", "Diversity_(pdis)"]).sort_values(by=["Diversity_(pdis)"], ascending=False))

Unnamed: 0,Disease,Diversity_(pdis)
4,Fragile X syndrome,0.021274
5,Hepatic steatosis,0.021046
6,intermittent vascular claudication,0.021044
0,insulin resistance,0.021023
1,prediabetes syndrome,0.020885
13,systemic lupus erythematosus,0.016576
10,alcohol drinking,0.01527
11,breast neoplasm,0.015187
3,myocardial infarction,0.014101
14,lung adenocarcinoma,0.014086


In [17]:
listembd = torch.stack(listembd).numpy()


In [23]:
def dict_to_array(embed_dict):
    if isinstance(next(iter(embed_dict.values())), torch.Tensor):
        stacked = torch.stack(list(embed_dict.values()))
        return stacked.numpy()

array = dict_to_array(alldic['insulin resistance'])

Diversity: 0.0210
