In [4]:
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import math
import scipy as sp
import pickle as pic
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict
import time
start_time = time.time()

In [5]:
class Node:
    
    def __init__(self, char_list, num_states, parent=None, left=None, right=None):
        # num_states includes the 0 state. For example, if the possible states are 0 or 1, then num_states=2
        self.chars = char_list
        self.parent = parent
        self.left = left
        self.right = right
        self.num_chars = len(self.chars)
        self.num_states = num_states
        
    def is_leaf(self):
        return not (self.left or self.right)

    def duplicate(self, p, q=None, dropout_rate=0):
        assert len(p) == len(self.chars), "invalid p vector"
        if q:
            for i in range(len(p)):
                assert len(q[i]) + 1 == self.num_states, "invalid q[" + str(i) + "] vector"
                
        new_chars = []
        if not q:
            q = [None for i in self.chars]

        for l in range(len(self.chars)):
            if self.chars[l] != 0:
                new_chars.append(self.chars[l])
            else:
                if np.random.random() < p[l]:
                    new_chars.append(np.random.choice(np.arange(1, self.num_states), 1, p=q[l])[0])
                else:
                    new_chars.append(self.chars[l])
        
        return Node(new_chars, self.num_states)
    
    def __str__(self):
        s = ''
        for x in self.chars:
            s += str(x) + '|'
        return s[:-1]
                    
def simulation(p, num_states, time, q=None):
    root = Node([0 for i in p], num_states)
    curr_gen = [root]
    for t in range(time):
        new_gen = []
        for n in curr_gen:
            c1 = n.duplicate(p, q)
            c2 = n.duplicate(p, q)
            c1.parent = n
            c2.parent = n
            n.left = c1
            n.right = c2
            new_gen.append(c1)
            new_gen.append(c2)
        curr_gen = new_gen
    return curr_gen

def find_lineage(node):
    lineage = [node]
    while node.parent:
        node = node.parent
        lineage.insert(0, node)
    return lineage

def find_root(samples):
    return find_lineage(samples[0])[0]

def print_tree(root):
    
    def tree_str(node, level=0):
        ret = "\t"*level+str(node)+"\n"
        if node.left:
            ret += tree_str(node.left, level+1)
        if node.right:
            ret += tree_str(node.right, level+1)
        return ret
    
    print(tree_str(root))
    
def generate_frequency_matrix(samples, subset=None):
    k = samples[0].num_chars
    m = samples[0].num_states + 1
    F = np.zeros((k,m), dtype=int)
    if not subset:
        subset = list(range(len(samples)))
    for i in subset:
        for j in range(k):
            F[j][samples[i].chars[j]] += 1
    return F
            
def split_data(F):
    k,m = F.shape[0], F.shape[1]
    split_data = []
    for i in range(k):
        for j in range(1, m-1):
            split_data.append((i,j))
    split_data.sort(key=lambda tup: F[tup[0]][tup[1]], reverse=True)
    index = 0
    
    for i in range(5):
        s = ''
        for j in range(1, 5):
            a, b = split_data[index][0], split_data[index][1]
            s += str((a,b)) + " freq =" + str(F[a][b]) + " "
            index += 1
        print(s)
            
    
def construct_connectivity_graph(samples, subset=None):
    n = len(samples)
    k = samples[0].num_chars
    m = samples[0].num_states
    G = nx.Graph()
    if not subset:
        subset = range(n)
    for i in subset:
        G.add_node(i)
    F = generate_frequency_matrix(samples, subset)
    for i in subset:
        for j in subset:
            if j <= i:
                continue
            n1 = samples[i]
            n2 = samples[j]
            #compute simularity score
            score = 0
            for l in range(k):
                x = n1.chars[l]
                y = n2.chars[l]
                if min(x, y) >= 0 and max(x,y) > 0:
                    if x==y:
                        score -= 3*(len(subset) - F[l][x] - F[l][-1])
                    elif min(x,y) == 0:
                        score += F[l][max(x,y)] - 1
                    else:
                        score += (F[l][x] + F[l][y]) - 2
                        
                if score != 0:
                    G.add_edge(i,j, weight=score)
    return G

def max_cut_heuristic(G, sdimension, iterations, show_steps=False):
    #n = len(G.nodes())
    d = sdimension+1
    emb = {}        
    for i in G.nodes():
        x = np.random.normal(size=d)
        x = x/np.linalg.norm(x)
        emb[i] = x
        
    def show_relaxed_objective():
        score = 0
        for e in G.edges():
            u = e[0]
            v = e[1]
            score += G[u][v]['weight']*np.linalg.norm(emb[u]-emb[v])
        print(score)
        
    for k in range(iterations):
        new_emb = {}
        for i in G.nodes:
            cm = np.zeros(d, dtype=float)
            for j in G.neighbors(i):
                cm -= G[i][j]['weight']*np.linalg.norm(emb[i]-emb[j])*emb[j]
            cm = cm/np.linalg.norm(cm)
            new_emb[i] = cm
        emb = new_emb
        
    #print("final relaxed objective:")
    #show_relaxed_objective()
    return_set = set()
    best_score = 0
    for k in range(3*d):
        b = np.random.normal(size=d)
        b = b/np.linalg.norm(b)
        S = set()
        for i in G.nodes():
            if np.dot(emb[i], b) > 0:
                S.add(i)
        this_score = evaluate_cut(S, G)
        if this_score > best_score:
            return_set = S
            best_score = this_score
    #print("score before hill climb = ", best_score)
    improved_S = improve_cut(G, return_set)
    #final_score = evaluate_cut(improved_S, G)
    #print("final score = ", final_score)
    return improved_S

def improve_cut(G, S):
    #n = len(G.nodes())
    ip = {}
    new_S = S.copy()
    for i in G.nodes():
        improvement_potential = 0
        for j in G.neighbors(i):
            if cut(i,j,new_S):
                improvement_potential -= G[i][j]['weight']
            else:
                improvement_potential += G[i][j]['weight']
        ip[i] = improvement_potential
        
    all_neg = False
    iters = 0
    while (not all_neg) and (iters < 2*len(G.nodes)):
        best_potential = 0
        best_index = 0
        for i in G.nodes():
            if ip[i] > best_potential:
                best_potential = ip[i]
                best_index = i
        if best_potential > 0:
            for j in G.neighbors(best_index):
                if cut(best_index,j,new_S):
                    ip[j] += 2*G[best_index][j]['weight']
                else:
                    ip[j] -= 2*G[best_index][j]['weight']
            ip[best_index] = -ip[best_index]
            if best_index in new_S:
                new_S.remove(best_index)
            else:
                new_S.add(best_index)
        else:
            all_neg = True
        iters += 1
    #print("number of hill climbing interations: ", iters)
    return new_S

def evaluate_cut(S, G, B=None, show_total=False):
    cut_score = 0
    total_good = 0
    total_bad = 0
    for e in G.edges():
        u = e[0]
        v = e[1]
        w_uv = G[u][v]['weight']
        total_good += float(w_uv)
        if cut(u,v,S):
            cut_score += float(w_uv)

    if B:
        for e in B.edges():
            u = e[0]
            v = e[1]
            w_uv = B[u][v]['weight']
            total_bad += float(w_uv)
            if cut(u,v,S):
                cut_score -= float(w_uv)
            
    if show_total:
        print("total good = ", total_good)
        print("total bad = ", total_bad)
    return(cut_score)

def greedy_cut(samples, subset=None):
    F = generate_frequency_matrix(samples, subset)
    k,m = F.shape[0], F.shape[1]
    freq = 0
    char = 0
    state = 0
    if not subset:
        subset = list(range(len(samples)))
    for i in range(k):
        for j in range(1, m-1):
            if F[i][j] > freq and F[i][j] < len(subset) - F[i][-1] :
                char, state = i,j
                freq = F[i][j]
    if freq == 0:
        return random_nontrivial_cut(subset)
    S = set()
    Sc = set()
    missing = set()
    #print(char, state)
    for i in subset:
        if samples[i].chars[char] == state:
            S.add(i)
        elif samples[i].chars[char] == -1:
            missing.add(i)
        else:
            Sc.add(i)
            
    if not Sc:
        if len(S) == len(subset) or len(S) == 0:
            print(F)
            print(char, state, len(subset))
        return S
    
    for i in missing:
        s_score = 0
        sc_score = 0
        for j in S:
            for l in range(k):
                if samples[i].chars[l] > 0 and samples[i].chars[l] == samples[j].chars[l]:
                    s_score += 1
        for j in Sc:
            for l in range(k):
                if samples[i].chars[l] > 0  and samples[i].chars[l] == samples[j].chars[l]:
                    sc_score += 1
        if s_score/len(S) > sc_score/len(Sc):
            S.add(i)
        else:
            Sc.add(i)
        
    if len(S) == len(subset) or len(S) == 0:
            print(F)
            print(char, state, len(subset))
    return S
    
def random_cut(subset):
    S = set()
    for i in subset:
        if np.random.random() > 0.5:
            S.add(i)
    return S

def random_nontrivial_cut(subset):
    assert len(subset) > 1
    S = set()
    lst = list(subset)
    S.add(lst[0])
    for i in range(2,len(lst)):
        if np.random.random() > 0.5:
            S.add(lst[i])
    return S


def cut(u, v, S):
    return ((u in S) and (not v in S)) or ((v in S) and (not u in S))

def num_incorrect(S, h):
    num = 0
    for i in range(int(2**h/2)):
        if not i in S:
            num += 1

    for i in range(int(2**h/2), 2**h):
        if i in S:
            num += 1

    return min(num, 2**h - num)

def find_tree_lineage(i, T):
    p = list(T.predecessors(i))
    curr_node = i
    ancestor_list = [curr_node]
    while p:
        curr_node = p[0]
        ancestor_list.insert(0, curr_node)
        p = list(T.predecessors(curr_node))
    return(ancestor_list)

        
def outgroup(i, j, k, T):
    assert i != j and i != k and j != k, str(i) + ' ' + str(j) + ' ' + str(k) + ' not distinct'
    
    Li = find_tree_lineage(i, T)
    Lj = find_tree_lineage(j, T)
    Lk = find_tree_lineage(k, T)
    l = 0
    while Li[l] == Lj[l] and Lj[l] == Lk[l]:
        l += 1
    if Li[l] != Lj[l] and Li[l] != Lk[l] and Lj[l] != Lk[l]:
        return None
    if Li[l] == Lj[l]:
        return k
    if Li[l] == Lk[l]:
        return j
    if Lj[l] == Lk[l]:
        return i
    
    
    
def evaluate_split(S, subset, T, sample_size=1000):
    # assume S \subseteq T.leaves
    def S_outgroup(i,j,k):
        if (not cut(i,j,subset)) and (not cut(j,k,subset)):
            return None
        if not cut(i,j,subset):
            return k
        if not cut(i,k,subset):
            return j
        return i
    
    TC = 0
    TI = 0
    unresolved = 0
    superset = np.array(list(S))
    num_sampled = 0
    for a in range(sample_size):
        chosen = np.random.choice(superset, 3, replace=False)
        oS = S_outgroup(chosen[0], chosen[1], chosen[2])
        oT = outgroup(chosen[0], chosen[1], chosen[2], T)
        if oS == None or oT == None:
            unresolved += 1
        else:
            if oS == oT:
                TC += 1
            else:
                TI += 1
    return TC/sample_size, TI/sample_size, unresolved/sample_size
                  
def remove_duplicates(nodes, indices):
    indices = list(indices)
    indices.sort(key=lambda i: nodes[i].chars)
    final_set = set()
    i = 0
    j = 1
    while j < len(indices):
        if nodes[indices[i]].chars != nodes[indices[j]].chars:
            final_set.add(indices[i])
            i = j
        j += 1
    final_set.add(indices[i])
    return final_set
            
def mult_chain(a,b):
    f = 1
    for i in range(a, b+1):
        f*=i
    return f

def nCr(n, k):
    if k > n:
        return 0
    if k > n/2:
        return nCr(n, n-k)
    return int(mult_chain(n-k+1,n)/mult_chain(1,k))

def similarity(u, v, samples):
    k = samples[0].num_chars
    return sum([1 for i in range(k) if samples[u].chars[i] == samples[v].chars[i] and samples[u].chars[i] > 0])

def construct_similarity_graph(samples, subset=None, threshold=0):
    G = nx.Graph()
    if not subset:
        subset = range(len(samples))
    for i in subset:
        G.add_node(i)
    F = generate_frequency_matrix(samples, subset)
    k,m = F.shape[0], F.shape[1]
    for i in range(k):
        for j in range(1,m-1):
            if F[i][j] == len(subset) - F[i][-1]:
                threshold += 1
    for i in subset:
        for j in subset:
            if j <= i:
                continue
            s = similarity(i,j, samples) 
            if s > threshold:
                G.add_edge(i,j, weight=(s-threshold))
    return G

def spectral_split(G, k=2, method='Fiedler', return_eig=False, display=False):
    L = nx.normalized_laplacian_matrix(G).todense()
    diag = sp.linalg.eig(L)
    if k == 2 and method == 'Fiedler':
        v2 = diag[1][:, 1] 
        x = {}
        vertices = list(G.nodes())
        for i in range(len(vertices)):
            x[vertices[i]] = v2[i]
        vertices.sort(key=lambda v: x[v])
        total_weight = 2*sum([G[e[0]][e[1]]['weight'] for e in G.edges()])
        S = set()
        num = 0
        denom = 0
        best_score = 10000000
        best_index = 0
        for i in range(len(vertices) - 1):
            v = vertices[i]
            S.add(v)
            cut_edges = 0
            neighbor_weight = 0
            for w in G.neighbors(v):
                neighbor_weight += G[v][w]['weight']
                if w in S:
                    cut_edges += G[v][w]['weight']
            denom += neighbor_weight
            num += neighbor_weight - 2*cut_edges
            if num == 0:
                best_index = i
                break
            if num/min(denom, total_weight-denom) < best_score:
                best_score = num/min(denom, total_weight-denom)
                best_index = i
        if display:
            print("number of samples = ", len(v2))
            print("lambda2 = ", diag[0][1])
            plt.hist(v2, density=True, bins=30)
            plt.hist([x[v] for v in vertices[:best_index+1]], density=True, bins=30)
            plt.show()
        if return_eig:
            return vertices[:best_index+1], diag
        return vertices[:best_index+1]

def spectral_improve_cut(S, G, display=False):
    delta_n = {}
    delta_d = {}
    ip = {}
    new_S = set(S)
    total_weight = 2*sum([G[e[0]][e[1]]['weight'] for e in G.edges()])
    num =  sum([G[e[0]][e[1]]['weight'] for e in G.edges() if cut(e[0], e[1], new_S)])
    denom = sum([sum([G[u][v]['weight'] for v in G.neighbors(u)]) for u in new_S])
    if num == 0:
        return list(new_S)
    curr_score = num/min(denom, total_weight-denom)
    
    def set_ip(u):
        if min(denom + delta_d[u], total_weight - denom - delta_d[u]) == 0:
            ip[u] = 1000
        else:
            ip[u] = (num + delta_n[u])/min(denom + delta_d[u], total_weight - denom - delta_d[u]) - num/min(denom, total_weight - denom)
    
    for u in G.nodes():
        d = sum([G[u][v]['weight'] for v in G.neighbors(u)])
        if d == 0:
            return [u]
        c = sum([G[u][v]['weight'] for v in G.neighbors(u) if cut(u,v,new_S)])
        delta_n[u] = d-2*c
        if u in new_S:
            delta_d[u] = -d
        else:
            delta_d[u] = d
        set_ip(u)
    #TODO
    all_neg = False
    iters = 0
    
    while (not all_neg) and (iters < len(G.nodes)):
        best_potential = 0
        best_index = None
        for v in G.nodes():
            if ip[v] < best_potential:
                best_potential = ip[v]
                best_index = v
        if not best_index is None:
            num += delta_n[best_index]
            denom += delta_d[best_index]
            for j in G.neighbors(best_index):
                if cut(best_index,j,new_S):
                    delta_n[j] += 2*G[best_index][j]['weight']
                else:
                    delta_n[j] -= 2*G[best_index][j]['weight']
                set_ip(j)
            delta_n[best_index] = -delta_n[best_index]
            delta_d[best_index] = -delta_d[best_index]
            set_ip(best_index)
            if best_index in new_S:
                new_S.remove(best_index)
            else:
                new_S.add(best_index)
            #print("curr scores:", num/min(denom, total_weight - denom))
        else:
            all_neg = True
        iters += 1
    if display:
        print("sgreed+ score, ",  num/min(denom, total_weight - denom))
    return list(new_S)

def evaluate_sparsity(S, G):
    total_weight = 2*sum([G[e[0]][e[1]]['weight'] for e in G.edges()])
    num =  sum([G[e[0]][e[1]]['weight'] for e in G.edges() if cut(e[0], e[1], S)])
    denom = sum([sum([G[u][v]['weight'] for v in G.neighbors(u)]) for u in S])
    return num/min(denom, total_weight - denom)
    
def build_tree_sep(samples, method='greedy', subset = None, sim_thresh=0, p = None, qs = None):
    assert method in ['greedy', 'egreedy', 'SDP', 'greedy+', 'spectral', 'sgreedy+']
    if not subset:
        subset = list(range(len(samples)))
    else:
        subset = list(subset)
    subset = remove_duplicates(samples, subset)
    T = nx.DiGraph()
    for i in subset:
        T.add_node(i)
    def build_helper(S):
        assert S, "error, S = "+ str(S)
        if len(S) == 1:
            return list(S)[0]
        left_set = set()
        if method == 'greedy':
            left_set = greedy_cut(samples, subset=S)
        elif method == 'egreedy':
            left_set = egreedy_cut(samples, p, qs, subset=S)
        elif method == 'SDP':
            G = construct_connectivity_graph(samples, subset=S)
            left_set = max_cut_heuristic(G, 3, 50)
        elif method == 'greedy+':
            G = construct_connectivity_graph(samples, subset=S)
            left_set = greedy_cut(samples, subset=S)
            left_set = improve_cut(G,left_set)
        elif method == 'spectral':
            G = construct_similarity_graph(samples, subset=list(S), threshold=sim_thresh)
            left_set = spectral_split(G)
            left_set = spectral_improve_cut(left_set, G)
        elif method == 'sgreedy+':
            G = construct_similarity_graph(samples, subset=S, threshold=sim_thresh)
            left_set = spectral_improve_cut(greedy_cut(samples, subset=S) , G)

        if len(left_set) == 0 or len(left_set) == len(S):
            left_set = greedy_cut(samples, subset=S)
        right_set = set()
        for i in S:
            if not i in left_set:
                right_set.add(i)
        root = len(T.nodes) - len(subset) + len(samples)
        T.add_node(root)
        left_child = build_helper(left_set)
        right_child = build_helper(right_set)
        T.add_edge(root, left_child)
        T.add_edge(root, right_child)
        return root
    build_helper(subset)
    return T

def triplets_correct_sep(T, Tt, sample_size=5000):
    TC = 0
    sample_set = np.array([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    for a in range(sample_size):
        chosen = np.random.choice(sample_set, 3, replace=False)
        if outgroup2(chosen[0], chosen[1], chosen[2], T)[0] == outgroup2(chosen[0], chosen[1], chosen[2], Tt)[0]:
            TC += 1
    return TC/sample_size

def outgroup2(i, j, k, T):
    assert i != j and i != k and j != k, str(i) + ' ' + str(j) + ' ' + str(k) + ' not distinct'
    
#     Li = find_tree_lineage(i, T)
#     Lj = find_tree_lineage(j, T)
#     Lk = find_tree_lineage(k, T)

    Li = [node for node in nx.ancestors(T, i)]
    Lj = [node for node in nx.ancestors(T, j)]
    Lk = [node for node in nx.ancestors(T, k)]
    
    ij_common = len(set(Li) & set(Lj))
    ik_common = len(set(Li) & set(Lk))
    jk_common = len(set(Lj) & set(Lk))
    index = min(ij_common, ik_common, jk_common)

    if ij_common == ik_common and ik_common == jk_common:
        return None, index
    if ij_common > ik_common and ij_common > jk_common:
        return k, index
    elif jk_common > ik_common and jk_common > ij_common:
        return i, index
    elif ik_common > ij_common and ik_common > jk_common:
        return j, index

def triplets_correct_stratified(T, Tt, sample_size=5000, min_size_depth = 20):
    correct_class = defaultdict(int)
    freqs = defaultdict(int)
    sample_set = np.array([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    
    for a in range(sample_size):
        chosen = np.random.choice(sample_set, 3, replace=False)
        out1, index = outgroup2(chosen[0], chosen[1], chosen[2], T)
        out2, index2 = outgroup2(chosen[0], chosen[1], chosen[2], Tt)
        correct_class[index] += (out1 == out2)
        freqs[index] += 1
        
    tot_tp = 0
    num_consid = 0
    
    for k in correct_class.keys():
        if freqs[k] > min_size_depth:

            num_consid += 1
            tot_tp += correct_class[k] / freqs[k]

    tot_tp /= num_consid
    return tot_tp

def get_colless(network):
    root = [n for n in network if network.in_degree(n) == 0][0]
    colless = [0]
    colless_helper(network, root, colless)
    n = len([n for n in network if network.out_degree(n) == 0 and network.in_degree(n) == 1]) 
    return colless[0], (colless[0] - n * np.log(n) - n * (np.euler_gamma - 1 - np.log(2)))/n

def colless_helper(network, node, colless):
    if network.out_degree(node) == 0:
        return 1
    else:
        leaves = []
        for i in network.successors(node):
            leaves.append(colless_helper(network, i, colless))
        colless[0] += abs(leaves[0] - leaves[1])
        return sum(leaves)

def triplets_correct_at_time_sep(T, Tt, method='all', bin_size = 10, sample_size=5000, sampling_depths=None):
    sample_set = set([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    children = {}
    num_triplets = {}
    nodes_at_depth = {}

    def find_children(node, total_time):
        t = total_time + Tt.nodes[node]['parent_lifespan']
        children[node] = []
        if Tt.out_degree(node) == 0:
            if node in sample_set:
                children[node].append(node)
            return

        for n in Tt.neighbors(node):
            find_children(n, t)
            children[node] += children[n]

        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        num_triplets[node] = len(children[L])*nCr(len(children[R]), 2) + len(children[R])*nCr(len(children[L]), 2)
        if num_triplets[node] > 0:
            bin_num = t//bin_size
            
            if bin_num in nodes_at_depth:
                nodes_at_depth[bin_num].append(node)
            else:
                nodes_at_depth[bin_num] = [node]
                
    root = [n for n in Tt if Tt.in_degree(n) == 0][0]
    find_children(root, 0)

    def sample_at_depth(d):
        denom = sum([num_triplets[v] for v in nodes_at_depth[d]])
        node = np.random.choice(nodes_at_depth[d], 1, [num_triplets[v]/denom for v in nodes_at_depth[d]])[0]
        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        if np.random.random() < (len(children[R])-1)/(len(children[R])+len(children[L])-2):
            outgrp = np.random.choice(children[L], 1)[0]
            ingrp = np.random.choice(children[R], 2, replace=False)
        else:
            outgrp = np.random.choice(children[R], 1)[0]
            ingrp = np.random.choice(children[L], 2, replace=False)
        return outgroup(ingrp[0], ingrp[1], outgrp, T) == outgrp

    if not sampling_depths:
        sampling_depths = [d for d in range(len(nodes_at_depth))]
    if method == 'aggregate':
        score = 0
        freq = 0
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    freq += 1
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
        return score/(sample_size*freq)
    elif method == 'all':
        ret = ['NA'] * len(sampling_depths)
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    score = 0
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
                    ret[d] = score/sample_size
        return np.array(ret)

def triplets_correct_at_depth_sep(T, Tt, method='all', sample_size=5000, sampling_depths=None):
    sample_set = set([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    children = {}
    num_triplets = {}
    nodes_at_depth = {}

    def find_children(node, depth):
        children[node] = []
        if Tt.out_degree(node) == 0:
            if node in sample_set:
                children[node].append(node)
            return

        for n in Tt.neighbors(node):
            find_children(n, depth+1)
            children[node] += children[n]

        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        num_triplets[node] = len(children[L])*nCr(len(children[R]), 2) + len(children[R])*nCr(len(children[L]), 2)
        if num_triplets[node] > 0:
            if depth in nodes_at_depth:
                nodes_at_depth[depth].append(node)
            else:
                nodes_at_depth[depth] = [node]
                
    root = [n for n in Tt if Tt.in_degree(n) == 0][0]
    find_children(root, 0)

    def sample_at_depth(d):
        denom = sum([num_triplets[v] for v in nodes_at_depth[d]])
        node = np.random.choice(nodes_at_depth[d], 1, [num_triplets[v]/denom for v in nodes_at_depth[d]])[0]
        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        if np.random.random() < (len(children[R])-1)/(len(children[R])+len(children[L])-2):
            outgrp = np.random.choice(children[L], 1)[0]
            ingrp = np.random.choice(children[R], 2, replace=False)
        else:
            outgrp = np.random.choice(children[R], 1)[0]
            ingrp = np.random.choice(children[L], 2, replace=False)
        return outgroup(ingrp[0], ingrp[1], outgrp, T) == outgrp

    if not sampling_depths:
        sampling_depths = [d for d in range(len(nodes_at_depth))]
        
    if method == 'aggregate':
        score = 0
        freq = 0
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    freq += 1
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
        return score/(sample_size*freq)
    elif method == 'all':
        ret = ['NA'] * len(sampling_depths)
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    score = 0
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
                    ret[d] = score/sample_size
        return np.array(ret) 

def eX(p, q, h):
    if p == 0.5:
        return 2*p*q*(h-1)
    return (2*q*p*(1-(2-2*p)**h)/(2*p-1))

def VarX(p, q, h):
    def cross_terms(p, q, h):
        if p == 0.5:
            return p**2*q**2*sum([(h-d)**2 for d in range(h)])
        s = 0
        for d in range(h):
            s += ((2-2*p)**d)*((2-2*p)**(h-d)-1)**2
        return p**2*q**2*s/(2*p-1)**2
    return 2*cross_terms(p,q,h) + eX(p,q,h) - eX(p,q,h)**2

def eY(p, q, h):
    return q*(1-(1-p)**h)

def Var_Y(p,q,h):
    Ey = q*(1-(1-p)**h)
    s = 0
    s += 2*(1-((1-p)/2)**h)/(1+p) - 4*(1-p)**h*(1-2**(-h)) + (2-2*p)*(1-p)**(2*h)*(1-(1/(2-2*p))**h)/(1-2*p)
    Eyiyj = q**2*s/4
    Eyiyi = p*q*(1-((1-p)/2)**h)/(1+p)
    return 2*Eyiyj + Eyiyi - Ey**2

def CoV(p, q, h):
    s = 2*(2-2*p)**h*(1-2**(-h)) - (2-2*p)*(1-p)**h*((2-2*p)**h - 1)/(1-2*p) - (1-(1-p)**h)/p + h*(1-p)**h
    twoExiyj = p*q**2*s/(1-2*p)
    Ey = q*(1-(1-p)**h)
    Ex = (2*q*p*(1-(2-2*p)**h)/(2*p-1))
    return twoExiyj + Ey - Ex*Ey

def generate_frequency_dict(samples, subset=None):
    k = samples[0].num_chars
    F = {}
    for n in range(k):
        F[n] = {}
    if not subset:
        subset = list(range(len(samples)))
    for i in subset:
        for j in range(k):
            if samples[i].chars[j] in F[j]:
                F[j][samples[i].chars[j]] += 1
            else:
                F[j][samples[i].chars[j]] = 1
    return F

def score_egreedy(samples, p, qs, subset = None):
    if not subset:
        subset = list(range(len(samples)))
    F = generate_frequency_dict(samples, subset)
    h = math.log2(len(subset))
    total = 0
    
    if len(subset) > 1:
        for char in F:
            for state in F[char]:
                if state == 0 or state == -1:
                    continue
                q = qs[char][str(state)]
                if q > 0:
                    total += (eX(p,q,h) + CoV(p,q,h)/Var_Y(p,q,h) * (F[char][state] - eY(p,q,h)))

    return total

def egreedy_cut(samples, p, qs, subset = None):
    if not subset:
        subset = list(range(len(samples)))
    F = generate_frequency_dict(samples, subset)
    k = samples[0].num_chars
    min_score = 9223372036854775807
    split_char = 0
    split_state = 0
    
    for char in F:
        for state in F[char]:
            if state == 0 or state == -1:
                continue
                
            S = set()
            Sc = set()
            missing = set()
            #print(char, state)
            for i in subset:
                if samples[i].chars[char] == state:
                    S.add(i)
                elif samples[i].chars[char] == -1:
                    missing.add(i)
                else:
                    Sc.add(i)
            
            if not Sc or not S:
                continue
                
            score_S = score_egreedy(samples, p, qs, S)
            score_Sc = score_egreedy(samples, p, qs, Sc)
            
#             print(score_S + score_Sc)
            if score_S + score_Sc < min_score:
                min_score = score_S + score_Sc
                split_char = char
                split_state = state
    
    if split_state == 0:
        return random_nontrivial_cut(subset)
    
    S = set()
    Sc = set()
    missing = set()
    #print(char, state)
    for i in subset:
        if samples[i].chars[split_char] == split_state:
            S.add(i)
        elif samples[i].chars[split_char] == -1:
            missing.add(i)
        else:
            Sc.add(i)
    
    for i in missing:
        s_score = 0
        sc_score = 0
        for j in S:
            for l in range(k):
                if samples[i].chars[l] > 0 and samples[i].chars[l] == samples[j].chars[l]:
                    s_score += 1
        for j in Sc:
            for l in range(k):
                if samples[i].chars[l] > 0  and samples[i].chars[l] == samples[j].chars[l]:
                    sc_score += 1
        if s_score/len(S) > sc_score/len(Sc):
            S.add(i)
        else:
            Sc.add(i)
            
    if not Sc:
        if len(S) == len(subset) or len(S) == 0:
            print('One side of split empty')
            print(F)
            print(char, state, len(subset))
        return S
    
#     if len(Sc) == 1 or len(S) == 1:
#         print('split of size 1 chosen')
    
    return S

In [16]:
# for drop in [10, 20, 30, 40, 50, 60]:
# for drop in [0]:
folder = "400cells_base_nb"
path = "/data/yosef2/users/richardz/projects/benchmarking/" + folder + "/"
nums = []
#     triplets = []
triplets_new = []
colless = []
types = []
methods = []

for method in ["greedy", "egreedy", "sgreedy+", "spectral", "greedy+", "SDP"]:
    for num in range(0, 10):
        dropout_cm = pd.read_csv(path + "dropout_cm" + str(num) + ".txt", sep = '\t', index_col = 0)
        dropout_cm = dropout_cm.applymap(str)

        samples = []
        for index, row in dropout_cm.iterrows():
            node = Node(list(row), 1001, parent=None, left=None, right=None)
            samples.append(node)

        subset = list(range(len(samples)))
        prune_samples = remove_duplicates(samples, subset)

        node_map = {}
        for i in prune_samples:
            node_map[i] = list(dropout_cm.iloc[i,:])

        net = pic.load(open(path + "dropout_net" + str(num) + ".pkl", 'rb'))
        ground = net.network
        leaves = [n for n in ground if ground.out_degree(n) == 0 and ground.in_degree(n) == 1]
        ground_net_map = {}
        for i in node_map:
            for j in leaves:
                if node_map[i] == j.char_vec:
                    ground_net_map[j] = i
                    break
        ground = nx.relabel_nodes(ground, ground_net_map)

        dropout_cm = dropout_cm.replace("-", -1)
        dropout_cm = dropout_cm.replace("*", -1)
        dropout_cm = dropout_cm.apply(pd.to_numeric)

        samples = []
        for index, row in dropout_cm.iterrows():
            node = Node(list(row), 1001, parent=None, left=None, right=None)
            samples.append(node)
    #     sample_map = {}
    #     for i in range(len(samples)):
    #         sample_map[i] = samples[i]
        if method == 'egreedy':
            p = 0.007550225842780328
#                 p = np.log(1 - (mut/100))/(-1 * 100)
            priors = pic.load(open(path + "priors" + str(num) + ".pkl", 'rb'))
            recon = build_tree_sep(samples, method="egreedy", p = p, qs = priors)
        else:
            recon = build_tree_sep(samples, method=method)
#             trip = triplets_correct_stratified(recon, ground)
        trip2 = triplets_correct_at_depth_sep(recon, ground, 'aggregate')
        print(num, trip2, method)
#             triplets.append(trip)
        triplets_new.append(trip2)
        nums.append(num)
        colless.append(get_colless(ground)[0])
        types.append(folder)
        methods.append(method)

data = [nums, triplets_new, colless, methods, types]
df = pd.DataFrame(data)
df = df.T
df = df.rename(columns = {0: 'Run', 1: 'TripletsCorrect', 2:'Colless', 3:'Method', 4: 'Type'})
df.to_csv(path + 'methods_triplets2.txt', sep = '\t', index = False)

0 0.6916333333333333 greedy
1 0.7673090909090909 greedy
2 0.7751666666666667 greedy
3 0.7358545454545454 greedy
4 0.7367428571428571 greedy
5 0.7442857142857143 greedy
6 0.7847 greedy
7 0.7366833333333334 greedy
8 0.7654285714285715 greedy
9 0.7945636363636364 greedy
0 0.5432166666666667 egreedy
1 0.5462 egreedy
2 0.6081833333333333 egreedy
3 0.5560545454545455 egreedy
4 0.5960285714285715 egreedy
5 0.5807714285714286 egreedy
6 0.56238 egreedy
7 0.6355333333333333 egreedy
8 0.5488571428571428 egreedy
9 0.6350727272727272 egreedy
0 0.6916833333333333 sgreedy+
1 0.7698 sgreedy+
2 0.7623666666666666 sgreedy+
3 0.7524181818181818 sgreedy+
4 0.7514857142857143 sgreedy+
5 0.7543714285714286 sgreedy+
6 0.77532 sgreedy+
7 0.7272 sgreedy+
8 0.7542857142857143 sgreedy+
9 0.7937454545454545 sgreedy+
0 0.4338666666666667 spectral
1 0.3503636363636364 spectral
2 0.5565166666666667 spectral
3 0.5758545454545455 spectral
4 0.5295714285714286 spectral
5 0.32474285714285717 spectral
6 0.53564 spectral




0 0.6910166666666666 SDP
1 0.7699636363636364 SDP
2 0.7637333333333334 SDP
3 0.7449454545454546 SDP
4 0.7214285714285714 SDP
5 0.7541142857142857 SDP
6 0.7739 SDP
7 0.74015 SDP
8 0.7852285714285714 SDP
9 0.7601636363636364 SDP


In [12]:
max_len = max([len(n) for n in triplets_new])
for i in range(len(triplets_new)):
    triplets_new[i]
    while(len(triplets_new[i]) < max_len):
        print(len(triplets_new[i]))
        triplets_new[i] = np.append(triplets_new[i], ['NA'], axis = 0)
df = pd.DataFrame(triplets_new)
# [len(n) for n in triplets_new]
df.to_csv(path + 'methods_triplets_depth.txt', sep = '\t', index = False)

15
16
17
16
17
16
17
15
16
17
17
17
16
17
15
16
17
16
17
16
17
15
16
17
17
17
16
17
15
16
17
16
17
16
17
15
16
17
17
17
16
17
15
16
17
16
17
16
17
15
16
17
17
17
16
17
15
16
17
16
17
16
17
15
16
17
17
17
16
17
15
16
17
16
17
16
17
15
16
17
17
17
16
17


In [13]:
path

'/data/yosef2/users/richardz/projects/benchmarking/400cells_base_m/'