In [None]:
import graspy
from mgcpy.independence_tests.mgc import MGC
from mgcpy.independence_tests.dcorr import DCorr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy as sp
%matplotlib inline
from graspy.simulations import sbm
from graspy.plot import heatmap
np.random.seed(10)

#Generating simulation/test data
def data_generator(m_n_0, ss_m_n, p1, p2, p3, p4, p5, num_graphs):
    data = {}
    male = {}
    female = {}
    ind_male = 0
    ind_female = 0
    #This creates dictionaries of test data
    for i in range(0,num_graphs):
        n = [ss_m_n, m_n_0 - ss_m_n]
        if i % 2 == 0:
            p_male = [[p1, p2], [p3, p4]]
            male["male{0}".format(ind_male)] = sbm(n=n, p=p_male)
            data["male{0}".format(i)] = sbm(n=n, p=p_male)
            ind_male += 1
        else:
            p_female = [[p5, p2], [p3, p4]]
            female["female{0}".format(ind_female)] = sbm(n=n, p=p_female)
            data["female{0}".format(i)] = sbm(n=n, p=p_female)
            ind_female += 1
    return male,female,data

male, female, data = data_generator(200, 20, .30, .2, .2, .3, .45, 100)

#Plot
heatmap(female["female1"], title='Female Graph Simulation')

plt.figure()
heatmap(male["male1"], title='Male Graph Simulation')

#Uses L_class_tool on all entries in dictionary
def a_dict_to_p_dict(dict, dimvec, prob_mat_0, prob_mat_1, key_0, key_1, ss_n_m):
    dict_new = {}
    count = 0
    for key in dict:
        if key_0 in key:
            dict_new[key] = L_class_tool(dimvec, prob_mat_0, len(dict[key]), ss_n_m)
            count += 1
        else:
            dict_new[key] = L_class_tool(dimvec, prob_mat_1, len(dict[key]), ss_n_m)
            count += 1
    
    return dict_new   

# Expands probability matrices to have the dimensions of the adjacency matrices: makes
# [[0.3, 0.2], [0.2, 0.3]] into a 200 by 200 where the top left 100by100 are all 0.3, etc
def L_class_tool(dim_vec, p_matrix, dim_mat, SS_n):
    test_mat = sbm(n=dim_vec, p=p_matrix)
    val_0 = p_matrix[0][0]
    val_1 = p_matrix[0][1]
    val_2 = p_matrix[1][0]
    val_3 = p_matrix[1][1]
    for i in range(dim_mat):
        for j in range(dim_mat):
            if i < SS_n and j < SS_n:
                test_mat[i][j] = val_0
            elif i < SS_n and j > SS_n:
                test_mat[i][j] = val_1
            elif i > SS_n and j < SS_n:
                test_mat[i][j] = val_2
            else :
                test_mat[i][j] = val_3
    return test_mat

def classifier_method(num_graphs, num_graphs_0, num_graphs_1, P_hat_0, P_hat_1, vec, m_n_0, ss_m_n, A_mat):
    
    pi_coef_hat_0 = num_graphs_0 / num_graphs
    pi_coef_hat_1 = num_graphs_1 / num_graphs
    
    P_0 = L_class_tool(vec, P_hat_0, m_n_0, ss_m_n)
    P_1 = L_class_tool(vec, P_hat_1, m_n_0, ss_m_n)
    
    count = 1
    count_1 = 1
    for i in ss_m_n:
        for j in ss_m_n:
            prod = ((P_0[i][j])**(A_mat[i][j]))*((1 - P_0[i][j])**(1 - A_mat[i][j]))
            count = count * prod
            prod_1 = ((P_1[i][j])**(A_mat[i][j]))*((1 - P_1[i][j])**(1 - A_mat[i][j]))
            count_1 = count_1 * prod_1
    
    class_0 = pi_coef_hat_0 * count
    class_1 = pi_coef_hat_1 * count_1
    list_val = []
    list_val.append(class_0)
    list_val.append(class_1)
    return np.asarray(list_val)

def iterative_screen(a_dict, mat_n_m, y_labels, c, delta, opt):

    '''
    Performs iterative screening on graphs.

    Parameters

    ----------
    
    a_dict: the dictionary of adjacency matrices are going to be used for signal subgraph.

    mat_n_m: Length/width of the adjacency matrices in the dictionary
    
    y_labels: the vector of labels.
    
    c: the correlation threshold value.
    
    delta: the quantile threshold value
    
    opt: indicator of whether to use dcorr or mgc

    Returns

    -------

    values_flags: the signal subgraph of a_matrix, found through iterative

    vertex screening.
    '''
     mgc = MGC()
    dcorr = DCorr(which_test = 'unbiased')
    S_hat = np.zeros((mat_n_m,1))
    count = len(a_dict)
    bound = int(count / 2)
    for i in range(mat_n_m):
      mat_hat = np.zeros((count,mat_n_m))
      for j in range(count):
        if j % 2 == 0:
          mat = a_dict["male{0}".format(j)]
        else:
          mat = a_dict["female{0}".format(j)]
        mat_hat[j] = mat[i]
      if opt == "mgc":
        c_u, independence_test_metadata = mgc.test_statistic(mat_hat, y_labels)
        S_hat[i][0] = c_u
      else:
        c_u, independence_test_metadata = dcorr.test_statistic(mat_hat, y_labels)
        S_hat[i][0] = c_u
    
    S_hat = np.absolute(S_hat)
    #as_vec = np.sort(S_hat.reshape(1,200)).reshape(200,1)
    #list_comp = as_vec[-20:]
    #values_flags = []
    #for i in range(20):
    #    values_flags.append(np.where(S_hat == list_comp[i])[0][0])
    print(S_hat)
    values_flags = np.nonzero(S_hat > c)    
    return values_flags
    
def vertex_diff(a_matrix_1, a_matrix_2):
    SS = []
    for i in range(len(a_matrix_1)):
        count = 0
        while count < len(a_matrix_1):
            if a_matrix_1[i][count] != a_matrix_2[i][count]:
                SS.append(i)
                count = len(a_matrix_1) + 1
            else:
                count += 1
    SS.sort()
    return SS