In [1]:
import json

import jsonlines
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import norm

## Checking heavy tails

In [3]:
HIST_PATH = "hist_path.json"

In [None]:
def calc_heaviness(norms_path, coef_1=1.5, coef_2=3):
    p_mRN = 0.0035
    p_eRN = 1.2 * 0.000001

    norms_dict = None
    with open(norms_path, "r") as file:
        norms_dict = json.load(file)

    norms = np.array(norms_dict["stochastic_norms"])
    q1 = np.quantile(norms, q=0.25)
    q3 = np.quantile(norms, q=0.75)

    statistic_1 = q3 + coef_1 * (q3 - q1)
    statistic_2 = q3 + coef_2 * (q3 - q1)

    p_mR = sum(norms > statistic_1) / len(norms)
    p_eR = sum(norms > statistic_2) / len(norms)

    return norms.mean(), norms.std(), p_mR / p_mRN, p_eR / p_eRN


calc_heaviness(HIST_PATH)

In [None]:
sns.set_theme()


def plot_norms(norms_path):
    norms_dict = None
    with open(norms_path, "r") as file:
        norms_dict = json.load(file)

    norms = norms_dict["stochastic_norms"]

    print(sorted(norms))
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.hist(norms, color="royalblue", edgecolor="white", bins=100, density=True)

    mu, sigma, ro_mR, ro_eR = calc_heaviness(HIST_PATH)
    x = np.linspace(min(norms), max(norms), 100)
    ax.plot(x, norm.pdf(x, mu, sigma), color="black")
    ax.set_xlabel("Noise norm", fontsize=14)
    ax.set_ylabel("Density", fontsize=14)
    ax.grid(True)
    textstr = "\n".join(
        (
            r"$\mu=%.2f$" % (mu,),
            r"$\sigma=%.2f$" % (sigma,),
            r"$\rho_{mR}=%.3f$" % (ro_mR,),
            r"$\rho_{eR}=%.0f$" % (ro_eR,),
        )
    )
    ax.text(
        0.7, 0.9, textstr, transform=ax.transAxes, fontsize=14, verticalalignment="top"
    )
    ax.set_title("Dataset, n steps", fontsize=16)


plot_norms(HIST_PATH)
None

## Model comparison

In [11]:
def get_mean_line(logs):
    return np.array(logs).mean(axis=0)


def get_std_lines(logs):
    logs = np.array(logs)
    upper_line = logs.mean(axis=0) + logs.std(axis=0)
    lower_line = logs.mean(axis=0) - logs.std(axis=0)
    return upper_line, lower_line


def get_worst_line(logs, worst_type: str):
    if worst_type == "max":
        return np.array(logs).max(axis=0)
    return np.array(logs).min(axis=0)


def get_quantile_line(logs, q):
    return np.quantile(np.array(logs), q=q, axis=0)

In [14]:
LOG_N_STEP = 10  # replace with your value
PATHS_TO_COMPARE = [
    "logs_first_model.jsonl",
    "logs_second_model.jsonl",
]
LABELS = ["first_model", "second_model"]

assert len(PATHS_TO_COMPARE) == len(LABELS)

In [None]:
def preroc_logs(raw_logs, stage, log_n_step=LOG_N_STEP):
    if stage == "train":
        raw_logs = np.array(raw_logs)
        return (
            np.add.reduceat(raw_logs, np.arange(0, len(raw_logs), log_n_step))[:-1]
            / log_n_step
        )
    return np.array([np.array(x).mean() for x in raw_logs if len(x) > 2])

In [15]:
logs = []

for i, path in enumerate(PATHS_TO_COMPARE):
    log = {"model": {"train_loss": [], "val_loss": [], "val_metric": []}}

    with jsonlines.open(path) as reader:
        for obj in reader:
            log["model"]["train_loss"].append(
                preroc_logs(obj["model"]["train_loss"], "train")
            )
            log["model"]["val_loss"].append(
                preroc_logs(obj["model"]["val_loss"], "val")
            )
            log["model"]["val_metric"].append(
                preroc_logs(obj["model"]["val_metric"], "val")
            )
    logs.append(log)

In [None]:
METRICS_TYPE = "val_loss"  # "val_metric", "train_loss"

In [None]:
clipped_index, base_index = 0, 1
fig, ax = plt.subplots(figsize=(9, 5))
colors = sns.color_palette("husl", 2)[::-1]

clipped_log = logs[clipped_index]
metrics = clipped_log["model"][METRICS_TYPE]

clipped_metrics_high = get_quantile_line(metrics, 0.95)
clipped_metrics_med = get_quantile_line(metrics, 0.5)
clipped_metrics_low = get_quantile_line(metrics, 0.05)
steps = [(i + 1) * LOG_N_STEP - 1 for i in range(len(clipped_metrics_high))]
ax.plot(
    steps, clipped_metrics_med, label=LABELS[clipped_index], color=colors[clipped_index]
)
ax.fill_between(
    steps,
    clipped_metrics_low,
    clipped_metrics_high,
    alpha=0.3,
    color=colors[clipped_index],
)

base_log = logs[base_index]
metrics = base_log["model"][METRICS_TYPE]

base_metrics_high = get_quantile_line(metrics, 0.95)
base_metrics_med = get_quantile_line(metrics, 0.5)
baes_metrics_low = get_quantile_line(metrics, 0.05)
ax.plot(steps, base_metrics_med, label=LABELS[base_index], color=colors[base_index])
ax.fill_between(
    steps, baes_metrics_low, base_metrics_high, alpha=0.3, color=colors[base_index]
)

ax.set_xlabel("Number of steps", fontsize=14)
ax.set_ylabel("val loss", fontsize=14)
ax.grid(True)
ax.set_title("Dataset, first_model vs second_model", fontsize=16)
ax.legend(fontsize=12)