In [1]:
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import scipy as sp
import pickle as pic
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict
import time
import math
start_time = time.time()

In [2]:
def eX(p, q, h):
    if p == 0.5:
        return 2*p*q*(h-1)
    return (2*q*p*(1-(2-2*p)**h)/(2*p-1))

def VarX(p, q, h):
    def cross_terms(p, q, h):
        if p == 0.5:
            return p**2*q**2*sum([(h-d)**2 for d in range(h)])
        s = 0
        for d in range(h):
            s += ((2-2*p)**d)*((2-2*p)**(h-d)-1)**2
        return p**2*q**2*s/(2*p-1)**2
    return 2*cross_terms(p,q,h) + eX(p,q,h) - eX(p,q,h)**2

def eY(p, q, h):
    return q*(1-(1-p)**h)

def Var_Y(p,q,h):
    Ey = q*(1-(1-p)**h)
    s = 0
    s += 2*(1-((1-p)/2)**h)/(1+p) - 4*(1-p)**h*(1-2**(-h)) + (2-2*p)*(1-p)**(2*h)*(1-(1/(2-2*p))**h)/(1-2*p)
    Eyiyj = q**2*s/4
    Eyiyi = p*q*(1-((1-p)/2)**h)/(1+p)
    return 2*Eyiyj + Eyiyi - Ey**2

def CoV(p, q, h):
    s = 2*(2-2*p)**h*(1-2**(-h)) - (2-2*p)*(1-p)**h*((2-2*p)**h - 1)/(1-2*p) - (1-(1-p)**h)/p + h*(1-p)**h
    twoExiyj = p*q**2*s/(1-2*p)
    Ey = q*(1-(1-p)**h)
    Ex = (2*q*p*(1-(2-2*p)**h)/(2*p-1))
    return twoExiyj + Ey - Ex*Ey

In [3]:
class Node:
    
    def __init__(self, char_list, num_states, parent=None, left=None, right=None):
        # num_states includes the 0 state. For example, if the possible states are 0 or 1, then num_states=2
        self.chars = char_list
        self.parent = parent
        self.left = left
        self.right = right
        self.num_chars = len(self.chars)
        self.num_states = num_states
        
    def is_leaf(self):
        return not (self.left or self.right)

    def duplicate(self, p, q=None, dropout_rate=0):
        assert len(p) == len(self.chars), "invalid p vector"
        if q:
            for i in range(len(p)):
                assert len(q[i]) + 1 == self.num_states, "invalid q[" + str(i) + "] vector"
                
        new_chars = []
        if not q:
            q = [None for i in self.chars]

        for l in range(len(self.chars)):
            if self.chars[l] != 0:
                new_chars.append(self.chars[l])
            else:
                if np.random.random() < p[l]:
                    new_chars.append(np.random.choice(np.arange(1, self.num_states), 1, p=q[l])[0])
                else:
                    new_chars.append(self.chars[l])
        
        return Node(new_chars, self.num_states)
    
    def __str__(self):
        s = ''
        for x in self.chars:
            s += str(x) + '|'
        return s[:-1]
                    
def simulation(p, num_states, time, q=None):
    root = Node([0 for i in p], num_states)
    curr_gen = [root]
    for t in range(time):
        new_gen = []
        for n in curr_gen:
            c1 = n.duplicate(p, q)
            c2 = n.duplicate(p, q)
            c1.parent = n
            c2.parent = n
            n.left = c1
            n.right = c2
            new_gen.append(c1)
            new_gen.append(c2)
        curr_gen = new_gen
    return curr_gen

def find_lineage(node):
    lineage = [node]
    while node.parent:
        node = node.parent
        lineage.insert(0, node)
    return lineage

def find_root(samples):
    return find_lineage(samples[0])[0]

def print_tree(root):
    
    def tree_str(node, level=0):
        ret = "\t"*level+str(node)+"\n"
        if node.left:
            ret += tree_str(node.left, level+1)
        if node.right:
            ret += tree_str(node.right, level+1)
        return ret
    
    print(tree_str(root))
    
def generate_frequency_matrix(samples, subset=None):
    k = samples[0].num_chars
    m = samples[0].num_states + 1
    F = np.zeros((k,m), dtype=int)
    if not subset:
        subset = list(range(len(samples)))
    for i in subset:
        for j in range(k):
            F[j][samples[i].chars[j]] += 1
    return F
            
def split_data(F):
    k,m = F.shape[0], F.shape[1]
    split_data = []
    for i in range(k):
        for j in range(1, m-1):
            split_data.append((i,j))
    split_data.sort(key=lambda tup: F[tup[0]][tup[1]], reverse=True)
    index = 0
    
    for i in range(5):
        s = ''
        for j in range(1, 5):
            a, b = split_data[index][0], split_data[index][1]
            s += str((a,b)) + " freq =" + str(F[a][b]) + " "
            index += 1
        print(s)
            
    
def construct_connectivity_graph(samples, subset=None):
    n = len(samples)
    k = samples[0].num_chars
    m = samples[0].num_states
    G = nx.Graph()
    if not subset:
        subset = range(n)
    for i in subset:
        G.add_node(i)
    F = generate_frequency_matrix(samples, subset)
    for i in subset:
        for j in subset:
            if j <= i:
                continue
            n1 = samples[i]
            n2 = samples[j]
            #compute simularity score
            score = 0
            for l in range(k):
                x = n1.chars[l]
                y = n2.chars[l]
                if min(x, y) >= 0 and max(x,y) > 0:
                    if x==y:
                        score -= 3*(len(subset) - F[l][x] - F[l][-1])
                    elif min(x,y) == 0:
                        score += F[l][max(x,y)] - 1
                    else:
                        score += (F[l][x] + F[l][y]) - 2
                        
                if score != 0:
                    G.add_edge(i,j, weight=score)
    return G

def max_cut_heuristic(G, sdimension, iterations, show_steps=False):
    #n = len(G.nodes())
    d = sdimension+1
    emb = {}        
    for i in G.nodes():
        x = np.random.normal(size=d)
        x = x/np.linalg.norm(x)
        emb[i] = x
        
    def show_relaxed_objective():
        score = 0
        for e in G.edges():
            u = e[0]
            v = e[1]
            score += G[u][v]['weight']*np.linalg.norm(emb[u]-emb[v])
        print(score)
        
    for k in range(iterations):
        new_emb = {}
        for i in G.nodes:
            cm = np.zeros(d, dtype=float)
            for j in G.neighbors(i):
                cm -= G[i][j]['weight']*np.linalg.norm(emb[i]-emb[j])*emb[j]
            cm = cm/np.linalg.norm(cm)
            new_emb[i] = cm
        emb = new_emb
        
    #print("final relaxed objective:")
    #show_relaxed_objective()
    return_set = set()
    best_score = 0
    for k in range(3*d):
        b = np.random.normal(size=d)
        b = b/np.linalg.norm(b)
        S = set()
        for i in G.nodes():
            if np.dot(emb[i], b) > 0:
                S.add(i)
        this_score = evaluate_cut(S, G)
        if this_score > best_score:
            return_set = S
            best_score = this_score
    #print("score before hill climb = ", best_score)
    improved_S = improve_cut(G, return_set)
    #final_score = evaluate_cut(improved_S, G)
    #print("final score = ", final_score)
    return improved_S

def improve_cut(G, S):
    #n = len(G.nodes())
    ip = {}
    new_S = S.copy()
    for i in G.nodes():
        improvement_potential = 0
        for j in G.neighbors(i):
            if cut(i,j,new_S):
                improvement_potential -= G[i][j]['weight']
            else:
                improvement_potential += G[i][j]['weight']
        ip[i] = improvement_potential
        
    all_neg = False
    iters = 0
    while (not all_neg) and (iters < 2*len(G.nodes)):
        best_potential = 0
        best_index = 0
        for i in G.nodes():
            if ip[i] > best_potential:
                best_potential = ip[i]
                best_index = i
        if best_potential > 0:
            for j in G.neighbors(best_index):
                if cut(best_index,j,new_S):
                    ip[j] += 2*G[best_index][j]['weight']
                else:
                    ip[j] -= 2*G[best_index][j]['weight']
            ip[best_index] = -ip[best_index]
            if best_index in new_S:
                new_S.remove(best_index)
            else:
                new_S.add(best_index)
        else:
            all_neg = True
        iters += 1
    #print("number of hill climbing interations: ", iters)
    return new_S

def evaluate_cut(S, G, B=None, show_total=False):
    cut_score = 0
    total_good = 0
    total_bad = 0
    for e in G.edges():
        u = e[0]
        v = e[1]
        w_uv = G[u][v]['weight']
        total_good += float(w_uv)
        if cut(u,v,S):
            cut_score += float(w_uv)

    if B:
        for e in B.edges():
            u = e[0]
            v = e[1]
            w_uv = B[u][v]['weight']
            total_bad += float(w_uv)
            if cut(u,v,S):
                cut_score -= float(w_uv)
            
    if show_total:
        print("total good = ", total_good)
        print("total bad = ", total_bad)
    return(cut_score)

def greedy_cut(samples, subset=None):
    F = generate_frequency_matrix(samples, subset)
    k,m = F.shape[0], F.shape[1]
    freq = 0
    char = 0
    state = 0
    if not subset:
        subset = list(range(len(samples)))
    for i in range(k):
        for j in range(1, m-1):
            if F[i][j] > freq and F[i][j] < len(subset) - F[i][-1] :
                char, state = i,j
                freq = F[i][j]
    if freq == 0:
        return random_nontrivial_cut(subset)
    S = set()
    Sc = set()
    missing = set()
    #print(char, state)
    for i in subset:
        if samples[i].chars[char] == state:
            S.add(i)
        elif samples[i].chars[char] == -1:
            missing.add(i)
        else:
            Sc.add(i)
            
    if not Sc:
        if len(S) == len(subset) or len(S) == 0:
            print(F)
            print(char, state, len(subset))
        return S
    
    for i in missing:
        s_score = 0
        sc_score = 0
        for j in S:
            for l in range(k):
                if samples[i].chars[l] > 0 and samples[i].chars[l] == samples[j].chars[l]:
                    s_score += 1
        for j in Sc:
            for l in range(k):
                if samples[i].chars[l] > 0  and samples[i].chars[l] == samples[j].chars[l]:
                    sc_score += 1
        if s_score/len(S) > sc_score/len(Sc):
            S.add(i)
        else:
            Sc.add(i)
        
    if len(S) == len(subset) or len(S) == 0:
            print(F)
            print(char, state, len(subset))
    return S
    
def random_cut(subset):
    S = set()
    for i in subset:
        if np.random.random() > 0.5:
            S.add(i)
    return S

def random_nontrivial_cut(subset):
    assert len(subset) > 1
    S = set()
    lst = list(subset)
    S.add(lst[0])
    for i in range(2,len(lst)):
        if np.random.random() > 0.5:
            S.add(lst[i])
    return S


def cut(u, v, S):
    return ((u in S) and (not v in S)) or ((v in S) and (not u in S))

def num_incorrect(S, h):
    num = 0
    for i in range(int(2**h/2)):
        if not i in S:
            num += 1

    for i in range(int(2**h/2), 2**h):
        if i in S:
            num += 1

    return min(num, 2**h - num)

def find_tree_lineage(i, T):
    p = list(T.predecessors(i))
    curr_node = i
    ancestor_list = [curr_node]
    while p:
        curr_node = p[0]
        ancestor_list.insert(0, curr_node)
        p = list(T.predecessors(curr_node))
    return(ancestor_list)

        
def outgroup(i, j, k, T):
    assert i != j and i != k and j != k, str(i) + ' ' + str(j) + ' ' + str(k) + ' not distinct'
    
    Li = find_tree_lineage(i, T)
    Lj = find_tree_lineage(j, T)
    Lk = find_tree_lineage(k, T)
    l = 0
    while Li[l] == Lj[l] and Lj[l] == Lk[l]:
        l += 1
    if Li[l] != Lj[l] and Li[l] != Lk[l] and Lj[l] != Lk[l]:
        return None
    if Li[l] == Lj[l]:
        return k
    if Li[l] == Lk[l]:
        return j
    if Lj[l] == Lk[l]:
        return i
    
    
    
def evaluate_split(S, subset, T, sample_size=1000):
    # assume S \subseteq T.leaves
    def S_outgroup(i,j,k):
        if (not cut(i,j,subset)) and (not cut(j,k,subset)):
            return None
        if not cut(i,j,subset):
            return k
        if not cut(i,k,subset):
            return j
        return i
    
    TC = 0
    TI = 0
    unresolved = 0
    superset = np.array(list(S))
    num_sampled = 0
    for a in range(sample_size):
        chosen = np.random.choice(superset, 3, replace=False)
        oS = S_outgroup(chosen[0], chosen[1], chosen[2])
        oT = outgroup(chosen[0], chosen[1], chosen[2], T)
        if oS == None or oT == None:
            unresolved += 1
        else:
            if oS == oT:
                TC += 1
            else:
                TI += 1
    return TC/sample_size, TI/sample_size, unresolved/sample_size
                  
def remove_duplicates(nodes, indices):
    indices = list(indices)
    indices.sort(key=lambda i: nodes[i].chars)
    final_set = set()
    i = 0
    j = 1
    while j < len(indices):
        if nodes[indices[i]].chars != nodes[indices[j]].chars:
            final_set.add(indices[i])
            i = j
        j += 1
    final_set.add(indices[i])
    return final_set
            
def mult_chain(a,b):
    f = 1
    for i in range(a, b+1):
        f*=i
    return f

def nCr(n, k):
    if k > n:
        return 0
    if k > n/2:
        return nCr(n, n-k)
    return int(mult_chain(n-k+1,n)/mult_chain(1,k))

def similarity(u, v, samples):
    k = samples[0].num_chars
    return sum([1 for i in range(k) if samples[u].chars[i] == samples[v].chars[i] and samples[u].chars[i] > 0])

def construct_similarity_graph(samples, subset=None, threshold=0):
    G = nx.Graph()
    if not subset:
        subset = range(len(samples))
    for i in subset:
        G.add_node(i)
    F = generate_frequency_matrix(samples, subset)
    k,m = F.shape[0], F.shape[1]
    for i in range(k):
        for j in range(1,m-1):
            if F[i][j] == len(subset) - F[i][-1]:
                threshold += 1
    for i in subset:
        for j in subset:
            if j <= i:
                continue
            s = similarity(i,j, samples) 
            if s > threshold:
                G.add_edge(i,j, weight=(s-threshold))
    return G

def spectral_split(G, k=2, method='Fiedler', return_eig=False, display=False):
    L = nx.normalized_laplacian_matrix(G).todense()
    diag = sp.linalg.eig(L)
    if k == 2 and method == 'Fiedler':
        v2 = diag[1][:, 1] 
        x = {}
        vertices = list(G.nodes())
        for i in range(len(vertices)):
            x[vertices[i]] = v2[i]
        vertices.sort(key=lambda v: x[v])
        total_weight = 2*sum([G[e[0]][e[1]]['weight'] for e in G.edges()])
        S = set()
        num = 0
        denom = 0
        best_score = 10000000
        best_index = 0
        for i in range(len(vertices) - 1):
            v = vertices[i]
            S.add(v)
            cut_edges = 0
            neighbor_weight = 0
            for w in G.neighbors(v):
                neighbor_weight += G[v][w]['weight']
                if w in S:
                    cut_edges += G[v][w]['weight']
            denom += neighbor_weight
            num += neighbor_weight - 2*cut_edges
            if num == 0:
                best_index = i
                break
            if num/min(denom, total_weight-denom) < best_score:
                best_score = num/min(denom, total_weight-denom)
                best_index = i
        if display:
            print("number of samples = ", len(v2))
            print("lambda2 = ", diag[0][1])
            plt.hist(v2, density=True, bins=30)
            plt.hist([x[v] for v in vertices[:best_index+1]], density=True, bins=30)
            plt.show()
        if return_eig:
            return vertices[:best_index+1], diag
        return vertices[:best_index+1]

def spectral_improve_cut(S, G, display=False):
    delta_n = {}
    delta_d = {}
    ip = {}
    new_S = set(S)
    total_weight = 2*sum([G[e[0]][e[1]]['weight'] for e in G.edges()])
    num =  sum([G[e[0]][e[1]]['weight'] for e in G.edges() if cut(e[0], e[1], new_S)])
    denom = sum([sum([G[u][v]['weight'] for v in G.neighbors(u)]) for u in new_S])
    if num == 0:
        return list(new_S)
    curr_score = num/min(denom, total_weight-denom)
    
    def set_ip(u):
        if min(denom + delta_d[u], total_weight - denom - delta_d[u]) == 0:
            ip[u] = 1000
        else:
            ip[u] = (num + delta_n[u])/min(denom + delta_d[u], total_weight - denom - delta_d[u]) - num/min(denom, total_weight - denom)
    
    for u in G.nodes():
        d = sum([G[u][v]['weight'] for v in G.neighbors(u)])
        if d == 0:
            return [u]
        c = sum([G[u][v]['weight'] for v in G.neighbors(u) if cut(u,v,new_S)])
        delta_n[u] = d-2*c
        if u in new_S:
            delta_d[u] = -d
        else:
            delta_d[u] = d
        set_ip(u)
    #TODO
    all_neg = False
    iters = 0
    
    while (not all_neg) and (iters < len(G.nodes)):
        best_potential = 0
        best_index = None
        for v in G.nodes():
            if ip[v] < best_potential:
                best_potential = ip[v]
                best_index = v
        if not best_index is None:
            num += delta_n[best_index]
            denom += delta_d[best_index]
            for j in G.neighbors(best_index):
                if cut(best_index,j,new_S):
                    delta_n[j] += 2*G[best_index][j]['weight']
                else:
                    delta_n[j] -= 2*G[best_index][j]['weight']
                set_ip(j)
            delta_n[best_index] = -delta_n[best_index]
            delta_d[best_index] = -delta_d[best_index]
            set_ip(best_index)
            if best_index in new_S:
                new_S.remove(best_index)
            else:
                new_S.add(best_index)
            #print("curr scores:", num/min(denom, total_weight - denom))
        else:
            all_neg = True
        iters += 1
    if display:
        print("sgreed+ score, ",  num/min(denom, total_weight - denom))
    return list(new_S)

def evaluate_sparsity(S, G):
    total_weight = 2*sum([G[e[0]][e[1]]['weight'] for e in G.edges()])
    num =  sum([G[e[0]][e[1]]['weight'] for e in G.edges() if cut(e[0], e[1], S)])
    denom = sum([sum([G[u][v]['weight'] for v in G.neighbors(u)]) for u in S])
    return num/min(denom, total_weight - denom)
    
def build_tree_sep(samples, method='greedy', subset = None, sim_thresh=0, p = None, qs = None):
    assert method in ['greedy', 'egreedy', 'SDP', 'greedy+', 'spectral', 'sgreedy+']
    if not subset:
        subset = list(range(len(samples)))
    else:
        subset = list(subset)
    subset = remove_duplicates(samples, subset)
    T = nx.DiGraph()
    for i in subset:
        T.add_node(i)
    def build_helper(S):
        assert S, "error, S = "+ str(S)
        if len(S) == 1:
            return list(S)[0]
        left_set = set()
        if method == 'greedy':
            left_set = greedy_cut(samples, subset=S)
        elif method == 'egreedy':
            left_set = egreedy_cut(samples, p, qs, subset=S)
        elif method == 'SDP':
            G = construct_connectivity_graph(samples, subset=S)
            left_set = max_cut_heuristic(G, 3, 50)
        elif method == 'greedy+':
            G = construct_connectivity_graph(samples, subset=S)
            left_set = greedy_cut(samples, subset=S)
            left_set = improve_cut(G,left_set)
        elif method == 'spectral':
            G = construct_similarity_graph(samples, subset=list(S), threshold=sim_thresh)
            left_set = spectral_split(G)
            left_set = spectral_improve_cut(left_set, G)
        elif method == 'sgreedy+':
            G = construct_similarity_graph(samples, subset=S, threshold=sim_thresh)
            left_set = spectral_improve_cut(greedy_cut(samples, subset=S) , G)

        if len(left_set) == 0 or len(left_set) == len(S):
            left_set = greedy_cut(samples, subset=S)
        right_set = set()
        for i in S:
            if not i in left_set:
                right_set.add(i)
        root = len(T.nodes) - len(subset) + len(samples)
        T.add_node(root)
        left_child = build_helper(left_set)
        right_child = build_helper(right_set)
        T.add_edge(root, left_child)
        T.add_edge(root, right_child)
        return root
    build_helper(subset)
    return T

def triplets_correct_sep(T, Tt, sample_size=5000):
    TC = 0
    sample_set = np.array([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    for a in range(sample_size):
        chosen = np.random.choice(sample_set, 3, replace=False)
        if outgroup2(chosen[0], chosen[1], chosen[2], T)[0] == outgroup2(chosen[0], chosen[1], chosen[2], Tt)[0]:
            TC += 1
    return TC/sample_size

def outgroup2(i, j, k, T):
    assert i != j and i != k and j != k, str(i) + ' ' + str(j) + ' ' + str(k) + ' not distinct'
    
#     Li = find_tree_lineage(i, T)
#     Lj = find_tree_lineage(j, T)
#     Lk = find_tree_lineage(k, T)

    Li = [node for node in nx.ancestors(T, i)]
    Lj = [node for node in nx.ancestors(T, j)]
    Lk = [node for node in nx.ancestors(T, k)]
    
    ij_common = len(set(Li) & set(Lj))
    ik_common = len(set(Li) & set(Lk))
    jk_common = len(set(Lj) & set(Lk))
    index = min(ij_common, ik_common, jk_common)

    if ij_common == ik_common and ik_common == jk_common:
        return None, index
    if ij_common > ik_common and ij_common > jk_common:
        return k, index
    elif jk_common > ik_common and jk_common > ij_common:
        return i, index
    elif ik_common > ij_common and ik_common > jk_common:
        return j, index

def triplets_correct_stratified(T, Tt, sample_size=5000, min_size_depth = 20):
    correct_class = defaultdict(int)
    freqs = defaultdict(int)
    sample_set = np.array([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    
    for a in range(sample_size):
        chosen = np.random.choice(sample_set, 3, replace=False)
        out1, index = outgroup2(chosen[0], chosen[1], chosen[2], T)
        out2, index2 = outgroup2(chosen[0], chosen[1], chosen[2], Tt)
        correct_class[index] += (out1 == out2)
        freqs[index] += 1
        
    tot_tp = 0
    num_consid = 0
    
    for k in correct_class.keys():
        if freqs[k] > min_size_depth:

            num_consid += 1
            tot_tp += correct_class[k] / freqs[k]

    tot_tp /= num_consid
    return tot_tp

def get_colless(network):
    root = [n for n in network if network.in_degree(n) == 0][0]
    colless = [0]
    colless_helper(network, root, colless)
    n = len([n for n in network if network.out_degree(n) == 0 and network.in_degree(n) == 1]) 
    return colless[0], (colless[0] - n * np.log(n) - n * (np.euler_gamma - 1 - np.log(2)))/n

def colless_helper(network, node, colless):
    if network.out_degree(node) == 0:
        return 1
    else:
        leaves = []
        for i in network.successors(node):
            leaves.append(colless_helper(network, i, colless))
        colless[0] += abs(leaves[0] - leaves[1])
        return sum(leaves)

def triplets_correct_at_time_sep(T, Tt, method='all', bin_size = 10, sample_size=5000, sampling_depths=None):
    sample_set = set([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    children = {}
    num_triplets = {}
    nodes_at_depth = {}

    def find_children(node, total_time):
        t = total_time + Tt.nodes[node]['parent_lifespan']
        children[node] = []
        if Tt.out_degree(node) == 0:
            if node in sample_set:
                children[node].append(node)
            return

        for n in Tt.neighbors(node):
            find_children(n, t)
            children[node] += children[n]

        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        num_triplets[node] = len(children[L])*nCr(len(children[R]), 2) + len(children[R])*nCr(len(children[L]), 2)
        if num_triplets[node] > 0:
            bin_num = t//bin_size
            
            if bin_num in nodes_at_depth:
                nodes_at_depth[bin_num].append(node)
            else:
                nodes_at_depth[bin_num] = [node]
                
    root = [n for n in Tt if Tt.in_degree(n) == 0][0]
    find_children(root, 0)

    def sample_at_depth(d):
        denom = sum([num_triplets[v] for v in nodes_at_depth[d]])
        node = np.random.choice(nodes_at_depth[d], 1, [num_triplets[v]/denom for v in nodes_at_depth[d]])[0]
        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        if np.random.random() < (len(children[R])-1)/(len(children[R])+len(children[L])-2):
            outgrp = np.random.choice(children[L], 1)[0]
            ingrp = np.random.choice(children[R], 2, replace=False)
        else:
            outgrp = np.random.choice(children[R], 1)[0]
            ingrp = np.random.choice(children[L], 2, replace=False)
        return outgroup(ingrp[0], ingrp[1], outgrp, T) == outgrp

    if not sampling_depths:
        sampling_depths = [d for d in range(len(nodes_at_depth))]
    if method == 'aggregate':
        score = 0
        freq = 0
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    freq += 1
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
        return score/(sample_size*freq)
    elif method == 'all':
        ret = ['NA'] * len(sampling_depths)
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    score = 0
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
                    ret[d] = score/sample_size
        return np.array(ret)

def triplets_correct_at_depth_sep(T, Tt, method='all', sample_size=5000, sampling_depths=None):
    sample_set = set([v for v in T.nodes() if T.in_degree(v) == 1 and T.out_degree(v) == 0])
    children = {}
    num_triplets = {}
    nodes_at_depth = {}

    def find_children(node, depth):
        children[node] = []
        if Tt.out_degree(node) == 0:
            if node in sample_set:
                children[node].append(node)
            return

        for n in Tt.neighbors(node):
            find_children(n, depth+1)
            children[node] += children[n]

        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        num_triplets[node] = len(children[L])*nCr(len(children[R]), 2) + len(children[R])*nCr(len(children[L]), 2)
        if num_triplets[node] > 0:
            if depth in nodes_at_depth:
                nodes_at_depth[depth].append(node)
            else:
                nodes_at_depth[depth] = [node]
                
    root = [n for n in Tt if Tt.in_degree(n) == 0][0]
    find_children(root, 0)

    def sample_at_depth(d):
        denom = sum([num_triplets[v] for v in nodes_at_depth[d]])
        node = np.random.choice(nodes_at_depth[d], 1, [num_triplets[v]/denom for v in nodes_at_depth[d]])[0]
        L, R = list(Tt.neighbors(node))[0], list(Tt.neighbors(node))[1]
        if np.random.random() < (len(children[R])-1)/(len(children[R])+len(children[L])-2):
            outgrp = np.random.choice(children[L], 1)[0]
            ingrp = np.random.choice(children[R], 2, replace=False)
        else:
            outgrp = np.random.choice(children[R], 1)[0]
            ingrp = np.random.choice(children[L], 2, replace=False)
        return outgroup(ingrp[0], ingrp[1], outgrp, T) == outgrp

    if not sampling_depths:
        sampling_depths = [d for d in range(len(nodes_at_depth))]
        
    if method == 'aggregate':
        score = 0
        freq = 0
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    freq += 1
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
        return score/(sample_size*freq)
    elif method == 'all':
        ret = ['NA'] * len(sampling_depths)
        for d in sampling_depths:
            if d in nodes_at_depth:
                max_children = 0
                for i in nodes_at_depth[d]:
                    if len(children[i]) > max_children:
                        max_children = len(children[i])
                if max_children > 10:
                    score = 0
                    for a in range(sample_size):
                        score += int(sample_at_depth(d))
                    ret[d] = score/sample_size
        return np.array(ret) 

In [35]:
def generate_frequency_dict(samples, subset=None):
    k = samples[0].num_chars
    F = {}
    for n in range(k):
        F[n] = {}
    if not subset:
        subset = list(range(len(samples)))
    for i in subset:
        for j in range(k):
            if samples[i].chars[j] in F[j]:
                F[j][samples[i].chars[j]] += 1
            else:
                F[j][samples[i].chars[j]] = 1
    return F

def score_egreedy(samples, p, qs, subset = None):
    if not subset:
        subset = list(range(len(samples)))
    F = generate_frequency_dict(samples, subset)
    h = math.log2(len(subset))
    total = 0
    
    if len(subset) > 1:
        for char in F:
            for state in F[char]:
                if state == 0 or state == -1:
                    continue
                q = qs[char][str(state)]
                if q > 0:
                    total += (eX(p,q,h) + CoV(p,q,h)/Var_Y(p,q,h) * (F[char][state] - eY(p,q,h)))

    return total

def egreedy_cut(samples, p, qs, subset = None):
    if not subset:
        subset = list(range(len(samples)))
    F = generate_frequency_dict(samples, subset)
    k = samples[0].num_chars
    min_score = 9223372036854775807
    split_char = 0
    split_state = 0
    
    for char in F:
        for state in F[char]:
            if state == 0 or state == -1:
                continue
                
            S = set()
            Sc = set()
            missing = set()
            #print(char, state)
            for i in subset:
                if samples[i].chars[char] == state:
                    S.add(i)
                elif samples[i].chars[char] == -1:
                    missing.add(i)
                else:
                    Sc.add(i)
            
            if not Sc or not S:
                continue
                
            score_S = score_egreedy(samples, p, qs, S)
            score_Sc = score_egreedy(samples, p, qs, Sc)
            
#             print(score_S + score_Sc)
            if score_S + score_Sc < min_score:
                min_score = score_S + score_Sc
                split_char = char
                split_state = state
    
    if split_state == 0:
        return random_nontrivial_cut(subset)
    
    S = set()
    Sc = set()
    missing = set()
    #print(char, state)
    for i in subset:
        if samples[i].chars[split_char] == split_state:
            S.add(i)
        elif samples[i].chars[split_char] == -1:
            missing.add(i)
        else:
            Sc.add(i)
    
    for i in missing:
        s_score = 0
        sc_score = 0
        for j in S:
            for l in range(k):
                if samples[i].chars[l] > 0 and samples[i].chars[l] == samples[j].chars[l]:
                    s_score += 1
        for j in Sc:
            for l in range(k):
                if samples[i].chars[l] > 0  and samples[i].chars[l] == samples[j].chars[l]:
                    sc_score += 1
        if s_score/len(S) > sc_score/len(Sc):
            S.add(i)
        else:
            Sc.add(i)
            
    if not Sc:
        if len(S) == len(subset) or len(S) == 0:
            print('One side of split empty')
            print(F)
            print(char, state, len(subset))
        return S
    
#     if len(Sc) == 1 or len(S) == 1:
#         print('split of size 1 chosen')
    
    return S

In [33]:
# def generate_frequency_dict(samples, subset=None):
#     k = samples[0].num_chars
#     F = {}
#     for n in range(k):
#         F[n] = {}
#     if not subset:
#         subset = list(range(len(samples)))
#     for i in subset:
#         for j in range(k):
#             if samples[i].chars[j] in F[j]:
#                 F[j][samples[i].chars[j]] += 1
#             else:
#                 F[j][samples[i].chars[j]] = 1
#     return F

# def score_egreedy(samples, p, qs, subset = None):
#     if not subset:
#         subset = list(range(len(samples)))
#     F = generate_frequency_matrix(samples, subset)
#     k,m = F.shape[0], F.shape[1]
#     h = math.log2(len(subset))

#     total = 0
#     if len(subset) > 1:
#         for char in range(k):
#             for state in range(1, m-1):
#                 if state == 0 or state == -1:
#                     continue
#                 q = qs[char][str(state)]
#                 if q > 0:
#                     total += (eX(p,q,h) + CoV(p,q,h)/Var_Y(p,q,h) * (F[char][state] - eY(p,q,h)))

#     return total

# def egreedy_cut(samples, p, qs, subset = None):
#     if not subset:
#         subset = list(range(len(samples)))
#     F = generate_frequency_dict(samples, subset)
#     k = samples[0].num_chars
#     min_score = 9223372036854775807
#     split_char = 0
#     split_state = 0
    
#     for char in F:
#         for state in F[char]:
#             if state == 0 or state == -1:
#                 continue
                
#             S = set()
#             Sc = set()
#             missing = set()
#             #print(char, state)
#             for i in subset:
#                 if samples[i].chars[char] == state:
#                     S.add(i)
#                 elif samples[i].chars[char] == -1:
#                     missing.add(i)
#                 else:
#                     Sc.add(i)
            
#             if not Sc or not S:
#                 continue
                
#             score_S = score_egreedy(samples, p, qs, S)
#             score_Sc = score_egreedy(samples, p, qs, Sc)
            
#             print(score_S + score_Sc)
            
#             if score_S + score_Sc < min_score:
#                 min_score = score_S + score_Sc
#                 split_char = char
#                 split_state = state
    
#     if split_state == 0:
#         return random_nontrivial_cut(subset)
    
#     S = set()
#     Sc = set()
#     missing = set()
#     #print(char, state)
#     for i in subset:
#         if samples[i].chars[split_char] == split_state:
#             S.add(i)
#         elif samples[i].chars[split_char] == -1:
#             missing.add(i)
#         else:
#             Sc.add(i)
    
#     for i in missing:
#         s_score = 0
#         sc_score = 0
#         for j in S:
#             for l in range(k):
#                 if samples[i].chars[l] > 0 and samples[i].chars[l] == samples[j].chars[l]:
#                     s_score += 1
#         for j in Sc:
#             for l in range(k):
#                 if samples[i].chars[l] > 0  and samples[i].chars[l] == samples[j].chars[l]:
#                     sc_score += 1
#         if s_score/len(S) > sc_score/len(Sc):
#             S.add(i)
#         else:
#             Sc.add(i)
            
#     if not Sc:
#         if len(S) == len(subset) or len(S) == 0:
#             print('One side of split empty')
#             print(F)
#             print(char, state, len(subset))
#         return S
    
#     return S

In [34]:
path = "/data/yosef2/users/richardz/projects/benchmarking/400cells_base_higher_herit/"

num = 0

dropout_cm = pd.read_csv(path + "dropout_cm" + str(num) + ".txt", sep = '\t', index_col = 0)
dropout_cm = dropout_cm.applymap(str)

samples = []
for index, row in dropout_cm.iterrows():
    node = Node(list(row), 1001, parent=None, left=None, right=None)
    samples.append(node)

subset = list(range(len(samples)))
prune_samples = remove_duplicates(samples, subset)

node_map = {}
for i in prune_samples:
    node_map[i] = list(dropout_cm.iloc[i,:])

net = pic.load(open(path + "dropout_net" + str(num) + ".pkl", 'rb'))
ground = net.network
leaves = [n for n in ground if ground.out_degree(n) == 0 and ground.in_degree(n) == 1]
ground_net_map = {}
for i in node_map:
    for j in leaves:
        if node_map[i] == j.char_vec:
            ground_net_map[j] = i
            break
ground = nx.relabel_nodes(ground, ground_net_map)

dropout_cm = dropout_cm.replace("-", -1)
dropout_cm = dropout_cm.replace("*", -1)
dropout_cm = dropout_cm.apply(pd.to_numeric)

samples = []
for index, row in dropout_cm.iterrows():
    node = Node(list(row), 1001, parent=None, left=None, right=None)
    samples.append(node)
    
priors = pic.load(open(path + "priors" + str(num) + ".pkl", 'rb'))
recon = build_tree_sep(samples, method="egreedy", p = 0.007550225842780328, qs = priors)
trip2 = triplets_correct_at_depth_sep(recon, ground, 'aggregate')
trip2

22641.428284555175
22259.626770999228
22871.96841828107
22084.516599664486
22919.257505775535
22887.73327025321
22871.970068294606
22520.20721145006
22887.72889969066
22879.848892757334
22887.73565716106
22403.13840263782
22903.498098453685
22903.49367078103
22848.32013053478
22903.49012339235
22879.849434545406
22879.854532863472
22927.131642155786
22847.16024856683
22895.618758453344
22887.736326386057
22919.256948858292
22879.85513870782
22871.96873095202
22887.75575122211
22288.003805506265
21839.968271689202
22486.458123712946
22008.3454165284
22585.177891703584
22425.133701387247
22683.3999748998
22595.97677755225
22502.94831855459
22631.18164658103
22601.602774010145
22302.493738162037
22114.45120814791
22675.53294390435
22667.670442954033
22644.05354792428
22659.79315175708
22395.427802458522
22572.546041940495
22691.27548260676
22644.054126212017
22714.892556241608
22691.265574639303
22667.66469510341
22623.54698412058
22659.793815669214
22659.79144016093
22693.568074398954
22

27428.460225710427
27396.062992195235
27371.782820545315
27412.250904683882
27412.26418279492
27351.807508133963
27404.163706458545
27327.42728943192
27030.458798610416
27371.767600335505
27396.058073340257
27420.353143792618
27387.973028768607
27345.694314135588
27357.878686523352
27387.964094836505
27396.05920466423
27316.13344959397
26908.179758034374
27363.66709808438
27412.25888179591
27412.25378900777
27213.31431970283
27420.34686585226
27315.225000126324
27387.95541371328
27304.672505443075
27412.253758700015
27309.121417975988
27404.152108823106
27436.55007232997
27396.05990738416
27420.354769463258
27412.271201593754
27284.780291222425
27379.851559734874
27444.641536556792
27387.954801327003
27444.632375033383
27396.07244772138
27412.254907553623
27387.962665268074
27412.26579780476
27404.16372152881
27436.555680246598
27404.16066621706
27428.456075009457
26834.806820774927
25759.63376861705
27396.06462011843
26537.173892481558
27135.01963563381
27363.667451263955
26787.482166

18322.424643689566
18492.421774042457
18613.713724346337
18598.495582925196
18606.0987304883
18613.706964331785
18606.111049768948
18552.84568364697
18613.71362910464
18343.169367032144
18526.052350936196
18397.378864040795
18613.702704744814
18583.274300157846
18598.49387902454
18613.715955217565
18568.055633184442
18583.271953262083
18531.625237356136
18590.881538295067
18269.71128159876
18598.48499110466
18628.936765233982
18322.424643689566
18613.713724346337
18575.663650454077
18583.28289287162
18590.88429152524
18613.708423648448
18590.879028061707
18598.49251495947
18451.605316939866
18606.100251936397
18598.490445257845
18636.53812896625
18560.454002274484
18606.10691149917
17748.672809914977
18552.836567674378
18583.28446865452
18590.880757286242
18575.668415461
17896.206544373945
18560.451394639203
18542.832760858524
18590.88673592648
18613.716571718433
18598.497631745464
18583.271953262083
18590.881538295067
18590.879601856737
18621.322319479234
18514.8251135245
18412.173247

11450.682135113115
11285.85278080921
11941.391460942792
11639.125664903002
11948.436027182608
11955.484522894727
11969.573835934152
11920.251098695742
11990.708276135105
11822.247235665789
11948.439118217671
11615.67268618617
11253.538932620473
11785.04677690833
11385.684846334796
11955.47997952759
11880.025421328788
11470.412983152244
11948.436065278842
11941.394412275678
11934.345376435203
11705.746693331343
11920.255451168421
11962.528125091823
11920.255070243968
11962.523179268966
11941.391543118727
11934.345707806173
11934.344522071531
11888.416872972417
11428.073827680168
11818.938563710211
11979.301876319398
11964.163039616315
12028.68575139667
11953.656131684718
12042.792230668429
12049.845028088392
12014.571730286896
11845.130739417735
12042.796470281735
12035.740188950149
12042.791701817532
12035.74098790758
12035.73800812935
11943.989777581923
12049.851191314217
15179.175850107395
14158.827829843456
14314.409222836073
15149.779293040678
15142.426862789001
14907.914489887615


2816.3367072185456
2685.279095682984
2794.339568616096
2805.337728928575
2799.8406595541737
2783.3430796261864
2821.835703815067
2768.302026137555
2783.3423132316047
2980.7365328165492
2969.6182857435533
2959.992716025118
2931.760755817434
2980.7365328165492
2571.281605607893
2991.851511957129
2959.991818975439
2959.992716025118
2986.2949021320937
2578.6412728029004
2975.177120122968
2137.934081673081
2128.5867230193635
2183.980043370652
2189.135915383476
2183.9800933118595
2183.9804051952797
1878.0395277669893
2507.940721364045
2725.686955009065
2620.4106579301806
2845.3429693494536
2881.7625597882998
2914.9353813921152
2797.9346172906644
2866.340384121711
2920.4649181905734
2620.6774966307894
2677.280484276218
2696.965842632463
2702.4036871734897
2740.461729665603
2694.304340464779
2729.58701799032
2673.8748606473305
2735.0233215362655
2450.4929343688077
2480.4353116265597
2580.1826969793065
2735.0233215362655
2792.4647611337555
2903.876058696995
2909.405093167121
2892.818101577939
2

588.3156540887553
602.4815449160204
682.9778529754789
682.9778529754789
682.9778529754789
682.9778529754789
619.0987647508691
619.0987647508691
619.0987647508691
655.7298848479724
671.3000620089807
651.8373119964355
682.9778529754789
0
0
0
0
534.3339117283442
534.3339117283442
508.6004376308877
523.305171915861
534.3339117283442
534.3339117283442
534.3339117283442
515.622107599872
451.00292998275546
462.96260111276047
473.6221075998721
452.4203613981845
508.6004376308877
519.6289103096378
504.92412742478007
534.3339117283442
464.664385569327
482.3845821335837
534.3339117283442
534.3339117283442
534.3339117283442
534.3339117283442
473.6221006685174
473.6221006685174
473.6221006685174
508.6003541499038
523.3051804911389
504.92412742478007
534.3339117283442
63.99999999999988
63.99999999999988
63.99999999999988
63.99999999999988
63.99999999999988
63.99999999999988
63.99999999999988
47.99999999999989
63.99999999999988
63.99999999999988
63.99999999999988
63.99999999999988
63.99999999999988
6

10786.138180930573
9502.635449465395
9682.072357018811
9701.491849965541
9682.075944546765
9759.6163394191
9773.331548026403
9766.473810863348
9701.491494769645
9706.34609841667
9739.042884477172
9766.4800435026
9701.490161868767
9759.61816375735
9587.064230178192
9759.615974916993
9780.19053094235
9759.620662931393
9766.47745325203
9739.04655186331
9686.953923904988
9773.336021752646
9578.972169823297
8667.465556649957
9368.003126735017
9780.193272330687
9682.075944546765
9673.147243762114
9395.8359543736
9535.908092289043
9706.34609841667
9739.042884477172
9752.756627466895
9780.19053094235
9739.04655186331
9766.473798725443
8848.96519964438
9368.027124281052
8667.465556649957
9749.954609294926
9461.84271967119
9535.908092289043
9706.34609841667
9739.042884477172
9570.447821527458
9773.333082706602
1324.6862519221145
1397.3762179697367
1431.662732419739
1407.1726702482156
1324.6862519221145
1324.6862519221145
1358.6255846641752
1397.3763353644556
1384.3492602721674
1446.357613992562


287.51665334852794
287.51665334852794
306.62297650245995
287.51665334852794
219.45294136141212
225.98501469910184
216.4683046341999
225.15959772580635
229.25100921998512
235.78307166835503
225.15959772580635
254.36092617759314
254.36092617759314
249.85882138959602
267.39231540161046
260.55741002870747
270.8097687177658
249.85882138959602
260.92087491547693
270.80972783879946
254.36092617759314
254.36092617759314
270.8097425830015
254.36092617759314
37.99999999999993
37.99999999999993
39.99999999999992
0
98.69225418215798
106.68214586229453
112.00874406659491
153.82985979357414
133.55015947569044
142.2414708014726
153.82985979357414
141.36565510657027
142.24146250162727
0
0
0
39.99999999999992
73.7056487071902
64.2019997974893
73.7056487071902
61.82608567167313
64.20199840556096
39.99999999999992
31.999999999999943
39.99999999999992
31.999999999999943
0
0
0
7467.679911250651
7330.3697156503795
7630.636018771623
7907.916336657419
7927.638972497016
7793.975914528785
7907.918371907907
7760

420.3242513026775
380.3969063905693
388.1745403478918
381.05362975828376
388.1745403478918
381.05362975828376
388.1745403478918
399.0051983434567
388.1745403478918
420.32426165876245
380.3969063905693
220.16733711787091
210.33565322195116
89.99999999999983
80.83337903201509
80.83337946207351
78.45746781175677
43.999999999999915
43.999999999999915
43.999999999999915
0
0
199.68138469302872
199.68138469302872
205.9776987038561
199.68138469302872
199.68138469302872
205.9776930919343
199.68138469302872
199.68138469302872
194.38930049551124
194.38930049551124
200.1834969488384
199.68138469302872
199.68138469302872
199.68138469302872
199.68138469302872
199.68138469302872
199.68138469302872
199.68138469302872
199.68138469302872
194.38930049551124
199.68138469302872
148.3370420511189
148.3370420511189
143.96834437686053
148.3370420511189
143.96834151822208
148.3370420511189
148.3370420511189
148.3370420511189
138.64174668343554
148.3370420511189
148.3370420511189
148.3370420511189
148.337042051

103.9999999999998
92.71294872490968
53.9999999999999
51.9999999999999
51.9999999999999
112.00874852501596
112.00875251472284
112.00874852501596
117.33534862148527
112.00875251472284
118.08156475426779
112.00875251472284
109.34545573279996
112.00875251472284
76.08156089624029
83.99999999999983
76.08156089624029
114.67204411085257
118.08156475426779
112.00875251472284
112.00874852501596
112.00874852501596
109.34545573279996
114.67204411085257
118.08156475426779
112.00875251472284
112.00874852501596
78.45747691059358
78.45747691059358
78.45747691059358
78.45747691059358
0
0
0
0
0
0
0
41.999999999999915
45.99999999999989
41.999999999999915
41.999999999999915
41.999999999999915
41.999999999999915
39.999999999999915
41.999999999999915
41.999999999999915
41.999999999999915
39.999999999999915
41.999999999999915
41.999999999999915
0
0
0
0
0
0
0
3871.842616068577
4183.713991392547
4166.10889951714
4166.107853361268
4154.370552388679
4021.412808490484
4142.633946303995
4142.634075237186
3959.4244

748.3600383902956
760.9052370773207
857.866453627353
831.9665438111124
748.3600383902956
827.6500966288154
831.9665059557489
641.98831833412
689.44584408759
689.44584408759
697.7713764913025
827.6498070785783
836.2831807288234
827.6500966288154
666.0949842962431
689.44584408759
689.4457159581893
825.6302764768395
823.3331497275267
840.5997401708275
823.3331644359411
807.6951657087932
825.6302764768395
823.3331497275267
831.9665438111124
831.9665059557489
731.679755377779
831.9665438111124
827.6500966288154
731.072921734113
700.7348032960401
722.7474583849583
690.0277427384982
731.072921734113
658.132475018213
658.1328744116054
647.5172739013074
722.7474889852087
658.1328744116054
55.99999999999989
53.9999999999999
53.9999999999999
53.9999999999999
53.9999999999999
0
0
0
538.9535122552014
535.0609577897667
531.168326897105
477.2436516498549
527.2757899834978
527.2759219594614
618.1843410064537
560.7240399023542
642.65596369318
560.7240399023542
610.456188258893
618.1843410064537
618.184

1205.3260155651926
1222.6488807919002
1231.8089118328019
1099.100187699593
1240.9685249925717
1222.648684150445
1212.8832964976148
1240.9685249925717
1236.388530480377
1245.5481744005654
1231.808603547212
1222.6488807919002
1205.326034287948
1236.3885655163206
1222.648684150445
969.1523857574723
971.1200322796793
1086.2643936829515
1072.9001973453346
1104.0831753745342
1072.9001973453346
1006.8667991502831
1064.5808522055986
1062.192862615736
1095.1739167431645
1077.3549244218218
1095.1735920666933
1160.8828830387515
1147.3266648390209
1130.9003390299347
1160.8828830387515
1160.8828506185046
1108.9082257769664
1245.5481744005654
1222.648684150445
1212.8832964976148
1222.6488807919002
1227.2287907005102
1236.3889020326362
1236.3885655163206
1205.3260155651926
1212.8832964976148
1240.9685249925717
1205.3260155651926
1006.7634412041074
1165.4012707675474
1124.9008998550478
1117.53589257838
1212.8832964976148
1222.6488807919002
1210.3638719357807
1236.3885655163206
1236.388530480377
1077.4

2056.876568610544
1981.4672529691286
1826.371749704771
1984.876562678826
1978.2354066592563
1944.2446915047776
2005.9475287007692
1725.8969375695294
1963.8035418384902
1918.6879505467598
1881.914957413132
1942.7312033775706
1947.9984821671496
1921.6597683999616
1963.8035418384902
1953.2673479450195
1937.4630064823536
1953.2664988811425
1915.4561042368894
1942.7314343562666
1932.1948299804508
1921.6597683999616
2318.6132519831517
2307.8036010231694
2318.613243546681
2292.4840973284136
2318.6127843517525
2329.4233048243304
2285.7409908741074
2296.995264082799
2070.9450172149645
2340.2330385509094
2138.4869049623203
2261.228463640172
2250.484656008714
2218.667899665381
2261.2286941938823
2235.3583786117624
2271.9722259243836
2250.484738677493
2266.6002838041245
2239.7394796717276
2239.740600092775
2214.004882352235
2239.7398174590403
2228.996253560646
2245.1116615145834
2245.111798687272
2013.9241525181803
2088.5435084169303
2359.8772721755613
2381.6246222856416
2120.240876141993
2147.481

631.1666870741731
647.8175715637988
625.4764714656062
639.4921568337579
618.177236590213
702.7914388797531
707.0335872861032
707.0333433213821
683.49198383847
647.8176902657561
647.8176921384504
639.4921853275231
656.1432882264775
601.8695689616322
593.7123352251282
610.0267934533457
605.94816877485
647.8174586723759
651.9802760426444
647.8175363851502
651.98030759997
639.4922480006869
643.6548983639458
656.1431494643642
565.4482582714933
639.4921938851242
631.1666889468673
647.8175734364929
639.4921631771061
643.6548964912529
643.6548296038515
639.4921568337579
685.6549288860994
620.586574606501
702.7914388797531
715.5172550896959
698.5495249181077
698.5496117235505
702.7915488430668
707.0333433213821
291.3145874468275
292.8472904839249
342.14941677046215
304.98436816193396
304.98436871967795
298.1494768083683
301.56688444335157
294.7319925320431
308.4017732471552
349.2612078485515
298.14949548255237
298.1494793581371
291.3145880045715
298.14947020846046
343.56691982985575
318.5033339

7869.63170145942
8110.931213501719
8170.830989300564
8211.665427132495
8059.002089289409
7869.311653411932
8288.89137158374
8302.757292377662
8259.444743277149
8179.907526667828
8295.82659469209
8302.758847039242
8229.915806888375
8302.76194869457
7806.790491779411
7491.805818414753
7765.309365658246
7505.539695927185
7794.299296212677
7842.3721089437695
7912.455740455892
7883.194364339084
7748.153935738413
7884.892724856906
7777.515282901198
7794.307113893879
7775.352416581431
7933.130273489176
7864.21149153936
7891.784730160149
7852.138937097196
7912.455083187376
7898.670654064103
7465.7229506757685
7726.341704475986
7905.561329292591
7919.348631048547
7891.781352083719
7808.210046561931
7842.377978545182
7905.563248334104
7842.374586525525
7905.561857841798
7797.818017446439
7775.352416581431
7878.000092098408
7861.892237823919
7891.781987413839
7891.787426815292
7919.349043210643
7884.891404640698
7891.792462197632
180.80049853844434
176.99482896227167
193.32495089987273
175.617805

226.3833421786679
229.80085483833764
229.80085483833764
229.80085483833764
186.79291415422873
178.74148881562797
186.79291415422873
193.32498026569755
186.79291415422873
233.21820478875847
229.80085483833764
206.40242108437263
146.91451125413715
148.98985301572267
190.05888696756816
180.2608104350186
190.0589577282409
180.80046898760338
66.57791782985626
37.99999999999993
82.71249029228447
80.0491957243049
82.71250571486088
80.0491957243049
82.71250571486088
82.71250571486088
82.71250571486088
49.946532878549974
49.946532878549974
54.69835898180824
49.946532878549974
85.37578938421862
82.71250571486088
49.946532878549974
82.71249029228447
74.72259672160747
82.71250571486088
49.94652780101622
47.5706163735961
49.946532878549974
47.5706163735961
49.946532878549974
49.946532878549974
49.946532878549974
25.999999999999947
52.322440717854064
49.946532878549974
49.94652780101622
49.946532878549974
25.99999999999995
25.999999999999947
25.999999999999947
25.999999999999947
25.999999999999947
2

1639.2769937965386
1884.955860392655
1890.111039166122
1808.7898187683493
1693.7897552570073
1809.771433557905
1751.111774006851
1813.905232571224
1808.7898187683493
1819.0208085872753
1798.5574289939611
3073.316785452192
3085.4518700612875
3123.662367218357
3116.968844254887
3190.557384138414
3166.691338230569
3184.784784361252
3167.4666107059243
3196.330763319557
3179.013498988273
3173.2376701731155
3179.0118479408357
3166.6907554321465
3106.929023158614
3179.0124453523276
3202.1034158355656
3184.78538154781
3179.012375970783
3070.2857281317524
3073.316785452192
2980.132712643032
3207.876515936457
3213.6485785043233
3179.012375970783
2566.60013229995
2293.0498159622366
2599.773324571106
2504.9421981795344
2415.711323046088
2588.714633263746
2583.1862036257053
2566.599256238388
2236.4235598147498
2500.0298442838157
2524.3923357728645
2541.887179199078
2583.185903169493
2984.63542655713
2878.7413537811103
2843.4051036570463
2722.0302441684785
2930.6852988338046
2924.0967940632027
3013.

78.45747951590891
76.08156664885776
76.08156664885776
78.45747951590891
76.08156664885776
76.08156664885776
45.9999999999999
76.08156664885776
78.45747951590891
76.08156664885776
76.08156664885776
45.99999999999991
53.999999999999886
45.99999999999991
87.96112775920548
76.08156664885776
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
410.7149396862737
395.55984245139484
376.7539141889521
365.33040853951263
376.7539141889521
922.6804843792349
912.2656914436952
979.909869746007
966.1696791381688
912.2656914436952
931.2313681708102
922.6804843792349
979.9095815516256
828.3059731479274
745.6396518515332
826.480992843612
832.6932253511493
828.3059731479274
828.3059731479274
779.1394805334
800.0812128249837
841.4682623091614
845.8554946182511
398.3136030482024
380.3169307400662
405.6659805639326
398.3136030482024
1003.6627441755817
1039.0696722988368
1034.431315336855
904.8234637321937
978.2470860747335
1039.06978263738
1043.7080401872756
970.7500982813268
889.5241825636352
975.4607121161421
970.7500982813268

2986.476054379626
2992.3464315611773
405.6658573449536
431.3991762881495
396.06970988084714
427.72285994932355
414.579759277585
420.37055455765517
377.94318700588934
431.3990974056606
405.6658573449536
349.2765900045711
334.9893541216245
349.2765900045711
431.3990974056606
427.72285994932355
405.44330358857565
396.06970988084714
413.0181727320402
352.8139245007762
150.21659038746304
151.9581806190295
151.9581806190295
151.9581806190295
151.9581806190295
141.3050039762308
349.2765900045711
341.2656309834623
381.2390910192823
405.44330358857565
377.94318700588934
427.72285994932355
424.0468141006904
354.0215608954943
331.8947265304157
367.0265697861854
331.8947265304157
374.13280744898
361.23632760311784
331.8947265304157
198.50456184882202
184.85223774086063
192.03566572219822
189.22052250420663
183.03117350259504
198.50453634070544
184.85223774086063
183.03117350259504
184.85223774086063
198.50453634070544
187.43869066552784
184.85223774086063
183.03117350259504
183.03117350259504
184.

1020.2218879237978
893.1555394306674
1515.9115839470817
1480.9997872482402
1521.2165105699396
1537.1260887367002
1526.5201902429808
1371.3249981178997
1526.519237891569
1537.1261882009655
1515.9129892515853
1401.068890928793
1444.6146605561578
1467.23221700229
1463.3364867608802
1468.6064711675679
1489.6764866564502
1403.9933335604928
1473.8729559676765
1484.4087376557925
1010.684238671769
1104.6605754694772
1072.1685289627455
1064.7791870405645
1016.406856332015
1028.2099210381537
996.0044786949504
1055.1195600272338
1069.7244949722365
1049.9460627862723
1084.5578473597493
1054.8912885822567
441.92243679130905
471.3537380348401
467.2746826760256
437.4998193718773
479.5105345112323
483.58896287443997
471.3534989668129
1516.0928918602895
1496.2747988766191
1542.1042871471925
1536.7659801056811
1448.7367849753423
1552.7800737536356
1454.7888476185572
1536.7657902129379
1520.6163467104636
1485.9131390987488
1499.5760103891446
1536.7656425984264
1531.4271630440103
1520.7540344028052
1542.1

233.2178447526791
246.8875656183652
243.47018303263783
236.63535774853938
246.88759496605394
236.63535774853938
223.7848952543069
230.102647859165
233.2178447526791
236.63535774853938
39.999999999999915
33.999999999999936
88.03905102753009
88.03905102753009
98.69224582566173
88.03905102753009
88.03905102753009
31.99999999999994
31.99999999999994
88.03905102753009
98.69224790220434
88.03905102753009
88.03905102753009
96.02894869413608
99.4501702867208
88.03905102753009
88.03905102753009
0
59.45017028672088
59.450170970367544
57.07425801529478
59.45017028672088
31.99999999999994
31.99999999999994
31.99999999999994
0
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
64.20199846924308
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
57.07426610299059
64.20199803670155
57.07426610299059
64.2019995997647
57.07426610299059
57.07426610299059
57.07426610299059
3

0.5391333333333334

In [36]:
for folder in ["400cells_base_higher_herit"]:
#     folder = "400cells_" + drop + "_drop"
    path = "/data/yosef2/users/richardz/projects/benchmarking/" + folder + "/"
    nums = []
    triplets = []
    triplets_new = []
    colless = []
    types = []
    methods = []

    for method in ["egreedy", "greedy"]:
        for num in range(0, 10):
            dropout_cm = pd.read_csv(path + "dropout_cm" + str(num) + ".txt", sep = '\t', index_col = 0)
            dropout_cm = dropout_cm.applymap(str)

            samples = []
            for index, row in dropout_cm.iterrows():
                node = Node(list(row), 1001, parent=None, left=None, right=None)
                samples.append(node)

            subset = list(range(len(samples)))
            prune_samples = remove_duplicates(samples, subset)

            node_map = {}
            for i in prune_samples:
                node_map[i] = list(dropout_cm.iloc[i,:])

            net = pic.load(open(path + "dropout_net" + str(num) + ".pkl", 'rb'))
            ground = net.network
            leaves = [n for n in ground if ground.out_degree(n) == 0 and ground.in_degree(n) == 1]
            ground_net_map = {}
            for i in node_map:
                for j in leaves:
                    if node_map[i] == j.char_vec:
                        ground_net_map[j] = i
                        break
            ground = nx.relabel_nodes(ground, ground_net_map)

            dropout_cm = dropout_cm.replace("-", -1)
            dropout_cm = dropout_cm.replace("*", -1)
            dropout_cm = dropout_cm.apply(pd.to_numeric)

            samples = []
            for index, row in dropout_cm.iterrows():
                node = Node(list(row), 1001, parent=None, left=None, right=None)
                samples.append(node)
        #     sample_map = {}
        #     for i in range(len(samples)):
        #         sample_map[i] = samples[i]
        
            if method == 'egreedy':
                priors = pic.load(open(path + "priors" + str(num) + ".pkl", 'rb'))
                recon = build_tree_sep(samples, method="egreedy", p = 0.007550225842780328, qs = priors)
            else:
                recon = build_tree_sep(samples, method=method)
            trip = triplets_correct_stratified(recon, ground)
            trip2 = triplets_correct_at_depth_sep(recon, ground, 'aggregate')
            print(num, trip, trip2, method)
            triplets.append(trip)
            triplets_new.append(trip2)
            nums.append(num)
            colless.append(get_colless(ground)[0])
            types.append(folder)
            methods.append(method)

    data = [nums, triplets, triplets_new, colless, methods, types]
    df = pd.DataFrame(data)
    df = df.T
    df = df.rename(columns = {0: 'Run', 1: 'TripletsCorrect', 2: 'TripletsCorrect2', 3:'Colless', 4:'Method',5: 'Type'})
#     df.to_csv(path + 'methods_triplets.txt', sep = '\t', index = False)

0 0.6633748715024531 0.5460666666666667 egreedy
1 0.7777529259411957 0.6186666666666667 egreedy
2 0.7270518244512935 0.5568166666666666 egreedy
3 0.7568456047406811 0.5622833333333334 egreedy
4 0.7438180121364876 0.5629833333333333 egreedy
5 0.7465887311411259 0.6021692307692308 egreedy
6 0.6236969603500747 0.5791428571428572 egreedy
7 0.5017815139486564 0.49557142857142855 egreedy
8 0.44327123161655124 0.6063846153846154 egreedy
9 0.661529577816303 0.5948 egreedy
0 0.8016572759285835 0.7635666666666666 greedy
1 0.8184715119477257 0.7441833333333333 greedy
2 0.8769611695135737 0.78985 greedy
3 0.824382351164874 0.7332166666666666 greedy
4 0.7138941682886175 0.7101 greedy
5 0.7754950895217146 0.7599384615384616 greedy
6 0.6378550772146505 0.6839285714285714 greedy
7 0.7740110913193782 0.7474428571428572 greedy
8 0.8156946097013457 0.7647076923076923 greedy
9 0.8289273285709466 0.7513230769230769 greedy


In [None]:
df.to_csv(path + 'egreedy_triplets.txt', sep = '\t', index = False)