# Experiments of stage 2

In [3]:
import sys
# sys.path.append('/home/ubuntu/DevoGraph')
sys.path.append('../../DevoGraph/')

In [6]:
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import os
import numpy as np
import torch as th


from copy import deepcopy
from importlib import reload
import devograph.datasets.datasets as data

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [7]:
reload(data)

<module 'devograph.datasets.datasets' from '/Users/watarukawakami/GSoC/DevoGraph/stage_2/../../DevoGraph/devograph/datasets/datasets.py'>

In [None]:
data_path = '~/.CEData/'

## Pre-process sampled data
* The sampled data shows what the input format of the stage 2 should be like:
---
* cell column: cell name
* time column: the time point when the cell data is sampled
* x, y, z columns: the 3-D positions
* size column: the volumn of the cell
---
* All input files of stage 2 should be aligned with the format here

In [None]:
raw_data = pd.read_csv(f'{data_path}raw-data-part-1a.csv')
raw_data.drop(['cellTime'], axis=1, inplace=True)
raw_data.drop(raw_data[(raw_data.time == 'time')].index, axis=0, inplace=True)
raw_data

In [None]:
raw_data.time = raw_data.time.astype('int')
raw_data.x = raw_data.x.astype('float')
raw_data.y = raw_data.y.astype('float')
raw_data.z = raw_data.z.astype('float')
raw_data['size'] = raw_data['size'].astype('float')

In [None]:
raw_data.to_csv('~/DevoGraph/data/CE_raw_data.csv')

In [None]:
lineage_data = pd.read_excel(f'{data_path}cell-by-cell-data-v2.xlsx', sheet_name='daughter-of-database', 
                             engine='openpyxl', usecols=['CELL NAME', 'CELL NAME.1'])
lineage_data.rename(columns={'CELL NAME': 'daughter', 'CELL NAME.1': 'mother'}, inplace=True)

In [None]:
lineage_data

In [None]:
lineage_data.to_csv(f'{data_path}CE_lineage_data.csv')

## Build a DGL graph based on the given sampled data

In [None]:
# load graphs from disk. if there's no existing graphs, download them from url
datasets = data.CETemporalGraphKNN(
        time_start=0, time_end=10, knn_k=3,
        url='https://raw.githubusercontent.com/LspongebobJH/DevoGraph/main/data/CE_raw_data.csv?token=GHSAT0AAAAAABMX6RJRRFC2U5QOCZXHNBVYYVL5Y2A')

# convert the temporal graph datasets into a batch graph with directed links that connect 
# cells across different frames according to the lineage tree given in the second param
res_g, batch_node_interval = data.to_temporal_directed(datasets, '~/.CEData/CE_lineage_data.csv')

# store the directed temporal graphs into datasets for later convenient operations
datasets.set_batch_graph(res_g)
datasets.set_info({'batch_node_interval': batch_node_interval})

In [None]:
print(f"number of frames: {len(datasets)}")
print(f"number of nodes in the batch graph: {datasets.batch_graph.number_of_nodes()}")
print(f"number of edges in the batch graph: {datasets.batch_graph.number_of_edges()}")

In [None]:
# node features
# pos: 3-D position (x,y,z); size: volumn; time: the timestamp relative to the temporal graph series
datasets[0].ndata.keys()

## Visualization in 3-D plot

In [None]:
fig = plt.figure(figsize=(15, 15))

g = datasets[3]
pos = g.ndata['pos']

ax = fig.add_subplot(121, projection='3d')
ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], marker='o')
edges = np.array(list(zip(g.edges()[0].numpy(), g.edges()[1].numpy())))
edges_coor = pos[edges].numpy()
for edge_coor in edges_coor:
    ax.plot(*edge_coor.T, color='tab:grey')
    
g = datasets[4]
pos = g.ndata['pos']

ax = fig.add_subplot(122, projection='3d')
ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], marker='o')
edges = np.array(list(zip(g.edges()[0].numpy(), g.edges()[1].numpy())))
edges_coor = pos[edges].numpy()
for edge_coor in edges_coor:
    ax.plot(*edge_coor.T, color='tab:grey')

# plt.show()

In [None]:
fig = plt.figure(figsize=(20, 20))
batch_graph = datasets.batch_graph
batch_node_interval = datasets.info['batch_node_interval'][3:5]
batch_edges = th.stack(batch_graph.edges()).numpy().T

g1 = deepcopy(datasets[3])
g2 = deepcopy(datasets[4])
pos = batch_graph.ndata['pos']
pos1 = g1.ndata['pos']
pos2 = g2.ndata['pos']
pos2[:, 0] = pos2[:, 0]+1000

ax = fig.add_subplot(111, projection='3d')
ax.scatter(pos1[:, 0], pos1[:, 1], pos1[:, 2], marker='o')
ax.scatter(pos2[:, 0], pos2[:, 1], pos2[:, 2], marker='o', color='tab:red')

edges = np.array(list(zip(g1.edges()[0].numpy(), g1.edges()[1].numpy())))
edges_coor1 = pos1[edges].numpy()
for edge_coor in edges_coor1:
    ax.plot(*edge_coor.T, color='tab:grey')
    
edges = np.array(list(zip(g2.edges()[0].numpy(), g2.edges()[1].numpy())))
edges_coor2 = pos2[edges].numpy()
for edge_coor in edges_coor2:
    ax.plot(*edge_coor.T, color='tab:grey')

cnt=0
for edge in batch_edges:
    if edge[0] in range(batch_node_interval[0][0], batch_node_interval[0][1]) and\
       edge[1] in range(batch_node_interval[1][0], batch_node_interval[1][1]):
        pos_st, pos_ed = pos[edge[0]], pos[edge[1]]
        ax.plot([pos_st[0], pos_ed[0]+1000], [pos_st[1], pos_ed[1]], [pos_st[2], pos_ed[2]], color='green', 
                  linestyle='dashed')
        cnt += 1
    if cnt > 50:
        break

plt.show()