# Import

In [113]:
import csv
import os
import pickle
import sys

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm

## Add configuration file

In [114]:
sys.path.append("/home/jovyan/core/config/")
sys.path.append("/home/jovyan/core/util/")
sys.path.append("../PlotFunction/lineplot/")
sys.path.append("../PlotFunction/config/")

In [115]:
from ALL import config
from line_plot_1 import line_plot_1
from line_plot_error_1 import line_plot_error_1
from line_plot_1_layout import layout
from util import *

## Set condition

In [116]:
tqdm.pandas()
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 50)

In [117]:
data_type = "AgNews"
vectorize_types = ["doc2vec", "sentenceBERT"]

In [118]:
model_nums = config["clustering"]["gmm"]["max_model_num"]
covariance_types = config["clustering"]["gmm"]["covariance_types"]
normalization = "normalized"
covariance_types = ["spherical", "diag", "full"]
vector_dims = {
    "doc2vec": config["vectorize"]["doc2vec"]["dims"],
    "sentenceBERT": config["vectorize"]["sentenceBERT"]["dims"] + [384],
}

In [119]:
stats_vals = ["aic", "bic", "mi", "logl"]

# Stats

## Read data

In [120]:
def load_stats_data(vectorize_type, stats_vals, covariance_types, model_nums):
    # データ型定義
    stats = {
        stats_val: {
            covariance_type: {
                model_num: pd.DataFrame for model_num in range(model_nums)
            }
            for covariance_type in covariance_types
        }
        for stats_val in stats_vals
    }

    # データ取得
    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            for model_num in range(model_nums):
                stats_path = f"../../Postprocessing/data/{data_type}/{vectorize_type}/GMM/stats/{normalization}/{covariance_type}/{model_num}.csv"
                df = pd.read_csv(stats_path, index_col=0)
                stats[stats_val][covariance_type][model_num] = df.loc[:, stats_val]
    return stats

In [121]:
def load_lda_mi(data_type):
    lda_mi = pd.read_csv(
        f"../../Postprocessing/data/{data_type}/LDA/mi.csv", index_col=0
    )
    describe_lda_mi, _ = get_describe(lda_mi, axis=0)
    return describe_lda_mi

In [122]:
stats_dict = {}
for vectorize_type in vectorize_types:
    stats_dict[vectorize_type] = load_stats_data(
        vectorize_type, stats_vals, covariance_types, model_nums
    )

In [123]:
describe_lda_mi = load_lda_mi(data_type)

## Data shaping

In [124]:
def shape_stats_df(stats):
    stats_df = {
        stats_val: {covariance_type: pd.DataFrame() for covariance_type in covariance_types}
        for stats_val in stats_vals
    }

    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            # model_numについて取得データstatsを結合
            stats_df[stats_val][covariance_type] = pd.concat(
                stats[stats_val][covariance_type], axis=1
            )
    return stats_df

In [125]:
def shape_describe(stats_df):
    describe = {
        stats_val: {covariance_type: dict() for covariance_type in covariance_types}
        for stats_val in stats_vals
    }

    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            describe[stats_val][covariance_type], describe_keys = get_describe(
                stats_df[stats_val][covariance_type], axis=1
            )
    return describe, describe_keys

In [126]:
def shape_data(describe, describe_keys):
    data = {
        stats_val: {describe_key: pd.DataFrame() for describe_key in describe_keys}
        for stats_val in stats_vals
    }
    for stats_val in stats_vals:
        for describe_key in describe_keys:
            # covariance_typeについてデータを結合
            _data = {
                covariance_type: describe[stats_val][covariance_type][describe_key]
                for covariance_type in covariance_types
            }
            data[stats_val][describe_key] = pd.concat(_data, axis=1)
    return data

In [127]:
data_stats = {}
for vectorize_type, stats in stats_dict.items():
    stats_df = shape_stats_df(stats)
    describe, describe_keys = shape_describe(stats_df)
    data_stats[vectorize_type] = shape_data(describe, describe_keys)

In [128]:
stats_df["mi"]["full"].style.highlight_max(axis=0)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
2,0.411392,0.369643,0.369701,0.369692,0.369743,0.379131,0.369695,0.411385,0.369695,0.36963,0.411392,0.37891,0.411392,0.369675,0.379047,0.369643,0.369695,0.383159,0.378996,0.378952,0.383186,0.369701,0.369695,0.369695,0.36963,0.369695,0.411392,0.36969,0.379256,0.379014
4,0.60905,0.60905,0.609073,0.611119,0.611101,0.60905,0.60905,0.611101,0.611114,0.60908,0.609073,0.60905,0.60905,0.60905,0.60905,0.609058,0.609058,0.60908,0.611101,0.60905,0.609058,0.609058,0.60905,0.609073,0.60905,0.60905,0.611114,0.609058,0.60905,0.609058
6,0.591362,0.619362,0.6194,0.619362,0.619362,0.6194,0.619407,0.575024,0.6194,0.619392,0.539264,0.591353,0.619362,0.619362,0.619377,0.61937,0.619362,0.6194,0.6194,0.539264,0.591353,0.619385,0.619362,0.619392,0.6194,0.6194,0.591285,0.619392,0.619392,0.619392
8,0.585515,0.525922,0.583474,0.61287,0.591226,0.583474,0.61287,0.612361,0.612878,0.61287,0.612361,0.61287,0.591302,0.575703,0.575649,0.612368,0.612368,0.612811,0.583439,0.612805,0.61287,0.612781,0.583393,0.612368,0.612844,0.612376,0.575697,0.585238,0.525922,0.525922
10,0.523871,0.585483,0.607121,0.579297,0.554333,0.523871,0.607135,0.607106,0.607092,0.579297,0.607106,0.607121,0.585279,0.607121,0.607121,0.607135,0.585472,0.523871,0.607121,0.585483,0.607121,0.607121,0.607091,0.607106,0.579206,0.607106,0.579267,0.607135,0.607121,0.585472
20,0.599257,0.586119,0.599252,0.582847,0.599248,0.545441,0.599248,0.545441,0.599252,0.545441,0.599248,0.545441,0.586095,0.545441,0.599257,0.586095,0.599252,0.599248,0.599244,0.545441,0.582847,0.599252,0.586095,0.545441,0.545441,0.599257,0.599229,0.599252,0.599248,0.545441
40,0.59541,0.542874,0.59541,0.579093,0.59541,0.542874,0.59541,0.59541,0.59541,0.579093,0.59541,0.59541,0.59541,0.59541,0.542874,0.59541,0.59541,0.542874,0.542874,0.542874,0.542874,0.579093,0.59541,0.59541,0.59541,0.59541,0.59541,0.59541,0.59541,0.59541
80,0.578202,0.578202,0.578202,0.541331,0.578202,0.578202,0.578202,0.578202,0.578616,0.578209,0.578616,0.578222,0.578202,0.578202,0.578209,0.578202,0.578202,0.578202,0.578202,0.541331,0.578202,0.541331,0.578616,0.578202,0.578202,0.578202,0.578202,0.578202,0.578202,0.541331
160,0.580713,0.580488,0.583348,0.54479,0.54479,0.580713,0.580713,0.583348,0.580713,0.583348,0.580936,0.580713,0.583348,0.583348,0.580713,0.580713,0.580713,0.54479,0.580713,0.580713,0.580713,0.580713,0.54479,0.580713,0.580713,0.580713,0.580936,0.580713,0.580713,0.580713
384,0.610895,0.610894,0.610836,0.610859,0.610857,0.610907,0.61086,0.512121,0.610834,0.61092,0.610868,0.610875,0.610905,0.610857,0.610867,0.610973,0.610914,0.610844,0.610834,0.610875,0.610911,0.610891,0.610857,0.610834,0.610891,0.610857,0.610926,0.610861,0.610894,0.610896


# Make Chart

In [129]:
best_dim = {"AgNews": 8, "20News": 80}

In [130]:
chart_data = {}
chart_data["doc2vec"] = data_stats["doc2vec"]["mi"]["mean"].loc[best_dim[data_type], :]
chart_data["sentenceBERT"] = data_stats["sentenceBERT"]["mi"]["mean"].loc[384, :]

In [131]:
pd.DataFrame(chart_data)

Unnamed: 0,doc2vec,sentenceBERT
spherical,0.454906,0.582227
diag,0.429935,0.585635
full,0.475413,0.607587


In [132]:
chart_df = pd.DataFrame(chart_data)

In [133]:
chart_df.T.style.format(
    escape="latex", formatter={"document_count": "{:.0f}"}
)

Unnamed: 0,spherical,diag,full
doc2vec,0.454906,0.429935,0.475413
sentenceBERT,0.582227,0.585635,0.607587


In [134]:
chart_df.to_csv(make_filepath(f"../data/{data_type}/CovarianceChart.csv"))

In [135]:
print(
    chart_df.T
    .style.format(precision=3, escape="latex")
    .to_latex(
        column_format="rrrr",
        position="h",
        position_float="centering",
        hrules=True,
        caption="miの比較と埋め込み次元",
        label="table:1",
        multicol_align="r",
    )
)

\begin{table}[h]
\centering
\caption{miの比較と埋め込み次元}
\label{table:1}
\begin{tabular}{rrrr}
\toprule
 & spherical & diag & full \\
\midrule
doc2vec & 0.455 & 0.430 & 0.475 \\
sentenceBERT & 0.582 & 0.586 & 0.608 \\
\bottomrule
\end{tabular}
\end{table}

