In [1]:
import warnings
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tms.utils.utils import load_results
from tms.training.experiments import run_experiments
from tms.llc import estimate_llc, get_llc_data
import os

In [2]:
version = "1.5.0"
data_path = "../data"
results = load_results(data_path, version)

In [3]:
results[0]['parameters']

{'m': 6,
 'n': 2,
 'num_samples': 100,
 'batch_size': 1024,
 'num_epochs': 20000,
 'sparsity': 0.0,
 'lr': 0.005,
 'momentum': 0.9,
 'weight_decay': 0.0,
 'init_kgon': 6,
 'no_bias': False,
 'init_zerobias': False,
 'prior_std': 0,
 'seed': 0,
 'log_ivl': [1,
  2,
  3,
  4,
  5,
  6,
  7,
  9,
  11,
  13,
  16,
  20,
  25,
  31,
  38,
  46,
  56,
  69,
  85,
  104,
  127,
  156,
  191,
  234,
  286,
  351,
  429,
  526,
  643,
  788,
  964,
  1180,
  1445,
  1768,
  2165,
  2650,
  3243,
  3970,
  4859,
  5948,
  7280,
  8910,
  10906,
  13349,
  16340,
  20000]}

In [4]:
llc_estimates = estimate_llc(results, version, snapshot_indices=[-1])

bias: False, num_features: 6, num_hidden_units: 2


  0%|          | 0/500 [00:00<?, ?it/s]

Running llc estimation for run 0
Running llc estimation for snapshot -1


Sweeping hyperparameters:   0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]




TypeError: 'NoneType' object is not callable

In [1]:
def compare_dataframes_and_results(df_results_pairs, positions = [9, 18, 27, 36, -1], hyperparam_combos = [(300, 0.001)]):
    warnings.simplefilter(action='ignore', category=UserWarning)  # Probably unwise

    for batch_size, learning_rate in hyperparam_combos:
        print(f"Batch size: {batch_size}, Learning rate: {learning_rate}\n")
        
        for position in positions:
            fig, axes = plt.subplots(1, len(df_results_pairs), figsize=(15*len(df_results_pairs), 10))
            if len(df_results_pairs) == 1:
                axes = [axes]

            for pair_index, (llc_estimates, results) in enumerate(df_results_pairs):
                llc_loss_by_sparsity = defaultdict(list)
                steps = results[0]['parameters']['log_ivl']
                
                for index in range(len(results)):
                    filtered_df = llc_estimates[llc_estimates['index'] == index]
                    llc = filtered_df[
                        (filtered_df['batch_size'] == batch_size) &
                        (filtered_df['lr'] == learning_rate) &
                        (filtered_df['snapshot_index'] == position) &
                        (filtered_df['t_sgld'] > 150) &
                        (filtered_df['llc_type'] != "mean")
                    ]["llc"].mean()
                    
                    loss = results[index]['logs']['loss'].values[position]
                    sparsity = results[index]['parameters']['sparsity']
                    llc_loss_by_sparsity[sparsity].append((llc, loss))

                for sparsity, llc_loss in llc_loss_by_sparsity.items():
                    if sparsity == 0:
                        continue
                    llcs = [llc for llc, loss in llc_loss if not(np.isnan(llc))]
                    losses = [loss for llc, loss in llc_loss if not(np.isnan(llc))]
                    if all(np.isnan(llc) for llc in llcs):
                        continue
                    axes[pair_index].scatter(*zip(*llc_loss), label=f"Sparsity: {round(sparsity,3)}")
                    if pair_index == 0:
                        title = "Initialized at random 4-gon"
                    if pair_index == 1:
                        title = "Initialized at optimal parameters for sparse inputs"
                    axes[pair_index].set_title(f"Pair {title}, Position {position}")
                    axes[pair_index].set_xlabel("LLC")
                    axes[pair_index].set_ylabel("Loss")
                    axes[pair_index].legend()


            plt.tight_layout()
            plt.suptitle(f"Loss and LLC After Epoch {steps[position]}", fontsize=16)
            plt.subplots_adjust(top=0.9)
            plt.savefig(f'../results/loss_vs_llc_epoch_{steps[position]}')
            plt.show()