## import 

In [None]:
!python -m pip install polars pyarrow

In [7]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict
import scipy.stats as stats
from tqdm import tqdm

import polars as pl

## Lower bound entropy of the whole dataset

In [None]:
data_dir = Path("../data/variations_wc/6switches_3values_min1_max20_10K")
lower_bound_complexity = []

# フォルダ名からgrammar_namesを取得
grammar_names = [d.name for d in data_dir.iterdir() if d.is_dir()]
print(len(grammar_names))
valid_grammar = 0
for grammar_name in grammar_names:
    if not (
        data_dir / grammar_name / "true_prob" / "lower_bound_entropy.value"
    ).exists():
        if not (
            data_dir / grammar_name / "true_prob" / "probability_split_1_of_1.csv.gz"
        ).exists():
            print(f"{grammar_name} is not valid")

In [4]:
data_dir = Path("../data/variations_wc/6switches_3values_min1_max20_10K")
lower_bound_complexity = []

# フォルダ名からgrammar_namesを取得
grammar_names = [d.name for d in data_dir.iterdir() if d.is_dir()]

for grammar_name in grammar_names:
    true_prob_dir = data_dir / grammar_name / "true_prob"

    try:
        # entropyとperplexityの値を読み込む
        with open(true_prob_dir / "lower_bound_entropy.value") as f:
            entropy = float(f.read().strip())

        with open(true_prob_dir / "lower_bound_perplexity.value") as f:
            perplexity = float(f.read().strip())

        lower_bound_complexity.append(
            {
                "grammar_name": grammar_name,
                "lower_bound_entropy": entropy,
                "lower_bound_perplexity": perplexity,
            }
        )
    except:
        print(f"Can't find {true_prob_dir}")

In [None]:
# 理論下限のプロット
fig, ax = plt.subplots(figsize=(25, 8))

# DataFrameに変換して扱いやすくする
lower_bound_df = pd.DataFrame(lower_bound_complexity)

# '2'の数が少ない順にソート
lower_bound_df["num_twos"] = lower_bound_df["grammar_name"].apply(
    lambda x: x.count("2")
)
lower_bound_df = lower_bound_df.sort_values(by="num_twos").drop(columns="num_twos")


# プロット
ax.bar(
    lower_bound_df["grammar_name"],
    lower_bound_df["lower_bound_perplexity"],
    color="skyblue",
)

# 見やすさの調整
ax.set_xticks(range(len(list(lower_bound_df["grammar_name"].values))))
ax.set_xticklabels(
    list(lower_bound_df["grammar_name"].values), rotation=45, fontsize=10
)
ax.set_xlabel("Grammar")
ax.set_ylabel("Lower Bound Perplexity")
ax.set_title("Lower Bound of Perplexity by Grammar")


# グリッドと余白の調整
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(25, 8))

# プロット
ax.bar(
    lower_bound_df["grammar_name"],
    lower_bound_df["lower_bound_entropy"],  # Changed from lower_bound_perplexity
    color="skyblue",
)

# 見やすさの調整
ax.set_xticks(range(len(list(lower_bound_df["grammar_name"].values))))
ax.set_xticklabels(list(lower_bound_df["grammar_name"].values), rotation=45)
ax.set_xlabel("Grammar")
ax.set_ylabel("Lower Bound Entropy")  # Updated ylabel
ax.set_title("Lower Bound of Entropy by Grammar")  # Updated title

# グリッドと余白の調整
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

## Lower bound entropy of testset

In [7]:
import sys

sys.path.append("..")
from src.length_sampling.sampler import construct_pcfg_sampler
from src.length_sampling.grammars.pcfg import Grammar
from src.length_sampling.grammars.cfg import Nonterminal
from src.length_sampling.lower_bound_perplexity import (
    parts_to_perplexity,
    Parts,
)
import polars as pl
from pathlib import Path
import gzip
import math
import pyarrow

In [None]:
min_length = 1
max_length = 20
data_dir = Path("../data/variations_wc/6switches_3values_min1_max20_10K")
fairseq_data_dir = Path("../data/fairseq_train/6switches_3values_min1_max20_10K")
grammar_dir = Path("../data/grammars/variations/6switches_3values")

lower_bound_complexity = []


grammar_names = lower_bound_df["grammar_name"].values
print(f"Processing {len(grammar_names)} grammars")

for grammar_name in grammar_names:
    # grammarの読み込みとsamplerの構築
    grammar_file = grammar_dir / f"{grammar_name}.gr"
    grammar = Grammar.from_file(grammar_file, Nonterminal("S"), normalize=True)
    sampler = construct_pcfg_sampler(grammar)

    true_prob_dir = data_dir / grammar_name / "true_prob"

    # テストデータの文を読み込む
    test_file = fairseq_data_dir / grammar_name / "test.txt"
    with open(test_file) as f:
        test_sentences = [line.strip() for line in f]

    # true_probデータを読み込む
    dfs = []
    for file in true_prob_dir.glob("*.csv.gz"):
        with gzip.open(file, "rt") as f:
            df = pl.read_csv(
                f, new_columns=["sentence", "count", "true_log_prob", "true_prob"]
            )
            dfs.append(df)
    if dfs:
        df = pl.concat(dfs).filter(pl.col("count").is_not_null())

        # test_sentencesをDataFrameに変換
        test_df = pd.DataFrame(test_sentences, columns=["sentence"])

        # true_log_probをtest_dfにマージ
        test_df = test_df.merge(
            df.to_pandas()[["sentence", "true_log_prob"]], on="sentence", how="left"
        )
        print(
            f"Grammar {grammar_name}, len: {len(test_df)}, unique: {test_df['sentence'].nunique()}"
        )

        # 統計量の計算
        total_neg_log_prob = -1.0 * test_df["true_log_prob"].sum()
        test_df["sent_len"] = test_df["sentence"].map(
            lambda x: len(x.split()) + 1
        )  # +1 for EOS
        total_len = test_df["sent_len"].sum()
        num_samples = len(test_df)

        valid_lengths = sampler.valid_lengths(min_length=1, max_length=20)

        # Partsを使ってパープレキシティを計算
        parts = Parts(total_neg_log_prob, total_len, num_samples)
        perplexity = parts_to_perplexity(parts, len(valid_lengths))
        entropy = math.log(perplexity)

        lower_bound_complexity.append(
            {
                "grammar_name": grammar_name,
                "lower_bound_entropy": entropy,
                "lower_bound_perplexity": perplexity,
            }
        )

In [None]:
# 理論下限のプロット
fig, ax = plt.subplots(figsize=(25, 8))

# DataFrameに変換して扱いやすくする
lower_bound_df = pd.DataFrame(lower_bound_complexity)

# '2'の数が少ない順にソート
lower_bound_df["num_twos"] = lower_bound_df["grammar_name"].apply(
    lambda x: x.count("2")
)
lower_bound_df = lower_bound_df.sort_values(by="num_twos").drop(columns="num_twos")


# プロット
ax.bar(
    lower_bound_df["grammar_name"],
    lower_bound_df["lower_bound_perplexity"],
    color="skyblue",
)

# 見やすさの調整
ax.set_xticks(range(len(list(lower_bound_df["grammar_name"].values))))
ax.set_xticklabels(list(lower_bound_df["grammar_name"].values), rotation=45)
ax.set_xlabel("Grammar")
ax.set_ylabel("Lower Bound Perplexity of Test Data (nats)")
ax.set_title("Empirical Lower Bound of Perplexity of Test Data by Grammar")


# グリッドと余白の調整
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

## Model performance

In [1]:
import math
from pathlib import Path
from tqdm import tqdm
import numpy as np

exp_name = "local_entropy"

data_dir = Path(f"../data/fairseq_train/{exp_name}")
grammar_names = [d.name for d in data_dir.iterdir() if d.is_dir()]

num_seeds = 5
# model_names = ["transformer", "lstm", "transformer_tiny"]
model_names = ["lstm", "transformer_4layer"]

results_dir = Path("../results").resolve()

result_list = []
for grammar_name in tqdm(grammar_names):
    split_data_file = data_dir / grammar_name / "test.txt"
    with open(split_data_file) as f:
        total_syms = sum(
            len(line.strip().split()) + 1 for line in f
        )  # add 1 for EOS token
    for model_name in model_names:
        model_result_dir = results_dir / f"{model_name}_results" / exp_name
        grammar_result_dir = model_result_dir / f"{grammar_name}"
        seed_stats = []
        for seed_i in range(num_seeds):
            split_result_file = grammar_result_dir / f"seed{seed_i}" / "test.scores.txt"
            with open(split_result_file) as f:
                scores = [float(line.strip()) for line in f]
            total_sents = len(scores)
            neg_log_probs = [-1.0 * score for score in scores]
            sym_cross_entropy = (sum(neg_log_probs) / total_syms) / math.log(2)
            sym_perplexity = math.exp(sym_cross_entropy)
            sent_cross_entropy = (sum(neg_log_probs) / total_sents) / math.log(2)
            sent_perplexity = math.exp(sent_cross_entropy)
            # seed_stats.append(
            #     {
            #         "sym_cross_entropy": sym_cross_entropy,
            #         "sym_perplexity": sym_perplexity,
            #         "sent_cross_entropy": sent_cross_entropy,
            #         "sent_perplexity": sent_perplexity,
            #     }
            # )

            # sym_entropies = [s["sym_cross_entropy"] for s in seed_stats]
            # sym_perplexities = [s["sym_perplexity"] for s in seed_stats]
            # sent_entropies = [s["sent_cross_entropy"] for s in seed_stats]
            # sent_perplexities = [s["sent_perplexity"] for s in seed_stats]

            result_list.append(
                {
                    "model_name": model_name,
                    "seed": seed_i,
                    "grammar_name": grammar_name,
                    "sym_cross_entropy": sym_cross_entropy,
                    "sym_perplexity": sym_perplexity,
                    "sent_cross_entropy": sent_cross_entropy,
                    "sent_perplexity": sent_perplexity,
                }
            )
    print(total_syms, total_sents, total_syms / total_sents)

In [20]:
import pandas as pd

result_df = (
    pd.DataFrame(result_list)
    .sort_values(["model_name", "grammar_name"])
    .reset_index(drop=True)
)
result_df

Unnamed: 0,model_name,seed,grammar_name,sym_cross_entropy,sym_perplexity,sent_cross_entropy,sent_perplexity
0,lstm,0,Q16_S16_s10180,1.039505,2.827818,36.184035,5.182350e+15
1,lstm,1,Q16_S16_s10180,1.037321,2.821648,36.108003,4.802932e+15
2,lstm,2,Q16_S16_s10180,1.042292,2.835708,36.281020,5.710143e+15
3,lstm,3,Q16_S16_s10180,1.044746,2.842677,36.366457,6.219442e+15
4,lstm,4,Q16_S16_s10180,1.037053,2.820891,36.098661,4.758273e+15
...,...,...,...,...,...,...,...
355,transformer_4layer,0,Q32_S8_s8247,0.823107,2.277565,8.748644,6.302134e+03
356,transformer_4layer,1,Q32_S8_s8247,0.822775,2.276810,8.745122,6.279979e+03
357,transformer_4layer,2,Q32_S8_s8247,0.819156,2.268585,8.706654,6.042989e+03
358,transformer_4layer,3,Q32_S8_s8247,0.831020,2.295660,8.832754,6.855142e+03


In [21]:
# save results
result_df.to_csv(results_dir / f"length_sampling_{exp_name}_results.csv", index=False)