### Loss plots

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
def create_loss_logs_lists(directory):
    loss_logs = []
    current_growth = None
    growth_log = []
    for log in sorted(
        os.listdir(directory),
        key=lambda x: (
            int(x.split("_")[3]),  # growth
            int(x.split("_")[5]),  # epoch
            int(x.split("_")[8].split(".")[0])  # epoch_step
        )
    ):
        json_file_path = os.path.join(directory, log)
        log_split = log.split("_")
        growth = log_split[3]
        if growth == 2:
            break

        if current_growth is None:
            current_growth = growth

        with open(json_file_path, "r") as f:
            if growth == current_growth:
                growth_log.extend(json.loads(f.read()))
            else:
                loss_logs.append(growth_log)
                growth_log = json.loads(f.read())
                current_growth = growth

    if growth_log:
        loss_logs.append(growth_log)

    return loss_logs


In [None]:
def get_bounded_losses(global_steps, losses, bounds):
    bounded_global_steps = [
        [
            global_steps[i][j]
            for j in range(len(losses[i]))
            if bounds[0] <= losses[i][j] <= bounds[1]
        ]
        for i in range(len(losses))
    ]

    bounded_losses = [
        [
            losses[i][j]
            for j in range(len(losses[i]))
            if bounds[0] <= losses[i][j] <= bounds[1]
        ]
        for i in range(len(losses))
    ]

    return bounded_global_steps, bounded_losses


def get_bounded_losses_flat(global_steps, losses, bounds):
    bounded_global_steps = [
        global_steps[i]
        for i in range(len(losses))
        if bounds[0] <= losses[i] <= bounds[1]
    ]

    bounded_losses = [
        loss
        for loss in losses
        if bounds[0] <= loss <= bounds[1]
    ]

    return bounded_global_steps, bounded_losses


In [None]:
def rollavg_cumsum_edges(a, n):
    if n % 2 != 1:
        n = max(1, n - 1)
    N = len(a)
    cumsum_vec = np.cumsum(
        np.insert(np.pad(a, (n - 1, n - 1), "constant"), 0, 0)
    )
    d = np.hstack(
        (
            np.arange(n // 2 + 1, n),
            np.ones(N - n) * n,
            np.arange(n, n // 2, -1)
        )
    )

    return (cumsum_vec[n+n//2:-n//2+1] - cumsum_vec[n//2:-n-n//2]) / d

def plot_losses(global_steps_flat, losses_flat, global_steps, losses, params):
    fig = plt.figure(figsize=(params["fig_size_x"], params["fig_size_y"]))
    plt.title(params["title"])
    plt.xlabel("Global Step")
    plt.ylabel("Loss")

    x_min = np.min(global_steps_flat)
    x_max = np.max(global_steps_flat)
    x_range = x_max - x_min
    plt.xticks(
        np.arange(
            x_min,
            x_max,
            step=np.floor(x_range / params["num_xticks"]).astype(np.int64)
        )
    )

    plt.yscale(params["yscale"])
    if params["yscale"] == "linear":
        y_min = np.min(losses_flat)
        y_max = np.max(losses_flat)
        y_range = y_max - y_min
        plt.yticks(np.arange(y_min, y_max, step=y_range / params["num_yticks"]))
    

    for growth_idx in range(len(global_steps)):
        block_idx = (growth_idx + 1) // 2
        if growth_idx % 2 == 1:
            plt.plot(
                global_steps[growth_idx],
                losses[growth_idx],
                label="{}T".format(4 * 2 ** block_idx)
            )
        else:
            plt.plot(
                global_steps[growth_idx],
                losses[growth_idx],
                label="{}S".format(4 * 2 ** block_idx)
            )

    # Plot sliding means of loss.
    loss_means = rollavg_cumsum_edges(np.array(losses_flat), params["mean_window_steps"])
    plt.plot(
        global_steps_flat,
        loss_means,
        label="mean",
        color="black"
    )
    plt.legend(loc="upper right")

    plt.hlines(y=0., xmin=0, xmax=global_steps_flat[-1], color="black")
    plt.xlim(left=x_min, right=x_max*1.05)
    plt.grid(b=True)


In [None]:
def investigate_loss(
    loss_logs, loss_key, percentile_bounds, mean_window_steps, yscale
):
    global_steps = [[int(x["global_step"]) for x in log] for log in loss_logs]
    losses = [
        [float(x["losses"][loss_key]) for x in log] for log in loss_logs
    ]

    loss_logs_flat = [item for sublist in loss_logs for item in sublist]

    global_steps_flat = [int(x["global_step"]) for x in loss_logs_flat]
    losses_flat = [float(x["losses"][loss_key]) for x in loss_logs_flat]

    loss_bounds = np.percentile(
        a=losses_flat, q=[percentile_bounds[0], percentile_bounds[1]]
    )
    print("Bounds = {}".format(loss_bounds))

    (bounded_global_steps,
     bounded_losses) = get_bounded_losses(
        global_steps=global_steps,
        losses=losses,
        bounds=loss_bounds
    )

    (bounded_global_steps_flat,
     bounded_losses_flat) = get_bounded_losses_flat(
        global_steps=global_steps_flat,
        losses=losses_flat,
        bounds=loss_bounds
    )

    plot_losses(
        global_steps_flat=bounded_global_steps_flat,
        losses_flat=bounded_losses_flat,
        global_steps=bounded_global_steps,
        losses=bounded_losses,
        params={
            "title": "Loss: {}".format(loss_key),
            "fig_size_x": 30,
            "fig_size_y": 10,
            "num_xticks": 25,
            "num_yticks": 15,
            "mean_window_steps": mean_window_steps,
            "yscale": yscale
        }
    )


In [None]:
def create_all_loss_plots(loss_logs, loss_configs):
    for config in loss_configs:
        if config["loss_key"] in loss_logs[0][0]["losses"]:
            investigate_loss(
                loss_logs=loss_logs,
                loss_key=config["loss_key"],
                percentile_bounds=config["percentile_bounds"],
                mean_window_steps=config["mean_window_steps"],
                yscale=config["yscale"]
            )


In [None]:
loss_configs = [
    {
        "loss_key": "generator_total_loss",
        "percentile_bounds": (5, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "encoder_total_loss",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "discriminator_total_loss",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(z))",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(x))",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(x)",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(z))-D(x)",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(x))-D(x)",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(z))_gradient_penalty",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(x))_gradient_penalty",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "epsilon_drift_penalty",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(z))_wgan_gp",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "D(G(x))_wgan_gp",
        "percentile_bounds": (1, 95),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "z-E(G(z))_L1",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "z-E(G(z))_L2",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "E(x)-E(G(E(x)))_L1",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "E(x)-E(G(E(x)))_L2",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "G(z)-G(E(G(z)))_L1",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "G(z)-G(E(G(z)))_L2",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "x-G(E(x))_L1",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "x-G(E(x))_L2",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "Ge(x)-E(G(x))_L1",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "Ge(x)-E(G(x))_L2",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "x-G(x)_L1",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    },
    {
        "loss_key": "x-G(x)_L2",
        "percentile_bounds": (1, 99),
        "mean_window_steps": 50,
        "yscale": "linear"
    }
]

#### Copy loss data

In [None]:
%%bash
rm -rf loss_logs/*
gsutil -m cp -r gs://.../trained_models/experiment/loss_logs . >/dev/null 2>&1

#### Create loss plots

In [None]:
loss_logs = create_loss_logs_lists(directory="loss_logs")

print(loss_logs[0][0])

create_all_loss_plots(
    loss_logs=loss_logs,
    loss_configs=loss_configs
)