In [7]:
import warnings
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tms.utils.utils import load_results
from tms.training.experiments import run_experiments
from tms.llc import estimate_llc, get_llc_data
import os

In [3]:
version = "1.9.0"
data_path = "../data"
results = load_results(data_path, version)

llc_estimates = estimate_llc(results, version)

In [1]:
def compare_dataframes_and_results(df_results_pairs, positions = [9, 18, 27, 36, -1], hyperparam_combos = [(300, 0.001)]):
    warnings.simplefilter(action='ignore', category=UserWarning)  # Probably unwise

    for batch_size, learning_rate in hyperparam_combos:
        print(f"Batch size: {batch_size}, Learning rate: {learning_rate}\n")
        
        for position in positions:
            fig, axes = plt.subplots(1, len(df_results_pairs), figsize=(15*len(df_results_pairs), 10))
            if len(df_results_pairs) == 1:
                axes = [axes]

            for pair_index, (llc_estimates, results) in enumerate(df_results_pairs):
                llc_loss_by_sparsity = defaultdict(list)
                steps = results[0]['parameters']['log_ivl']
                
                for index in range(len(results)):
                    filtered_df = llc_estimates[llc_estimates['index'] == index]
                    llc = filtered_df[
                        (filtered_df['batch_size'] == batch_size) &
                        (filtered_df['lr'] == learning_rate) &
                        (filtered_df['snapshot_index'] == position) &
                        (filtered_df['t_sgld'] > 150) &
                        (filtered_df['llc_type'] != "mean")
                    ]["llc"].mean()
                    
                    loss = results[index]['logs']['loss'].values[position]
                    sparsity = results[index]['parameters']['sparsity']
                    llc_loss_by_sparsity[sparsity].append((llc, loss))

                for sparsity, llc_loss in llc_loss_by_sparsity.items():
                    if sparsity == 0:
                        continue
                    llcs = [llc for llc, loss in llc_loss if not(np.isnan(llc))]
                    losses = [loss for llc, loss in llc_loss if not(np.isnan(llc))]
                    if all(np.isnan(llc) for llc in llcs):
                        continue
                    axes[pair_index].scatter(*zip(*llc_loss), label=f"Sparsity: {round(sparsity,3)}")
                    if pair_index == 0:
                        title = "Initialized at random 4-gon"
                    if pair_index == 1:
                        title = "Initialized at optimal parameters for sparse inputs"
                    axes[pair_index].set_title(f"Pair {title}, Position {position}")
                    axes[pair_index].set_xlabel("LLC")
                    axes[pair_index].set_ylabel("Loss")
                    axes[pair_index].legend()


            plt.tight_layout()
            plt.suptitle(f"Loss and LLC After Epoch {steps[position]}", fontsize=16)
            plt.subplots_adjust(top=0.9)
            plt.savefig(f'../results/loss_vs_llc_epoch_{steps[position]}')
            plt.show()