# Import

In [1]:
import csv
import os
import pickle
import sys

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm

## Add configuration file

In [2]:
sys.path.append("/home/jovyan/core/config/")
sys.path.append("/home/jovyan/core/util/")
sys.path.append("../PlotFunction/lineplot/")
sys.path.append("../PlotFunction/config/")

In [3]:
from ALL import config
from line_plot_1 import line_plot_1
from line_plot_error_1 import line_plot_error_1
from line_plot_1_layout import layout
from util import *

## Set condition

In [80]:
tqdm.pandas()
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 50)

In [101]:
data_type = "AgNewsTitle"
vectorize_types = ["doc2vec", "sentenceBERT"]

In [102]:
model_nums = config["clustering"]["gmm"]["max_model_num"]
covariance_types = config["clustering"]["gmm"]["covariance_types"]
normalization = "normalized"
covariance_types = ["spherical", "diag", "full"]
vector_dims = {
    "doc2vec": config["vectorize"]["doc2vec"]["dims"],
    "sentenceBERT": config["vectorize"]["sentenceBERT"]["dims"] + [384],
}

In [103]:
stats_vals = ["aic", "bic", "mi", "logl"]

# Stats

## Read data

In [104]:
def load_stats_data(vectorize_type, stats_vals, covariance_types, model_nums):
    # データ型定義
    stats = {
        stats_val: {
            covariance_type: {
                model_num: pd.DataFrame for model_num in range(model_nums)
            }
            for covariance_type in covariance_types
        }
        for stats_val in stats_vals
    }

    # データ取得
    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            for model_num in range(model_nums):
                stats_path = f"../../Postprocessing/data/{data_type}/{vectorize_type}/GMM/stats/{normalization}/{covariance_type}/{model_num}.csv"
                df = pd.read_csv(stats_path, index_col=0)
                stats[stats_val][covariance_type][model_num] = df.loc[:, stats_val]
    return stats

In [105]:
def load_lda_mi(data_type):
    lda_mi = pd.read_csv(
        f"../../Postprocessing/data/{data_type}/LDA/mi.csv", index_col=0
    )
    describe_lda_mi, _ = get_describe(lda_mi, axis=0)
    return describe_lda_mi

In [106]:
stats_dict = {}
for vectorize_type in vectorize_types:
    stats_dict[vectorize_type] = load_stats_data(
        vectorize_type, stats_vals, covariance_types, model_nums
    )

In [107]:
describe_lda_mi = load_lda_mi(data_type)

## Data shaping

In [108]:
def shape_stats_df(stats):
    stats_df = {
        stats_val: {covariance_type: pd.DataFrame() for covariance_type in covariance_types}
        for stats_val in stats_vals
    }

    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            # model_numについて取得データstatsを結合
            stats_df[stats_val][covariance_type] = pd.concat(
                stats[stats_val][covariance_type], axis=1
            )
    return stats_df

In [109]:
def shape_describe(stats_df):
    describe = {
        stats_val: {covariance_type: dict() for covariance_type in covariance_types}
        for stats_val in stats_vals
    }

    for stats_val in stats_vals:
        for covariance_type in covariance_types:
            describe[stats_val][covariance_type], describe_keys = get_describe(
                stats_df[stats_val][covariance_type], axis=1
            )
    return describe, describe_keys

In [110]:
def shape_data(describe, describe_keys):
    data = {
        stats_val: {describe_key: pd.DataFrame() for describe_key in describe_keys}
        for stats_val in stats_vals
    }
    for stats_val in stats_vals:
        for describe_key in describe_keys:
            # covariance_typeについてデータを結合
            _data = {
                covariance_type: describe[stats_val][covariance_type][describe_key]
                for covariance_type in covariance_types
            }
            data[stats_val][describe_key] = pd.concat(_data, axis=1)
    return data

In [111]:
data_stats = {}
for vectorize_type, stats in stats_dict.items():
    stats_df = shape_stats_df(stats)
    describe, describe_keys = shape_describe(stats_df)
    data_stats[vectorize_type] = shape_data(describe, describe_keys)

In [112]:
stats_df["mi"]["full"].style.highlight_max(axis=0)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
2,0.311385,0.364783,0.356637,0.311357,0.311312,0.362891,0.356646,0.365019,0.311309,0.364783,0.355033,0.311323,0.355048,0.32032,0.311323,0.311385,0.365925,0.311312,0.311323,0.311312,0.365009,0.356618,0.311302,0.355032,0.311323,0.311334,0.311323,0.355038,0.364783,0.364928
4,0.488087,0.474213,0.488007,0.474213,0.474213,0.488043,0.474213,0.474229,0.488016,0.488001,0.488001,0.488001,0.488006,0.474213,0.488001,0.488016,0.474264,0.488017,0.488016,0.474196,0.488001,0.487976,0.474264,0.488011,0.474264,0.487999,0.474256,0.488082,0.488001,0.487976
6,0.506245,0.506267,0.506228,0.48348,0.48347,0.48348,0.506175,0.506191,0.506141,0.506395,0.506228,0.506141,0.506184,0.506228,0.506136,0.506314,0.48348,0.506141,0.506188,0.506228,0.506141,0.506228,0.506314,0.506228,0.50643,0.48348,0.506415,0.506141,0.506319,0.506133
8,0.49149,0.491514,0.483398,0.491449,0.491412,0.491218,0.491427,0.491487,0.491498,0.491554,0.491514,0.49142,0.491517,0.491554,0.491514,0.491424,0.483398,0.491228,0.491424,0.491499,0.483398,0.491505,0.491228,0.491198,0.491517,0.491514,0.491488,0.491228,0.491424,0.491424
10,0.50532,0.50532,0.434075,0.50532,0.50532,0.50532,0.50532,0.434075,0.50532,0.50532,0.434075,0.50532,0.434075,0.50532,0.50532,0.50532,0.50532,0.422667,0.50532,0.434075,0.50532,0.50532,0.50532,0.422639,0.50532,0.50532,0.50532,0.50532,0.50532,0.50532
20,0.483216,0.483358,0.483239,0.483276,0.483298,0.483276,0.483174,0.483276,0.483276,0.483386,0.483276,0.483216,0.483287,0.483276,0.483435,0.483276,0.48323,0.483276,0.483276,0.483174,0.48324,0.48324,0.483276,0.48334,0.483276,0.449003,0.448974,0.483239,0.483276,0.483276
40,0.460044,0.480637,0.480637,0.460043,0.480637,0.460044,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.460043,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.480637,0.460044,0.480637,0.480637,0.460044,0.480637,0.480637
80,0.482401,0.482401,0.482401,0.481962,0.482401,0.482378,0.481962,0.482401,0.481962,0.481962,0.481962,0.481962,0.481962,0.482401,0.456301,0.482401,0.482401,0.456301,0.482378,0.482401,0.482378,0.482378,0.481962,0.482401,0.482401,0.482401,0.482401,0.482401,0.482401,0.481962
160,0.479265,0.488163,0.488159,0.488159,0.479265,0.488159,0.488159,0.488163,0.488163,0.488159,0.488155,0.488159,0.488159,0.488163,0.488163,0.488159,0.488163,0.479265,0.488159,0.463153,0.488163,0.488159,0.488163,0.488159,0.488159,0.488159,0.488159,0.488163,0.488159,0.479265
384,0.501512,0.500649,0.501232,0.50108,0.500549,0.500931,0.500962,0.501652,0.501403,0.501475,0.501097,0.501146,0.501376,0.501428,0.501003,0.501626,0.500991,0.50155,0.501249,0.500579,0.501502,0.50112,0.5014,0.501162,0.501607,0.501393,0.500539,0.501558,0.500586,0.501231


In [119]:
get_describe(stats_df["mi"]["full"], axis=1)[0]["mean"].max()

0.5024390241159054

# Make Chart

In [114]:
best_dim = {"AgNews": 8, "20News": 80}

In [115]:
chart_data = {}
chart_data["doc2vec"] = data_stats["doc2vec"]["mi"]["mean"].loc[best_dim[data_type], :]
chart_data["sentenceBERT"] = data_stats["sentenceBERT"]["mi"]["mean"].loc[384, :]

KeyError: 'AgNewsTitle'

In [116]:
pd.DataFrame(chart_data)

In [97]:
chart_df = pd.DataFrame(chart_data)

In [98]:
chart_df.T.style.format(
    escape="latex", formatter={"document_count": "{:.0f}"}
)

Unnamed: 0,spherical,diag,full
doc2vec,0.535979,0.535291,0.522142
sentenceBERT,0.582425,0.590015,0.593456


In [99]:
chart_df.to_csv(make_filepath(f"../data/{data_type}/CovarianceChart.csv"))

In [100]:
print(
    chart_df.T
    .style.format(precision=3, escape="latex")
    .to_latex(
        column_format="rrrr",
        position="h",
        position_float="centering",
        hrules=True,
        caption="miの比較と埋め込み次元",
        label="table:1",
        multicol_align="r",
    )
)

\begin{table}[h]
\centering
\caption{miの比較と埋め込み次元}
\label{table:1}
\begin{tabular}{rrrr}
\toprule
 & spherical & diag & full \\
\midrule
doc2vec & 0.536 & 0.535 & 0.522 \\
sentenceBERT & 0.582 & 0.590 & 0.593 \\
\bottomrule
\end{tabular}
\end{table}

