In [None]:
import torch_geometric as tg
import networkx as nx
import torch, time, pickle, os, sys, argparse
from torch_geometric.data import Data
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os.path as osp
from itertools import count

In [None]:
# case='vlarge_4t_quantile_raw_redshift_75_all'
case= 'vlarge_all_4t_z0.0_quantile_raw'
with open(osp.expanduser(f'~/../../../scratch/gpfs/cj1223/GraphStorage/{case}/data.pkl'), 'rb') as handle:
    data=pickle.load(handle)

In [None]:
transform_path=osp.expanduser(f'~/../../scratch/gpfs/cj1223/GraphStorage')  
transformer=pickle.load(open(osp.join(transform_path,'transformers','quantile_allfeat_1.pkl'), 'rb'))
tkeys=list(transformer.keys())

In [None]:
def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, leaf_vs_root_factor = 0.5):

    '''
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    Based on Joel's answer at https://stackoverflow.com/a/29597209/2966723,
    but with some modifications.  

    We include this because it may be useful for plotting transmission trees,
    and there is currently no networkx equivalent (though it may be coming soon).
    
    There are two basic approaches we think of to allocate the horizontal 
    location of a node.  
    
    - Top down: we allocate horizontal space to a node.  Then its ``k`` 
      descendants split up that horizontal space equally.  This tends to result
      in overlapping nodes when some have many descendants.
    - Bottom up: we allocate horizontal space to each leaf node.  A node at a 
      higher level gets the entire space allocated to its descendant leaves.
      Based on this, leaf nodes at higher levels get the same space as leaf
      nodes very deep in the tree.  
      
    We use use both of these approaches simultaneously with ``leaf_vs_root_factor`` 
    determining how much of the horizontal space is based on the bottom up 
    or top down approaches.  ``0`` gives pure bottom up, while 1 gives pure top
    down.   
    
    
    :Arguments: 
    
    **G** the graph (must be a tree)

    **root** the root node of the tree 
    - if the tree is directed and this is not given, the root will be found and used
    - if the tree is directed and this is given, then the positions will be 
      just for the descendants of this node.
    - if the tree is undirected and not given, then a random choice will be used.

    **width** horizontal space allocated for this branch - avoids overlap with other branches

    **vert_gap** gap between levels of hierarchy

    **vert_loc** vertical location of root
    
    **leaf_vs_root_factor**

    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, leftmost, width, leafdx = 0.2, vert_gap = 0.2, vert_loc = 0, 
                    xcenter = 0.5, rootpos = None, 
                    leafpos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''

        if rootpos is None:
            rootpos = {root:(xcenter,vert_loc)}
        else:
            rootpos[root] = (xcenter, vert_loc)
        if leafpos is None:
            leafpos = {}
        children = list(G.neighbors(root))
        leaf_count = 0
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            rootdx = width/len(children)
            nextx = xcenter - width/2 - rootdx/2
            for child in children:
                nextx += rootdx
                rootpos, leafpos, newleaves = _hierarchy_pos(G,child, leftmost+leaf_count*leafdx, 
                                    width=rootdx, leafdx=leafdx,
                                    vert_gap = vert_gap, vert_loc = vert_loc-vert_gap, 
                                    xcenter=nextx, rootpos=rootpos, leafpos=leafpos, parent = root)
                leaf_count += newleaves

            leftmostchild = min((x for x,y in [leafpos[child] for child in children]))
            rightmostchild = max((x for x,y in [leafpos[child] for child in children]))
            leafpos[root] = ((leftmostchild+rightmostchild)/2, vert_loc)
        else:
            leaf_count = 1
            leafpos[root]  = (leftmost, vert_loc)
#        pos[root] = (leftmost + (leaf_count-1)*dx/2., vert_loc)
#        print(leaf_count)
        return rootpos, leafpos, leaf_count

    xcenter = width/2.
    if isinstance(G, nx.DiGraph):
        leafcount = len([node for node in nx.descendants(G, root) if G.out_degree(node)==0])
    elif isinstance(G, nx.Graph):
        leafcount = len([node for node in nx.node_connected_component(G, root) if G.degree(node)==1 and node != root])
    rootpos, leafpos, leaf_count = _hierarchy_pos(G, root, 0, width, 
                                                    leafdx=width*1./leafcount, 
                                                    vert_gap=vert_gap, 
                                                    vert_loc = vert_loc, 
                                                    xcenter = xcenter)
    pos = {}
    for node in rootpos:
        pos[node] = (leaf_vs_root_factor*leafpos[node][0] + (1-leaf_vs_root_factor)*rootpos[node][0], leafpos[node][1]) 
#    pos = {node:(leaf_vs_root_factor*x1+(1-leaf_vs_root_factor)*x2, y1) for ((x1,y1), (x2,y2)) in (leafpos[node], rootpos[node]) for node in rootpos}
    xmax = max(x for x,y in pos.values())
    for node in pos:
        pos[node]= (pos[node][0]*width/xmax, pos[node][1])
    return pos

In [None]:
j=11
k=3
feats=data[j].x.numpy()
G=tg.utils.to_networkx(data[j])
di=nx.betweenness_centrality(G)
featr=[]
for q,key in enumerate(di.keys()):
    feat=transformer[tkeys[k]].inverse_transform(data[j].x.numpy()[q,k].reshape(-1, 1))[0][0]
    featr.append(feat)
    di[key]=feat
zs=np.unique(1/transformer[0].inverse_transform(feats[:,0].reshape(-1,1))-1)
# posy=2*(z/max(z)-0.5)  
posy=[]
vals, counts=np.unique(feats[:,0], return_counts=True)
for z in feats[:,0]:
    posy.append(np.where(z==vals)[0][0])
hpos=hierarchy_pos(G.reverse())
for p in hpos.keys():
    hpos[p]=[hpos[p][0]*1.05, -posy[p]]

#     hpos[p]=[posy[p], hpos[p][0]*1.05]

In [None]:
rs=np.round(np.percentile(featr, np.linspace(0,100,8)),1)
rs

In [None]:
data[j][np.zeros(291)]

In [None]:
targs = data[j].y.numpy()

In [None]:
from matplotlib import cm
plt.rcParams['font.size'] = 20
cmap=cm.get_cmap(name='magma')
fig,ax=plt.subplots(figsize=(13,15))
nx.set_node_attributes(G, di, 'n_prog')
labels=nx.get_node_attributes(G, 'n_prog')
progs = set(nx.get_node_attributes(G,'n_prog').values())
mapping = dict(zip(sorted(progs),count()))
nodes = G.nodes()
colors = [mapping[G.nodes[n]['n_prog']] for n in nodes]
print('Made graph')
# G.remove_nodes_from(list(nx.isolates(G)))
pos=nx.planar_layout(G)
# pos=nx.kamada_kawai_layout(G) #this one is kinda snakey but works pretty well! Slow as hell though
# pos=nx._layout(G)
ax.set(ylabel='Redshift')
ax.set(yticks=-np.arange(len(zs))[1::3])
ax.set_yticklabels(np.round(zs[:-1],1)[::-3])

ec = nx.draw_networkx_edges(G, hpos, alpha=0.9, arrows=False, ax=ax)
nc = nx.draw_networkx_nodes(G, hpos, nodelist=nodes, node_color=colors,  \
                            node_size=((feats[:,3]-min(feats[:,3])+1))**3/2, cmap=plt.cm.jet, ax=ax)
cbar = fig.colorbar(nc)
cbar.ax.set_yticklabels(rs)
# cbar.set_label('Redshift')
cbar.set_label(r'Halo mass [log($M_h/M_{\odot}$)]')

# cbar.set_ticks(featr)
ax.tick_params(left=True, bottom=False, labelleft=True, labelbottom=True)
ax.set_xticks([]);
ax.set_title(r'Sample merger tree with $M_{*,final}=10^{8.5} M_{\odot}$' )
ax.vlines(0,-33.5,-68, color='gray', linestyle='dashed')
ax.vlines(1.1,-33.5,-68, color='gray', linestyle='dashed')
ax.hlines(-33.5,0,1.1, color='gray', linestyle='dashed')
ax.hlines(-68,0,1.1, color='gray', linestyle='dashed', label='75% cut')
ax.legend(loc='lower left')
# if G.number_of_nodes()<2000:
#     print('drawing')
#     nx.draw(G, labels=labels, pos=pos,ax=ax, node_size=8, arrowsize=5, cmap=cmap)
fig.savefig('../paper_figures/full_tree_with_cutout.png')

In [None]:
posy=2*(feats[:,0]/max(feats[:,0])-0.5)

In [None]:
np.round(zs[:-1],1)[::-5]

In [None]:
np.arange(len(zs))[1::3]

In [None]:
targs[0]+9

In [None]:
feat

In [None]:
np.round(zs[:-1],1)[::-2]

In [None]:
a=np.unique(feats[:,0])
ly=len(a)
ys=2*np.arange(ly)/(ly)-1

In [None]:
def convert(d,p):
    dfin=[]
    if len(p)!=len(np.unique(p)):
        print('Wrong order of prog/desc')
    else:
        no=d[0]
        for desc in d:
            if desc==no:
                dfin.append(0)
            else:
                dfin.append(p.index(desc)+1)
    return dfin, np.arange(1, 1+len(p))

In [None]:
mls=[]
lss=[]
for percentile in [75]:
    ml=[]
    ls=[]
    dat=[]
    for d in tqdm(data[:15], total=len(data[:15])):
#         print(d)
        hals=d.x.numpy()
        mask=hals[:,0]>np.percentile(hals[:,0], percentile)
        de, pr = d.edge_index
        ml.append(sum(mask))
        ls.append(len(hals))
        if np.sum(mask[1:])!=0:
            desc, progs = convert(list(np.array(pr)[mask[1:]]), list(np.array(de)[mask[1:]]) )
        else:
            desc, progs=[0],[0]
        edge_index = torch.tensor([progs, desc], dtype=torch.long)
        if np.sum(mask)==0:
            mask[0]=True
        x = torch.tensor(hals[mask], dtype=torch.float)  
        y=d.y
        
        if np.sum(mask)==0:
            edge_attr=torch.tensor([1], dtype=torch.float)
        else:
            edge_attr=d.edge_attr[mask[1:]]
        graph=Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
        dat.append(graph)
    mls.append(ml)
    lss.append(ls)
#     case=f'vlarge_4t_quantile_raw_redshift_{percentile}_all'
#     print("Saving dataset "+ case)
#     if not osp.exists(f'../../../../../scratch/gpfs/cj1223/GraphStorage/{case}'):
#         os.mkdir(f'../../../../../scratch/gpfs/cj1223/GraphStorage/{case}')

#     with open(f'../../../../../scratch/gpfs/cj1223/GraphStorage/{case}/data.pkl', 'wb') as handle:
#         pickle.dump(dat, handle)

In [None]:
j=11
k=3
feats=dat[j].x.numpy()
G=tg.utils.to_networkx(dat[j])
di=nx.betweenness_centrality(G)
featr=[]
for q,key in enumerate(di.keys()):
    feat=transformer[tkeys[k]].inverse_transform(dat[j].x.numpy()[q,k].reshape(-1, 1))[0][0]
    featr.append(feat)
    di[key]=feat
zs=np.unique(1/transformer[0].inverse_transform(feats[:,0].reshape(-1,1))-1)
# posy=2*(z/max(z)-0.5)  
posy=[]
vals, counts=np.unique(feats[:,0], return_counts=True)
for z in feats[:,0]:
    posy.append(np.where(z==vals)[0][0])
hpos=hierarchy_pos(G.reverse())
for p in hpos.keys():
    hpos[p]=[hpos[p][0]*1.05, -posy[p]]

#     hpos[p]=[posy[p], hpos[p][0]*1.05]

In [None]:
rs=np.round(np.percentile(featr, np.linspace(0,100,8)),1)
rs

In [None]:
targs = dat[j].y.numpy()

In [None]:
from matplotlib import cm
plt.rcParams['font.size'] = 20
cmap=cm.get_cmap(name='magma')
fig,ax=plt.subplots(figsize=(13,15))
nx.set_node_attributes(G, di, 'n_prog')
labels=nx.get_node_attributes(G, 'n_prog')
progs = set(nx.get_node_attributes(G,'n_prog').values())
mapping = dict(zip(sorted(progs),count()))
nodes = G.nodes()
colors = [mapping[G.nodes[n]['n_prog']] for n in nodes]
print('Made graph')
# G.remove_nodes_from(list(nx.isolates(G)))
pos=nx.planar_layout(G)
# pos=nx.kamada_kawai_layout(G) #this one is kinda snakey but works pretty well! Slow as hell though
# pos=nx._layout(G)
ax.set(ylabel='Redshift')
ax.set(yticks=-np.arange(len(zs))[1::3])
ax.set_yticklabels(np.round(zs[:-1],1)[::-3])

ec = nx.draw_networkx_edges(G, hpos, alpha=0.9, arrows=False, ax=ax, width=2)
nc = nx.draw_networkx_nodes(G, hpos, nodelist=nodes, node_color=colors,  \
                            node_size=((feats[:,3]-min(feats[:,3])+1))**3*2, cmap=plt.cm.jet, ax=ax)
cbar = fig.colorbar(nc)
cbar.ax.set_yticklabels(rs)
# cbar.set_label('Redshift')
cbar.set_label(r'Halo mass [log($M_h/M_{\odot}$)]')

# cbar.set_ticks(featr)
ax.tick_params(left=True, bottom=False, labelleft=True, labelbottom=True)
ax.set_xticks([]);
ax.set_title(r'Sample merger tree cut by 75%' )
# ax.vlines(0,-33.5,-68, color='gray', linestyle='dashed')
# ax.vlines(1.1,-33.5,-68, color='gray', linestyle='dashed')
# ax.hlines(-33.5,0,1.1, color='gray', linestyle='dashed')
# ax.hlines(-68,0,1.1, color='gray', linestyle='dashed', label='75% cut')
# ax.legend(loc='lower left')
# if G.number_of_nodes()<2000:
#     print('drawing')
#     nx.draw(G, labels=labels, pos=pos,ax=ax, node_size=8, arrowsize=5, cmap=cmap)
fig.savefig('../paper_figures/small_tree_with_cutout.png')