In [None]:
import networkx as nx
from cdt.causality.graph import SAM
from cdt.data import load_dataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import japanize_matplotlib
import decimal
import semopy as sm
from semopy import inspector

from tqdm.notebook import tqdm
import pickle
import os

np.set_printoptions(precision=5)
pd.options.display.max_columns=1000
pd.options.display.max_rows = 1000
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

In [None]:
def make_mod(sam_matrix, threshold, origin_data_path=DATA_PATH):
    '''
    SEMへ入力するためのオブジェクトを生成する関数
    
    Parameters
    ----------
    sam_matrix: ２次元リスト
        SAMから出力されたマトリクス
        
    threshold: float
        エッジの生成確率に対する閾値
        
    Returns
    -------
    mod: str
        SEMへ入れるための因果関係を表す文字列
    '''
    df_origin = pd.read_csv(origin_data_path)
    array_sam_01data = np.where(sam_matrix > threshold, 1, 0)
    columns = list(df_origin.columns)
    edge_model = []

    for i, row in enumerate(array_sam_01data):
        for j, value in enumerate(row):
            if value == 1:
                edge_model.append(columns[j].replace(' ', '_') + " ~ " + columns[i].replace(' ', '_'))
    mod = '\n'.join(edge_model)
    return mod

def cal_fit_index(mod, origin_data_path=DATA_PATH, detail=True):
    '''
    構造方程式モデリングにおける最適化指標を計算する
    
    Parameters
    ----------
    mod: str
        make_modで生成された因果関係を表すオブジェクト
    '''
    np.set_printoptions(precision=10, suppress=True)
    
    df_origin = pd.read_csv(origin_data_path)
    df_origin = df_origin.set_axis([i.replace(' ', '_') for i in df_origin.columns], axis=1)
    
    model = sm.Model(mod)
    model.load_dataset(df_origin)
    opt = sm.Optimizer(model)
    objective_function_value = opt.optimize()
    
#     if detail:
#         with open('./sem_graph/pickle_files/inspector.pickle', 'wb') as f:
#             pickle.dump(sm.inspector.inspect(opt, std_est=True), f)
    
    model.fit(df_origin)
    # calc_stats
    stats = sm.gather_statistics(opt)
    
#     if detail:
#         with open('./sem_graph/pickle_files/calc_stats.pickle', 'wb') as f:
#             pickle.dump(sm.calc_stats(model).T, f)
            
    return np.round([stats.rmsea, stats.gfi, stats.agfi, stats.aic], 5)

def cal_rmsea(file_name, threshold=0.5, origin_data_path=DATA_PATH):
    '''
    最適化指標の中でrmseaのみ計算する関数
    
    Parameters
    ----------
    file_name: str
        SAMから出力された結果のファイル名
    '''
    df_origin = pd.read_csv(origin_data_path)
    
    sam_matrix = pd.read_csv(f'{SAM_RESULT}/{VERSION}/{file_name}.csv')
    
    mod = make_mod(sam_matrix, threshold)
    
    return_index = 0
    
    # RMSEA cal not
#     try:
#         rmsea, gfi, agfi, aic = cal_fit_index(mod, detail=False)
#         return_index = rmsea
#     except:
#         return_index = 0.1
    rmsea, gfi, agfi, aic = cal_fit_index(mod, detail=False)
    
    if 'pH' not in mod:
        return_index += 0.1
        
    if check_acycle(file_name, threshold):
        return_index += 0.1
        
    if check_pH_cause(file_name, threshold):
        return_index += 0.1
    
    return return_index

# SEM詳細表示
def SEM(file_name, threshold, sam_result=SAM_RESULT, sem_result=SEM_RESULT, 
        version=VERSION, origin_data_path=DATA_PATH, show_cal_exp=False):
    '''
    因果効果の推定，パス図の描画，適合度指標の出力を行う
    
    Parameters
    ----------
    file_name: str
        SAMの結果のファイル
        
    threshold: float
        エッジの生成確率に対する閾値
    '''
    check_dir(sem_result)
    check_dir(f'{sem_result}/{version}')
    
    graph_folder = f'{sem_result}/{version}'
    
    df_origin = pd.read_csv(origin_data_path)
    df_origin = convert_df_under_bar(df_origin)
    
    png_name = f'{file_name}'

    df = pd.read_csv(f'{sam_result}/{version}/{file_name}.csv')
    mod = make_mod(df, threshold)
    # print(mod)
    
    rmsea, gfi, agfi, aic = cal_fit_index(mod)
    print(f'RMSEA: {rmsea}')

    m = sm.Model(mod)
    m.fit(df_origin)
    display(sm.semplot(m, 
            f'{graph_folder}/rmsea{str(rmsea).replace(".", "")}_thr{str(threshold).replace(".","")}_{png_name}.png',
            plot_ests=True))
    # with open('./sem_graph/pickle_files/calc_stats.pickle', 'rb')as f:
    #     print(pickle.load(f))
    # print('')
    # with open('./sem_graph/pickle_files/inspector.pickle', 'rb')as f:
    #     print(pickle.load(f))
    # print('')
    if show_cal_exp:
        print(make_calculated_expression(df, threshold))
        
    return rmsea, gfi, agfi, aic