# Import

In [1]:
import csv
import os
import pickle
import sys

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm

## Add configuration file

In [2]:
sys.path.append("/home/jovyan/core/config/")
sys.path.append("/home/jovyan/core/util/")
sys.path.append("../PlotFunction/lineplot/")
sys.path.append("../PlotFunction/config/")

In [3]:
from ALL import config
from line_plot_1 import line_plot_1
from line_plot_error_1 import line_plot_error_1
from line_plot_1_layout import layout
from util import *

## Set condition

In [4]:
tqdm.pandas()
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 50)

In [30]:
data_type = "AgNewsTitle"
vectorize_types = ["doc2vec", "sentenceBERT"]

# Read data

In [31]:
model_nums = config["clustering"]["gmm"]["max_model_num"]
covariance_types = config["clustering"]["gmm"]["covariance_types"]

In [32]:
stats_vals = ["aic", "bic", "mi", "logl"]

In [33]:
def load_stats_data(vectorize_type, stats_vals, covariance_types, model_nums):
    # データ型定義
    stats = {
        stats_val: {
            covariance_type: {
                model_num: pd.DataFrame for model_num in range(model_nums)
            }
            for covariance_type in covariance_types
        }
        for stats_val in stats_vals
    }

    # データ取得
    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            for model_num in range(model_nums):
                stats_path = f"../../Postprocessing/data/{data_type}/{vectorize_type}/GMM/stats/{covariance_type}/{model_num}.csv"
                df = pd.read_csv(stats_path, index_col=0)
                stats[stats_val][covariance_type][model_num] = df.loc[:, stats_val]
    return stats

In [34]:
def load_lda_mi(data_type):
    lda_mi = pd.read_csv(
        f"../../Postprocessing/data/{data_type}/LDA/mi.csv", index_col=0
    )
    describe_lda_mi, _ = get_describe(lda_mi, axis=0)
    return describe_lda_mi

In [35]:
stats_dict = {}
for vectorize_type in vectorize_types:
    stats_dict[vectorize_type] = load_stats_data(
        vectorize_type, stats_vals, covariance_types, model_nums
    )

In [36]:
describe_lda_mi = load_lda_mi(data_type)

# Data shaping

In [37]:
def shape_stats_df(stats):
    stats_df = {
        stats_val: {covariance_type: pd.DataFrame() for covariance_type in covariance_types}
        for stats_val in stats_vals
    }

    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            # model_numについて取得データstatsを結合
            stats_df[stats_val][covariance_type] = pd.concat(
                stats[stats_val][covariance_type], axis=1
            )
    return stats_df

In [38]:
def shape_describe(stats_df):
    describe = {
        stats_val: {covariance_type: dict() for covariance_type in covariance_types}
        for stats_val in stats_vals
    }

    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            describe[stats_val][covariance_type], describe_keys = get_describe(
                stats_df[stats_val][covariance_type], axis=1
            )
    return describe, describe_keys

In [39]:
def shape_data(describe, describe_keys):
    data = {
        stats_val: {describe_key: pd.DataFrame() for describe_key in describe_keys}
        for stats_val in stats_vals
    }
    for stats_val in stats_vals:
        for describe_key in describe_keys:
            # covariance_typeについてデータを結合
            _data = {
                covariance_type: describe[stats_val][covariance_type][describe_key]
                for covariance_type in covariance_types
            }
            data[stats_val][describe_key] = pd.concat(_data, axis=1)
    return data

In [40]:
data = {}
for vectorize_type, stats in stats_dict.items():
    stats_df = shape_stats_df(stats)
    describe, describe_keys = shape_describe(stats_df)
    data[vectorize_type] = shape_data(describe, describe_keys)

In [41]:
stats_df["mi"]["diag"].style.highlight_max(axis=1)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
2,0.314672,0.31431,0.314344,0.314664,0.314676,0.314345,0.314336,0.314284,0.314685,0.31431,0.314416,0.314676,0.314409,0.314385,0.31468,0.314672,0.314453,0.314676,0.31468,0.314676,0.314284,0.314484,0.314685,0.314409,0.31468,0.31467,0.314688,0.314416,0.31431,0.314279
4,0.503743,0.484292,0.503763,0.484272,0.484292,0.503743,0.484272,0.484339,0.503717,0.503717,0.503743,0.503743,0.503743,0.484272,0.503743,0.503733,0.484421,0.503733,0.503717,0.484269,0.503743,0.503753,0.484421,0.503763,0.484421,0.503723,0.48441,0.503733,0.503743,0.503753
6,0.509858,0.509705,0.509712,0.488431,0.488431,0.488431,0.509802,0.490024,0.490089,0.509674,0.509712,0.490111,0.489811,0.509712,0.490138,0.509714,0.488431,0.490141,0.509682,0.509712,0.49014,0.509712,0.509714,0.509712,0.509661,0.488431,0.509674,0.490134,0.509718,0.490187
8,0.509863,0.509937,0.478602,0.509937,0.509937,0.509857,0.509926,0.509937,0.509863,0.509863,0.509863,0.509863,0.509863,0.509863,0.509937,0.509937,0.478602,0.509857,0.509937,0.509937,0.478602,0.509937,0.509857,0.509857,0.509863,0.509937,0.509863,0.509857,0.509937,0.509937
10,0.470327,0.470327,0.496485,0.470327,0.470327,0.470327,0.470327,0.496485,0.470327,0.470327,0.496454,0.470327,0.496485,0.470327,0.470327,0.470332,0.470327,0.450864,0.470332,0.496454,0.470327,0.470327,0.470327,0.450864,0.470327,0.470327,0.470327,0.470327,0.470327,0.470327
20,0.462238,0.462152,0.462227,0.462077,0.462209,0.462077,0.462248,0.462077,0.462088,0.462164,0.462082,0.462244,0.462077,0.462093,0.462141,0.462093,0.462236,0.462077,0.462088,0.462248,0.462236,0.462236,0.462077,0.46218,0.462088,0.489393,0.489374,0.462232,0.462088,0.462093
40,0.4808,0.466008,0.466008,0.4808,0.466008,0.4808,0.466008,0.466005,0.466005,0.466008,0.466008,0.466008,0.466008,0.4808,0.466008,0.466008,0.466008,0.466008,0.466005,0.466005,0.466005,0.466008,0.466008,0.466008,0.4808,0.466008,0.466005,0.4808,0.466008,0.466008
80,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.491251,0.475201,0.475201,0.491251,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201,0.475201
160,0.494554,0.477805,0.477805,0.477805,0.494554,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.494554,0.477805,0.494554,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.477805,0.494554


# Make Chart

In [42]:
chart_data = {}
for vectorize_type, _data in data.items():
    mi_max_idx = _data["mi"]["mean"].stack().idxmax()
    chart_data[vectorize_type] = {
        "埋め込み次元": mi_max_idx[0],
        "mutual information": _data["mi"]["mean"].loc[mi_max_idx],
        "分散": _data["mi"]["std"].loc[mi_max_idx],
    }

In [43]:
chart_data

{'doc2vec': {'埋め込み次元': 6,
  'mutual information': 0.216079607330899,
  '分散': 0.0005672788329092607},
 'sentenceBERT': {'埋め込み次元': 80,
  'mutual information': 0.5200915147894267,
  '分散': 4.230357428712255e-05}}

In [44]:
chart_df = pd.DataFrame(chart_data)

In [45]:
chart_df["LDA"] = [
    np.NaN,
    describe_lda_mi["mean"].to_numpy()[0],
    describe_lda_mi["std"].to_numpy()[0],
]

In [46]:
chart_df.T.style.format(
    escape="latex", formatter={"document_count": "{:.0f}"}
)

Unnamed: 0,埋め込み次元,mutual information,分散
doc2vec,6.0,0.21608,0.000567
sentenceBERT,80.0,0.520092,4.2e-05
LDA,,0.02258,0.008628


In [48]:
print(
    chart_df.T.style.format(precision=3, escape="latex").to_latex(
        column_format="rrrr",
        position="h",
        position_float="centering",
        hrules=True,
        caption="miの比較と埋め込み次元",
        label="table:1",
        multicol_align="r",
    )
)

\begin{table}[h]
\centering
\caption{miの比較と埋め込み次元}
\label{table:1}
\begin{tabular}{rrrr}
\toprule
 & 埋め込み次元 & mutual information & 分散 \\
\midrule
doc2vec & 6.000 & 0.216 & 0.001 \\
sentenceBERT & 80.000 & 0.520 & 0.000 \\
LDA & nan & 0.023 & 0.009 \\
\bottomrule
\end{tabular}
\end{table}

