In [1]:
import torch
import pandas as pd

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Data loading and preparation

In [2]:
edges = pd.read_csv("../data/interim/edges.csv")

In [3]:
overcl = pd.read_csv("../data/interim/clones_over.csv")

In [4]:
overcl.columns = ["node1","clone"]

In [5]:
overcl.clone.unique()

array(['16', '10', '0', 'diploid', '12', '8', '13', '14', '1', '7', '15',
       '4', '6', '9', '5', '11', '2', '3'], dtype=object)

In [6]:
edges = edges.drop(columns = ["clone"])


In [7]:
overcl = edges[edges["type"] == "sc2vis"].merge(overcl,left_on = "node1",right_on ="node1",how = "left")

In [8]:
overcl.clone = overcl.clone.fillna("diploid")

In [9]:
overcl.clone.unique()

array(['diploid', '11', '2', '10', '8', '6', '1', '13', '4', '15', '14',
       '9', '7', '0', '5', '12', '16', '3'], dtype=object)

In [10]:
edges= edges.merge(overcl[["clone","node1"]],left_on = "node1",right_on ="node1",how = "left")

In [11]:
emb_vis = pd.read_csv("../data/interim/embedding_visium_scvi.csv", index_col = 0)
emb_rna = pd.read_csv("../data/interim/embedding_rna2vis_scvi.csv", index_col = 0)

In [12]:
all_nodes_graph = set(edges.node1.to_list() + edges.node2.to_list())
all_nodes_emb = set(emb_vis.index).union(set(emb_rna.index))
all_nodes = list(all_nodes_graph.intersection(all_nodes_emb))
node_encoder = {all_nodes[i]:i for i in range(len(all_nodes))}
emb_vis = emb_vis.loc[emb_vis.index.isin(all_nodes)]
emb_rna = emb_rna.loc[emb_rna.index.isin(all_nodes)]

In [13]:
edges = edges[edges.node1.isin(all_nodes)]

In [14]:
edges = edges[edges.node2.isin(all_nodes)]

In [15]:
edges.node1 = edges.node1.map(node_encoder)

In [16]:
edges.node2 = edges.node2.map(node_encoder)

In [17]:
emb_vis = emb_vis.rename(index = node_encoder)
emb_rna = emb_rna.rename(index = node_encoder)

In [18]:
len(all_nodes)

34079

In [19]:
edge_index = torch.tensor([edges.node1,
                           edges.node2], dtype=torch.long)

In [20]:
features = pd.concat([emb_vis, emb_rna])

In [21]:
features = features.sort_index()

In [22]:
x = torch.tensor(features.values , dtype=torch.float)

In [23]:
edges[edges.type == "vis2grid"]

Unnamed: 0.1,Unnamed: 0,node1,node2,type,celltype_major,celltype_minor,clone
2908700,298910,23191,22073,vis2grid,,,
2908701,298911,23191,2569,vis2grid,,,
2908702,298912,21298,19845,vis2grid,,,
2908703,298913,21298,33989,vis2grid,,,
2908704,298914,21298,8210,vis2grid,,,
...,...,...,...,...,...,...,...
2928253,318463,18955,13368,vis2grid,,,
2928254,318464,18955,6262,vis2grid,,,
2928255,318465,18955,17910,vis2grid,,,
2928256,318466,14742,28303,vis2grid,,,


In [24]:
edges.clone = edges.clone.fillna("missing")

In [25]:
edges[edges.clone == "missing"]

Unnamed: 0.1,Unnamed: 0,node1,node2,type,celltype_major,celltype_minor,clone
2908700,298910,23191,22073,vis2grid,,,missing
2908701,298911,23191,2569,vis2grid,,,missing
2908702,298912,21298,19845,vis2grid,,,missing
2908703,298913,21298,33989,vis2grid,,,missing
2908704,298914,21298,8210,vis2grid,,,missing
...,...,...,...,...,...,...,...
2928253,318463,18955,13368,vis2grid,,,missing
2928254,318464,18955,6262,vis2grid,,,missing
2928255,318465,18955,17910,vis2grid,,,missing
2928256,318466,14742,28303,vis2grid,,,missing


In [26]:
nodes_atr = edges[["node1","type","celltype_major", "clone"]].drop_duplicates()

In [27]:
nodes_atr = nodes_atr.sort_values(by = "node1")

In [28]:
nodes_atr.clone.unique()

array(['diploid', '9', 'missing', '8', '11', '0', '13', '10', '6', '4',
       '1', '5', '12', '14', '2', '7', '15', '16', '3'], dtype=object)

In [29]:
nodes_atr.celltype_major = nodes_atr.celltype_major.fillna("missing")

In [30]:
nodes_atr

Unnamed: 0,node1,type,celltype_major,clone
885200,0,sc2vis,T.cells,diploid
2865600,1,sc2vis,Normal Epithelial,9
2924234,2,vis2grid,missing,missing
1089400,3,sc2vis,Normal Epithelial,8
1253900,4,sc2vis,Cancer Epithelial,diploid
...,...,...,...,...
2909396,34074,vis2grid,missing,missing
1492900,34075,sc2vis,T.cells,diploid
290300,34076,sc2vis,Normal Epithelial,0
1039900,34077,sc2vis,T.cells,diploid


In [31]:
from collections import Counter

Counter(nodes_atr.celltype_major)

Counter({'T.cells': 5983,
         'Normal Epithelial': 4033,
         'missing': 4992,
         'Cancer Epithelial': 8503,
         'Plasmablasts': 1275,
         'Myeloid': 3868,
         'Endothelial': 1346,
         'CAFs': 2592,
         'B.cells': 1018,
         'PVL': 469})

In [33]:
ct_list = list(nodes_atr.celltype_major.unique())
ct_list.remove("missing")

In [34]:
Counter(nodes_atr.clone)

Counter({'diploid': 19626,
         '9': 483,
         'missing': 4992,
         '8': 583,
         '11': 367,
         '0': 1502,
         '13': 304,
         '10': 511,
         '6': 645,
         '4': 801,
         '1': 1236,
         '5': 816,
         '12': 315,
         '14': 224,
         '2': 925,
         '7': 599,
         '15': 67,
         '16': 61,
         '3': 22})

In [35]:
clone_list = list(nodes_atr.clone.unique())
clone_list.remove("missing")
clone_list.remove("diploid")
clone_dict = {x:int(x) for x in clone_list}
clone_dict["missing"] = -1
clone_dict["diploid"] = len(clone_dict)-1

In [36]:
clone_dict

{'9': 9,
 '8': 8,
 '11': 11,
 '0': 0,
 '13': 13,
 '10': 10,
 '6': 6,
 '4': 4,
 '1': 1,
 '5': 5,
 '12': 12,
 '14': 14,
 '2': 2,
 '7': 7,
 '15': 15,
 '16': 16,
 '3': 3,
 'missing': -1,
 'diploid': 17}

In [37]:
    ct_list = list(nodes_atr.celltype_major.unique())
    ct_list.remove("missing")

In [38]:
type_dict = {ct_list[i]:i for i in range(len(ct_list))}

In [39]:
type_dict["missing"] = -1

In [40]:
type_dict

{'T.cells': 0,
 'Normal Epithelial': 1,
 'Cancer Epithelial': 2,
 'Plasmablasts': 3,
 'Myeloid': 4,
 'Endothelial': 5,
 'CAFs': 6,
 'B.cells': 7,
 'PVL': 8,
 'missing': -1}

In [41]:
nodes_atr.clone = nodes_atr.clone.map(clone_dict)

In [42]:
edges.celltype_major = edges.celltype_major.fillna("missing")

In [43]:
nodes_atr.celltype_major = nodes_atr.celltype_major.map(type_dict)

In [44]:
nodes_atr = nodes_atr.set_index("node1")

In [45]:
nodes_atr.head()

Unnamed: 0_level_0,type,celltype_major,clone
node1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,sc2vis,0,17
1,sc2vis,1,9
2,vis2grid,-1,-1
3,sc2vis,1,8
4,sc2vis,2,17


In [46]:
features = features.join(nodes_atr)

In [47]:
features.clone = features.clone.fillna(-1)

In [48]:
features.celltype_minor = features.celltype_major.fillna(-1)

  features.celltype_minor = features.celltype_major.fillna(-1)


In [49]:
y_clone = torch.tensor(features.clone.values,dtype=torch.long)


In [50]:
y_type = torch.tensor(features.celltype_minor.values,dtype=torch.long)

In [51]:
y_type.unique()

tensor([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8])

In [52]:
y_clone.unique()

tensor([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17])

In [53]:
data = Data(x=x, edge_index=edge_index, y_clone = y_clone,y_type = y_type, edge_type = edges.type.values)

In [54]:
data.validate(raise_on_error=True)


True

In [62]:
torch.save(data,"../data/processed/gnn_data.pt")

In [55]:
hold_out = torch.tensor(np.where(data.y_clone == -1)[0],dtype=torch.long)

In [56]:
tot_size = data.x.shape[0] -len(hold_out)
train_size = int(0.8*tot_size)
hold_in = np.arange(data.x.shape[0])
hold_in = [x for x in hold_in if x not in  hold_out]

In [57]:
len(hold_out)

4992

In [58]:
data.y_type[hold_in].unique()

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

In [59]:
data.y_clone[hold_in].unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])

In [60]:
from sklearn.model_selection import train_test_split

train_samp,test_samp,_,_ = train_test_split(hold_in,data.y_clone[hold_in], stratify = data.y_clone[hold_in], test_size=0.2, random_state=42)


In [61]:
data.train_mask = torch.tensor(train_samp,dtype=torch.long)
data.test_mask = torch.tensor(test_samp,dtype=torch.long)
data.hold_out = hold_out

In [62]:
data.train_mask = torch.tensor(train_samp,dtype=torch.long)
data.test_mask = torch.tensor(test_samp,dtype=torch.long)
data.hold_out = hold_out

In [68]:
clone_dict

{'9': 9,
 '8': 8,
 '11': 11,
 '0': 0,
 '13': 13,
 '10': 10,
 '6': 6,
 '4': 4,
 '1': 1,
 '5': 5,
 '12': 12,
 '14': 14,
 '2': 2,
 '7': 7,
 '15': 15,
 '16': 16,
 '3': 3,
 'missing': -1,
 'diploid': 17}

In [69]:
node_encoder

{'CATTATGGTAATGCCT-1': 0,
 'TTGGTCATCCATCCGG-1': 1,
 'TAGACCATTGCTTAGA-1': 2,
 'CGAGGAGAGCCATCGT-1': 3,
 'CTAGATGCATGGATTG-1': 4,
 'GATTGCTGTGAACCAA-1': 5,
 'TGGTTAAAGGTAGACT-1': 6,
 'GAGTAATGTAACGTGA-1': 7,
 'TGAAGGAAGACTAATG-1': 8,
 'ATCTTAGGTAACTTGC-1': 9,
 'TAGTGACCGTCGATTA-1': 10,
 'ATTTAGGAGGTGAGAA-1': 11,
 'CCATTAGTCATTACCG-1': 12,
 'GGATAATGTTGTCAGT-1': 13,
 'AACTGGGTCTATCGGC-1': 14,
 'AAGTTCGCAATAACCG-1': 15,
 'ACAAACCAGACTATTC-1': 16,
 'ACATTACTCATAATGG-1': 17,
 'TCATAGGCACATGCTT-1': 18,
 'TTCGCGCCAGGTTTCT-1': 19,
 'AAGTGAACAAGTTAAC-1': 20,
 'GAATAACGTGACACAA-1': 21,
 'GGCAAGCTCACTTGAT-1': 22,
 'CCATCACAGTTACAGA-1': 23,
 'TCCTTGAAGTTAGGGT-1': 24,
 'CTTGCGCCAAGGAATA-1': 25,
 'TGGTGAAGTGCGCAAG-1': 26,
 'ACTTGCTCAATCCACT-1': 27,
 'GAGACCCAGCAACTTT-1': 28,
 'GACAAACTCCTTTCTA-1': 29,
 'CCCTGACCAGCACTAA-1': 30,
 'CTTTGGATCGTCAACC-1': 31,
 'AATTCTCAGACGCAAC-1': 32,
 'ATTGAGTCACAACATG-1': 33,
 'AGTCAATCATAAGTCG-1': 34,
 'CTTTAGGTCTCCATCT-1': 35,
 'TAGTGTGCAACATGAT-1': 36,
 'CTTAAGTTC

In [70]:
type_dict

{'T.cells': 0,
 'Normal Epithelial': 1,
 'Cancer Epithelial': 2,
 'Plasmablasts': 3,
 'Myeloid': 4,
 'Endothelial': 5,
 'CAFs': 6,
 'B.cells': 7,
 'PVL': 8,
 'missing': -1}

In [73]:
    import pickle   

    torch.save(data, "../data/processed/data.pt")
    node_encoder.update(clone_dict)
    node_encoder.update(type_dict)
    with open('../data/processed/node_encoding.pkl', 'wb') as fp:
        pickle.dump(node_encoder, fp)