In [None]:
import pickle
import os
import json
import numpy as np
import matplotlib.pyplot as plt

root_dir = "."
results_sst2 = {}
for result_id in os.listdir(f"{root_dir}/results/train_test"):
    with open(f"{root_dir}/results/train_test/{result_id}/config.json", "rb") as f:
        config = json.load(f)
    if config["dataset"] == "sst2":
        with open(f"{root_dir}/results/train_test/{result_id}/train.pkl", "rb") as f:
            train_results = pickle.load(f)
        with open(f"{root_dir}/results/train_test/{result_id}/test.pkl", "rb") as f:
            test_results = pickle.load(f)
        results_sst2[config["random_state"]] = (train_results, test_results)

train_priors = []
test_priors = []
for random_state, (train_results, test_results) in results_sst2.items():
    rs = np.random.RandomState(random_state)
    train_labels = train_results["train_labels"]
    test_labels = test_results["test_labels"]
    train_idx = rs.choice(len(train_labels), 400, replace=False)
    train_subsample_labels = train_labels[train_idx]
    train_priors.append(train_subsample_labels.sum() / len(train_subsample_labels))
    test_priors.append(test_labels.sum() / len(test_labels))

fig, ax = plt.subplots(2,1,sharex=True)
ax[0].hist(train_priors, bins=20)
ax[0].grid(True)
ax[0].set_title("Train priors distribution")
ax[1].hist(test_priors, bins=20)
ax[1].grid(True)
ax[1].set_title("Test priors distribution")
print(np.min(train_priors), np.min(test_priors))
print(np.max(train_priors), np.max(test_priors))