In [1]:
import os
import argparse
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import QuantileTransformer
import numpy as np
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score


import torch.nn.functional as F
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, global_mean_pool
from torch.nn import Linear, Dropout
from torch_geometric.data import Dataset, DataLoader
from models import SelfAttention, HybridModel
from tqdm import tqdm 
import wandb

import dgl

In [2]:
import hashlib
def get_hash(x):
  return hashlib.sha1(x.encode()).hexdigest()

In [3]:
# Set the directory where your files are located
directory = 'D:\Edward\Smita_Lab_Work\PyGs\PyGs'

# Get a list of all .pt files in the directory
files = [f for f in os.listdir(directory) if f.endswith('.pt')]

# Initialize an empty list to store the graphs
graphs = []

# Loop through the files and load each graph, showing a progress bar
for file in tqdm(files, desc="Loading graphs"):
    file_path = os.path.join(directory, file)
    graph = torch.load(file_path)
    graphs.append(graph)

# Now `graphs` contains all the loaded graph objects
print(f"Loaded {len(graphs)} graphs.")


graphs = [x for x in graphs if ('NXVPMVATV' not in x.name) and ('X' not in x.name)]

strings = [x.name.split("Immuno")[1] for x in graphs]

print(len(strings), len(set(strings)))

Loading graphs: 100%|█████████████████████████████████████████████████████████████████| 24607/24607 [00:07<00:00, 3461.10it/s]

Loaded 24607 graphs.
24603 23339





In [4]:
new_graphs = []
names = set()

for graph in graphs:
    if graph.name.split("Immuno")[1] not in names:
        names.add(graph.name.split("Immuno")[1])
        new_graphs.append(graph)

strings = [x.name.split("Immuno")[1] for x in new_graphs]

print(len(strings), len(set(strings)))

23339 23339


In [5]:
count = pd.Series(strings).value_counts()
print("Element Count")
print(count)

Element Count
PKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRRVLSFIKGTK_29174    1
PRTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWASVVVPSGQEQRYTCHVQHEGLPKPLTLRIVQKAPIYKR_9ef65    1
PRTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWASVVVPSGQEQRYTCHVQHEGLPKPLTLRIVNKFMSFYK_ea184    1
PRTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWASVVVPSGQEQRYTCHVQHEGLPKPLTLRIVMPVFIIKR_d3bdf    1
PRTHMTHHAVSDHEATLRCWALSFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWASVVVPSGQEQRYTCHVQHEGLPKPLTLRIVLLFPSIIY_f638d    1
                                                                                                            ..
PPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRNTYGEGFDY_8d6bf    1
PPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRNTYGEGFDY_4ef5f    1
PPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRNTYASPRF

In [6]:
expanded_df = pd.read_table('D:\Edward\Smita_Lab_Work\complete_score_Mprops_1_2.csv')
expanded_df = expanded_df.dropna(subset='Foreignness_Score')
expanded_df['pep_pair'] = expanded_df['peptide'] + expanded_df['allele']

f_dict = dict(zip(expanded_df['pep_pair'],expanded_df['smoothed_foreign']))
fp2_dict = dict(zip(expanded_df['pep_pair'],expanded_df['Mprop1']))
new_imm_dict = dict(zip(expanded_df['pep_pair'],expanded_df['immunogenicity']))

expanded_pep_pair = expanded_df['pep_pair'].tolist()

In [7]:
print(len(expanded_pep_pair))

24539


In [8]:
expanded_pep_pair[:4]

['LSNSGKDVPKHLA-A*11:01',
 'TTLFHTFYELHLA-A*24:02',
 'KFGDLTNNFHLA-A*24:02',
 'KLFESKAELHLA-A*02:01']

In [9]:
#cut off h-bonding features for now 

for data in graphs:  # Assuming data_list is the list containing your graph data
    data.x = data.x[:, :-2]

hla_df = pd.read_csv('D:\Edward\Smita_Lab_Work\HLA_27_seqs_csv.csv')
hla_dict_true = dict(zip(hla_df['allele'], hla_df['seqs']))
print(len(hla_dict_true))

27


In [10]:
name_mapper = {} # pep_pair: actual sequence, name in pyG

for seq in expanded_pep_pair:
    pep, hla = seq.split("HLA-")
    unfolded = hla_dict_true["HLA-"+hla]
    name = unfolded + pep
    hashed = get_hash(name)[:5]
    name_mapper[seq] = (name, name[-99:]+"_"+hashed, pep)
print(len(name_mapper), len(set(name_mapper.keys())))
print(list(name_mapper.values())[:3])

24539 24539
[('SHSMRYFYTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDQETRNVKAQSQTDRVDLGTLRGYYNQSEDGSHTIQIMYGCDVGPDGRFLRGYRQDAYDGKDYIALNEDLRSWTAADMAAQITKRKWEAAHAAEQQRAYLEGRCVEWLRRYLENGKETLQRTDPPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRLSNSGKDVPK', 'PKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRLSNSGKDVPK_601a6', 'LSNSGKDVPK'), ('SHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDEETGKVKAHSQTDRENLRIALRYYNQSEAGSHTLQMMFGCDVGSDGRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQITKRKWEAAHVAEQQRAYLEGTCVDGLRRYLENGKETLQRTDPPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRTTLFHTFYEL', 'PKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRTTLFHTFYEL_36be2', 'TTLFHTFYEL'), ('SHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDEETGKVKAHSQTDRENLRIALRYYNQSEAGSHTLQMMFGCDVGSDGRFLRGYHQYAYDGKDYIALKEDLRSWTAADMAAQITKRKWEAAHVAEQQRAYLEGTC

In [11]:
# make both table and graphs size 1:1

# graph -> table
strings = [x.name.split("Immuno")[1] for x in new_graphs]
names = set(strings)
to_remove = []
for x, y in name_mapper.items():
    if y[1] not in names:
        to_remove.append(x)

for i in to_remove:
    del name_mapper[i]

print(len(name_mapper))

# table -> graph
to_remove = set()
mapper_names = set(y[1] for x, y in name_mapper.items())

for i in strings:
    if i not in mapper_names:
        to_remove.add(i)

new_graphs = [x for x in new_graphs if x.name.split("Immuno")[1] not in to_remove]
strings = [x for x in new_graphs]
print(len(strings))
print(len(to_remove))

23339
23339
0


In [12]:
graph_mapper = {x.name.split("Immuno")[1]: x for x in new_graphs}
print(len(graph_mapper))

23339


In [13]:
for x, y in name_mapper.items():
    
    immuno_score = new_imm_dict[x]
    f_score = f_dict[x]
    graph = graph_mapper[y[1]]

    
    graph.y = torch.tensor([immuno_score, f_score], dtype=torch.float)  # We use a one-element tensor for each graph-level label
    graph.x = torch.cat([graph.x, graph.coords], dim=-1)

    graph.x = graph.x.to(dtype=torch.float32)
    graph.y = graph.y.to(dtype=torch.float32)    

In [14]:
def pad_graph(graph, max_nodes, feature_size, coord_size):
    num_nodes_to_add = max_nodes - graph.num_nodes
    if num_nodes_to_add > 0:
        # Pad node features
        zero_features = torch.zeros(num_nodes_to_add, feature_size)
        padded_features = torch.cat([graph.x, zero_features], dim=0)

        # Pad coordinates
        zero_coords = torch.zeros(num_nodes_to_add, coord_size)
        padded_coords = torch.cat([graph.coords, zero_coords], dim=0)

        # Update the graph
        graph.x = padded_features
        graph.coords = padded_coords
        graph.num_nodes = max_nodes
    return graph

max_nodes = max(graph.num_nodes for graph in graph_mapper.values())
feature_size = 23  # Replace with the size of your feature vectors
coord_size = 3     # Replace with the size of your coordinate vectors

padded_graphs = {name: pad_graph(graph, max_nodes, feature_size, coord_size) for name, graph in graph_mapper.items()}

In [15]:
def to_dgl(pt_geometric_graph):
     # The number of edges is half the size of the second dimension of edge_index
    num_edges = pt_geometric_graph.edge_index.size(1)
    
    # Create a tensor of ones with the size equal to the number of edges
    # Assuming all edges have a single feature, which is set to 1
    pt_geometric_graph.edge_attr = torch.ones((num_edges, 1))

    # Convert to DGL graph
    src, dst = pt_geometric_graph.edge_index
    dgl_graph = dgl.graph((src, dst), num_nodes=pt_geometric_graph.num_nodes)
    dgl_graph.ndata['x'] = pt_geometric_graph.x  # Node features
    dgl_graph.edata['edge_attr'] = pt_geometric_graph.edge_attr  # Edge attributes
    return dgl_graph

graph_mapper = {name: to_dgl(graph) for name, graph in padded_graphs.items()}

In [16]:
print(max(len(y[0]) for x, y in name_mapper.items()))
print(min(len(y[0]) for x, y in name_mapper.items()))

283
281


In [17]:
# Function to pad peptide sequences
def pad_peptide_sequence(sequence, max_length=11, padding_char='J'):
    # Pad the sequence with the padding character to reach the max length
    padded_sequence = sequence.ljust(max_length, padding_char)
    return padded_sequence

name_mapper = {x:(pad_peptide_sequence(a, 283), b, pad_peptide_sequence(c)) for x, (a, b, c) in name_mapper.items()}

In [18]:
def one_hot_encode_sequence(sequence, amino_acids = 'ACDEFGHIKLMNPQRSTVWY', padding_char='J'):
    # Create a dictionary mapping each amino acid and padding character to an integer
    char_to_int = dict((c, i) for i, c in enumerate(amino_acids + padding_char))

    # Initialize the one-hot encoded matrix for the sequence
    one_hot_encoded = np.zeros((len(sequence), len(char_to_int)))

    # Fill the one-hot encoded matrix with appropriate values
    for i, char in enumerate(sequence):
        if char in char_to_int:  # Only encode known characters
            one_hot_encoded[i, char_to_int[char]] = 1
        else:
            print("unknown character: {}", char)
    
    return one_hot_encoded

encoded_full_sequence_map = {x:one_hot_encode_sequence(a) for x, (a, b, c) in name_mapper.items()}
print("peptide")
encoded_peptide_map = {x:one_hot_encode_sequence(c) for x, (a, b, c) in name_mapper.items()}

peptide


In [23]:
names = [(x, a, b, c) for x, (a, b, c) in name_mapper.items()]

encoded_full_sequence = [encoded_full_sequence_map[x[0]] for x in names]
encoded_peptide_sequence = [encoded_peptide_map[x[0]] for x in names]

protein_reg_values = [fp2_dict[x[0]] for x in names]
protein_immuno_values = [new_imm_dict[x[0]] for x in names]
protein_reg_values_f = [f_dict[x[0]] for x in names]

dgl_filtered_graphs = [graph_mapper[x[2]] for x in names]

In [27]:
# TO CHECK
to_remove = set()
cache = dict()
dupe = 0
double_dupe = 0 

for n, (a, b, c, d) in enumerate(zip(encoded_full_sequence, protein_reg_values, protein_immuno_values, protein_reg_values_f)):
    overlap = (tuple(map(tuple, a)), b, c , d)
    if overlap in cache:
        dupe+=1
        if (dgl_filtered_graphs[cache[overlap]].num_nodes() == dgl_filtered_graphs[n].num_nodes() and 
            dgl_filtered_graphs[cache[overlap]].num_edges() == dgl_filtered_graphs[n].num_edges() and 
            dgl_filtered_graphs[cache[overlap]].ndata['x'].tolist() == dgl_filtered_graphs[n].ndata['x'].tolist() and 
            dgl_filtered_graphs[cache[overlap]].edata['edge_attr'].tolist() == dgl_filtered_graphs[n].edata['edge_attr'].tolist() and 
            dgl_filtered_graphs[cache[overlap]].edges()[0].tolist() == dgl_filtered_graphs[n].edges()[0].tolist()):
            double_dupe+=1
            to_remove.add(n)
    else:
        cache[overlap] = n
print(dupe, double_dupe)

0 0


In [26]:
# TO CHECK
to_remove = set()
cache = dict()
dupe = 0
double_dupe = 0 

for n, (a, b) in enumerate(zip(encoded_full_sequence, protein_reg_values)):
    overlap = (tuple(map(tuple, a)), b)
    if overlap in cache:
        dupe+=1
        if (dgl_filtered_graphs[cache[overlap]].num_nodes() == dgl_filtered_graphs[n].num_nodes() and 
            dgl_filtered_graphs[cache[overlap]].num_edges() == dgl_filtered_graphs[n].num_edges() and 
            dgl_filtered_graphs[cache[overlap]].ndata['x'].tolist() == dgl_filtered_graphs[n].ndata['x'].tolist() and 
            dgl_filtered_graphs[cache[overlap]].edata['edge_attr'].tolist() == dgl_filtered_graphs[n].edata['edge_attr'].tolist() and 
            dgl_filtered_graphs[cache[overlap]].edges()[0].tolist() == dgl_filtered_graphs[n].edges()[0].tolist()):
            double_dupe+=1
            to_remove.add(n)
    else:
        cache[overlap] = n
print("dupes", dupe, double_dupe)

dupes 0 0
