In [183]:
import torch
import pandas as pd

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pickle   
from collections import Counter
import os
os.chdir("/home/o313a/clonal_GNN/")


In [195]:
edges = pd.read_csv("data/interim/edges_xenium_new.csv",
dtype = {"node1":str, "node2":str})
overcl = pd.read_csv("data/interim/clones_over.csv")
overcl.columns = ["node1","clone"]
overcl.node1 = [x[:-2] for x in overcl.node1]
cell_type = pd.read_excel("data/raw/Requested_Cell_Barcode_Type_Matrices.xlsx", sheet_name="scFFPE-Seq")
cell_type.columns = ["node1","cell_type"]
overcl = cell_type.merge(overcl, on = "node1", how = "left")

In [196]:
set(edges.node2).difference(set(edges.node1))

set()

In [197]:
# edges = edges[edges["type"]!= "sc2sc"]

# Data loading and preparation

In [198]:

def validate_counts(counter, threshold, label):
    for key, value in counter.items():
        assert value > threshold, f"{label} {key} has less than {threshold + 1} items"
        

def filter_and_encode(df, node_encoder, all_nodes,use_index=False):
    if use_index:
        df = df[df.index.isin(all_nodes)]
    else:
        df = df[df.node1.isin(all_nodes) & df.node2.isin(all_nodes)]


    if use_index:
        df = df.rename(index=node_encoder)
    else:
        df.node1 = df.node1.map(node_encoder)
        df.node2 = df.node2.map(node_encoder)

    return df


def preprocess_data(edges, overcl, spatial_edges):
    # Filter and fill missing values
    overcl = overcl.merge(edges[edges["type"] != spatial_edges], on="node1", how="left")
    overcl = overcl[["node1", "clone", "cell_type"]].drop_duplicates()    
    # Validation
    validate_counts(Counter(overcl.cell_type), 20, "Cell type")
    validate_counts(Counter(overcl.clone), 20, "Clone")
    
    edges = edges.merge(overcl[["clone", "node1", "cell_type"]], on="node1", how="left")
    drop_nodes_clone = edges[(edges["type"]!= "xen2grid")&(edges.clone.isna())].node1
    drop_nodes_ct = edges[(edges["type"]!= "xen2grid")&(edges.cell_type.isna())].node1
    to_drop = list(set(drop_nodes_clone).union(set(drop_nodes_ct)))
    edges = edges[~edges.node1.isin(to_drop)]
    edges = edges[~edges.node2.isin(to_drop)]
    return edges, overcl


def read_and_merge_embeddings(paths, edges):
    all_nodes_graph = set(edges.node1).union(set(edges.node2))
    emb_vis = pd.read_csv(paths["spatial"], index_col=0)
    emb_vis.index = emb_vis.index.map(str)
    emb_rna = pd.read_csv(paths["rna"], index_col=0)
    emb_rna.index = emb_rna.index.map(str)

    all_nodes_emb = set(emb_vis.index).union(set(emb_rna.index))
    all_nodes = list(all_nodes_graph.intersection(all_nodes_emb))
    node_encoder = {all_nodes[i]:i for i in range(len(all_nodes))}
    emb_vis = filter_and_encode(emb_vis, node_encoder, all_nodes, use_index=True)
    emb_rna = filter_and_encode(emb_rna, node_encoder, all_nodes, use_index=True)
    edges = filter_and_encode(edges, node_encoder, all_nodes)

    
    return emb_vis, emb_rna, edges, node_encoder


def create_data_object(edges, emb_vis, emb_rna,node_encoder):
    # Convert to tensors
    edge_index = torch.tensor([edges.node1.values, edges.node2.values], dtype=torch.long)
    edge_weight = torch.tensor(edges.weight.values, dtype=torch.float)
    features = pd.concat([emb_vis, emb_rna]).sort_index()
    x = torch.tensor(features.values, dtype=torch.float)
    
    # Encode attributes
    edges.clone = edges.clone.fillna("missing")
    edges.cell_type = edges.cell_type.fillna("missing")

    nodes_attr = edges[["node1", "cell_type", "clone"]].drop_duplicates().sort_values(by="node1")
    clone_dict = create_encoding_dict(nodes_attr, "clone", extras=["diploid", "missing"])
    type_dict = create_encoding_dict(nodes_attr, "cell_type", extras=["missing"])
    
    nodes_attr["clone"] = nodes_attr["clone"].map(clone_dict)
    nodes_attr["cell_type"] = nodes_attr["cell_type"].map(type_dict)
    nodes_attr = nodes_attr.set_index("node1")
    features = features.join(nodes_attr)
    
    y_clone = torch.tensor(features.clone.values, dtype=torch.long)
    y_type = torch.tensor(features.cell_type.values, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y_clone=y_clone, y_type=y_type, edge_type=edges.type.values, edge_attr=edge_weight)
    assert data.validate(raise_on_error=True), "Data not valid"
    assert data.x.shape[0] == data.y_clone.shape[0] == data.y_type.shape[0], "Data not valid"
    
    return data, {"nodes": node_encoder, "clones": clone_dict, "types": type_dict}


def create_encoding_dict(df, column, extras=[]):
    items = list(df[column].unique())
    for extra in extras:
        items.remove(extra)
    if "diploid" in extras:
        dt = {x:int(x) for x in items}
        dt["missing"] = -1
        dt["diploid"] = len(dt)-1
    else:
        dt = {item: idx for idx, item in enumerate(items)}
        dt["missing"] = -1


    return dt


embedding_paths = {
    "spatial": "data/interim/embedding_spatial_xenium.csv",
    "rna": "data/interim/embedding_rna_xenium.csv"
}
edges, overcl = preprocess_data(edges, overcl,"sc2xen")

emb_vis, emb_rna, edges, node_encoder = read_and_merge_embeddings(embedding_paths, edges)

data, encoding_dict = create_data_object(edges, emb_vis, emb_rna, node_encoder)


In [199]:
torch.save(data, "data/processed/data_xen.pt")
with open('data/processed/full_encoding_xen.pkl', 'wb') as fp:
    pickle.dump(encoding_dict, fp)

In [200]:
data.edge_attr = data.edge_attr.reshape((-1,1))
hold_out_indices = np.where(data.y_clone == -1)[0]
hold_out = torch.tensor(hold_out_indices, dtype=torch.long)
hold_in_indices = np.arange(data.x.shape[0])
hold_in = [index for index in hold_in_indices if index not in hold_out]

In [201]:
data.y_clone[hold_in].unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26])

In [202]:
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    data,
    num_neighbors=[10] * 3,
    batch_size=128,input_nodes = hold_in
)

In [203]:
del data.edge_type


In [204]:

for dat in loader:
    assert -1 not in dat.y_clone.unique()

In [205]:
dat.y_clone.unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 26])