<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [10]</a>'.</span>

### Installations

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

import torch
print("Available GPUs:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(0))

Available GPUs: 1
Current device: 0
Device name: NVIDIA GeForce RTX 3090


In [2]:
!nvidia-smi -L
!pip install -q jax flax optax ml_collections 

GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-261c922f-1397-45bd-524f-13d036befda3)
GPU 1: NVIDIA GeForce RTX 3090 (UUID: GPU-163ab9b1-70e4-90b0-565c-a56335ad6f19)


### Optional: Load Google drive and change directory to neurips_2023_demo

In [3]:
import sys

if "google.colab" in sys.modules:
    print("Running on Google Colab")

    from google.colab import drive
    drive.mount('/content/drive')
    %cd /content/drive/MyDrive/neurips_2023_15410/

### Load libraries

Requirements: data_utils_pytorch, model_utils, train_mse_utils / train_xent_utils in the current directory

In [4]:
loss = 'mse'
# custom modules
import data_utils_pytorch
import model_utils as models

if loss == 'mse':
    import train_mse_utils as train_utils
elif loss == 'xent':
    import train_xent_utils as train_utils
else:
    print('Unsupported loss function')
    raise


# in use imports
import jax
from jax import numpy as jnp
import optax
from ml_collections import config_dict

#usual imports
import numpy as np
import pandas as pd
from sys import argv
import gc
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size':15})
import seaborn as sns

### Helper functions

In [5]:
def poly(coeffs, x):
    "Given polynomial coefficients coeffs, evaluates the polynomal f(x)"
    degree = len(coeffs)-1
    output = 0
    for i in range(len(coeffs)):
        output += coeffs[i] * x**(degree-i)
    return output

pandas helper functions

In [6]:
def get_rows_where_col_equals(df, col, value):
	return df.loc[df[col] == value].copy()
 
def get_rows_where_col_in(df, col, values):
	return df.loc[df[col].isin(values)].copy()

def get_rows_where_col_greater(df, col, value):
	return df.loc[df[col] > value].copy()

def get_rows_where_col_less(df, col, value):
	return df.loc[df[col] < value].copy()

### Model definition

In this demo, we will use a 8 layer ReLU FCN to obtain the phase diagram of early training. The model can be changed to Myrtle CNNs, ResNets or linear FCNs using the model_util.py

In [7]:
def create_train_state(config, batches):
    """
    Description: Creates a Flax train state with learning rate eta = c / sharpness init
    Input: 
        1. config: ml_collections dict which contains the model and optimizer hyperparameters
        2. batches: batches used for sharpness estimation
    Output:
        1. state: Flax state with learning rate eta = c / sharpness init
        2. sharpness_init: sharpness at initialization
    """

    #create model
    model = models.Myrtle(num_filters=config.width,
    num_layers=config.depth,
    num_classes=config.num_classes,
    use_bias=config.use_bias,
    varw=config.varw
    )

    #initialize the parameters
    example = jax.random.normal(config.init_rng, shape = config.in_dim)
    init_params = model.init(config.init_rng, example)['params'] #initialize parameters

    #create a dummy state and optimizer for sharpness calculation
    _opt = optax.sgd(learning_rate = 0.1, momentum = config.momentum)
    _state = train_utils.TrainState.create(apply_fn = model.apply, params = init_params, opt = _opt)

    #calculate hessian and learning rate
    sharpness_init = train_utils.estimate_hessian(_state, batches, num_batches = config.measure_batches, power_iterations = config.power_iterations)
    lr_rate = config.lr_const/sharpness_init

    # create a new state with learning rate lr_rate; 
    # updating learning rate does not work since the recemt flax update
    opt = optax.sgd(learning_rate = lr_rate, momentum = config.momentum)
    state = train_utils.TrainState.create(apply_fn = model.apply, params = init_params, opt = opt)
    return state, sharpness_init

### Training loop

In [8]:
def train_and_evaluate(config, train_ds):
    """
    Description: Creates a training state and trains for 10 steps
    Input: 
        1. config: ml_collections dictionary containing all the hyperparameters
        2. train_ds: tuple (x_train, y_train) of the training data
    Output: 
        1. Divergence: bool; Flag for training divergence
        2. train_results: np array containing training loss, sharpness and accuracy trajectories
    """

    train_results = list()

    #create train and validation batches using rng
    rng = config.sgd_rng
    # create train and test batches for measurements: measure batches are called train_batches and val_batches; training batches are called batches
    train_batches = data_utils_pytorch.data_stream(rng, train_ds, config.measure_examples)

    #create model
    state, sharpness_init = create_train_state(config, train_batches)
    init_params = state.params

    #print(f'Top Hessian init: {sharpness_init:0.4f}, Learning rate: {config.lr_const/sharpness_init:0.4f}')

    #measure metrics at initialization
    train_loss_init, train_acc_init = train_utils.measure_state(state, train_batches, config.num_train, config.measure_examples)

    #store results at init
    result_init = jnp.asarray([0, 0, train_loss_init, train_loss_init, train_acc_init, sharpness_init, sharpness_init])
    train_results.append(result_init)
    print(f'{0}, {0}, {train_loss_init:0.4f}, {train_acc_init:0.4f}, {sharpness_init:0.4f}, {sharpness_init:0.4f}')

    divergence = False

    for epoch in range(config.num_epochs):

        rng, _ = jax.random.split(rng)
        batches = data_utils_pytorch.data_stream(rng, train_ds, config.batch_size)

        for batch_ix in range(config.num_steps):
            #get the next batch and calculate the step
            batch = next(batches)
            x, y = batch
            step = config.num_batches*epoch + batch_ix
            
            #compute gradient and other metrics and update the model
            state, loss_batch = train_utils.train_batch(state, batch)

            # hard cut-off on loss divergence; catapult loss is at most of O(width); max width in this experiment is 2048
            if loss_batch > 10**5: 
                divergence = True
                break

            # evaluate training loss, accuracy and sharpness at step t            
            train_loss_step, train_acc_step = train_utils.measure_state(state, train_batches, config.num_train, config.measure_examples)
            sharpness_step = train_utils.estimate_hessian(state, batches, num_batches = config.measure_batches, power_iterations = config.power_iterations)
    
            #store results
            result_step = jnp.asarray([step+1, epoch+1, train_loss_step, train_loss_init, train_acc_step, sharpness_step, sharpness_init])
            train_results.append(result_step)
            print(f'{step+1}, {epoch+1}, {train_loss_step:0.4f}, {train_acc_step:0.4f}, {sharpness_step:0.4f}, {sharpness_init:0.4f}')
    
    del state
    train_results = jnp.asarray(train_results)
    train_results = jax.device_get(train_results)
    return divergence, train_results


### Hyperparameters

In [None]:
dataset = 'cifar10'
num_classes = 10

train_ds, test_ds, info = data_utils_pytorch.load_data_pytorch(dataset, num_classes)

#Hyperparameters
config = config_dict.ConfigDict()
config.num_train, config.num_test = info.num_train, info.num_test

config.act = 'relu'
config.model = f'fc_{config.act}'
config.use_bias = False 
config.varw = 2.0 # this is the variance of all but last layer
config.varwL = 1.0 # last layer variance

widths = [128, 256]
config.depth = 10
config.in_dim = info.in_dim
config.num_classes = info.num_classes 

# batch size
config.batch_size = 256
# batch size for hessian measurement
config.measure_examples = 512
#number of batches for hessian measurement
config.measure_batches = 1

config.num_batches = data_utils_pytorch.estimate_num_batches(config.num_train, config.batch_size)

#number of power iterations
config.power_iterations = 20

#optimizer related hyperparameters
# Sample learning rates in powers of 2 starting with 2**0.0
lr_exp_start = jax.device_put(0.0)
lr_step = 0.2 # step size for increasing the learning rate
config.momentum = jax.device_put(0.0) # momentum for sgd
config.num_steps = 10 # train for ten steps
config.num_epochs = 1 # one epoch only

# averages
init_averages = 1 #number of initialization averages
sgd_runs = 1 # number of sgd runs for each initialization

Files already downloaded and verified


Files already downloaded and verified




### Early time experiment:

* We will train a 8 layer ReLU FCN for 10 steps using MSE loss using SGD with learning rate $\eta = c / \lambda_0^H$ and batch size $B = 512$. Here, $\lambda_0^H$ is the sharpness at initialization. The same experiment can be performed with other architectures by replacing the model definition in the create_train_state function

* We will sample learning rate constants $c$ in powers of 2, i.e., $c = 2^x$ with $x$ incremented in steps of $0.1$. 

* We will start with $x = x_{min} = 0.0$, wich corresponds to $\eta = 1 / \lambda_0^H$, as we are interested in the catapult dynamics.

* To reduce computational time, we will run the experiment for only one random initialization.


Note: This experiment takes about 45 mins to run on a V100 with 16 GB RAM. 


<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [10]:
# for storing trajectories
dfs = list()

#run for all widths
for width in widths:
    config.width = width
    #run for init_averages different initializations
    for iteration in range(1, init_averages+1):
        config.init_rng = jax.random.PRNGKey(iteration)
        #run for sgd_run different mini batch sequences
        for run in range(1, sgd_runs + 1):
            config.sgd_rng = jax.random.PRNGKey(run)
            lr_exp = lr_exp_start
            divergence = False

            while not divergence:
                config.lr_exp = lr_exp
                config.lr_const = 2**lr_exp
                print(f'w: {config.width}, d: {config.depth}, I: {iteration}, J: {run}, x: {lr_exp:0.1f}, B: {config.batch_size}, t: {config.num_steps}')

                divergence, train_results = train_and_evaluate(config, train_ds)

                if not divergence:
                    # append training results
                    df = pd.DataFrame(train_results, columns = ['step', 'epoch', 'train_loss_step', 'train_loss_init', 'train_accuracy', 'sharpness_step', 'sharpness_init'], dtype = float)
                    df['lr_exp'] = config.lr_exp; df['lr_const'] = config.lr_const; df['batch_size'] = config.batch_size; df['num_steps'] = config.num_epochs; df['I'] = iteration; df['J'] = run; df['width'] = config.width; df['depth'] = config.depth
                    dfs.append(df)
                    del df

                else:
                    print('Divergence')

                del train_results

                lr_exp += lr_step

                collected = gc.collect()
                #print("Garbage collector: collected %d objects." % (collected))

# concatenate results
dfs = pd.concat(dfs, axis = 0, ignore_index = True)


w: 128, d: 10, I: 1, J: 1, x: 0.0, B: 256, t: 10


2024-11-05 04:56:55.521065: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.39GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.1420, 0.0996, 50.2012, 160.2212


2, 1, 0.1098, 0.0997, 35.0602, 160.2212


3, 1, 0.0930, 0.0994, 28.2017, 160.2212


4, 1, 0.0825, 0.1000, 24.0795, 160.2212


5, 1, 0.0756, 0.1004, 18.6472, 160.2212


6, 1, 0.0706, 0.1004, 16.8140, 160.2212


7, 1, 0.0666, 0.1012, 16.0409, 160.2212


8, 1, 0.0634, 0.1007, 14.9781, 160.2212


9, 1, 0.0608, 0.1009, 14.7899, 160.2212


10, 1, 0.0588, 0.1009, 14.7583, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 0.2, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.1253, 0.0995, 41.8742, 160.2211


2, 1, 0.0997, 0.0996, 30.0762, 160.2211


3, 1, 0.0858, 0.0999, 24.5538, 160.2211


4, 1, 0.0769, 0.0998, 18.7700, 160.2211


5, 1, 0.0709, 0.1003, 16.6683, 160.2211


6, 1, 0.0666, 0.1017, 15.4974, 160.2211


7, 1, 0.0631, 0.1009, 15.1189, 160.2211


8, 1, 0.0603, 0.1011, 14.1143, 160.2211


9, 1, 0.0581, 0.1009, 14.0815, 160.2211


10, 1, 0.0563, 0.1006, 14.1330, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 0.4, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.1103, 0.0992, 34.2364, 160.2211


2, 1, 0.0907, 0.0991, 25.5066, 160.2211


3, 1, 0.0794, 0.0997, 20.9239, 160.2211


4, 1, 0.0719, 0.0996, 16.3771, 160.2211


5, 1, 0.0669, 0.1018, 15.2322, 160.2211


6, 1, 0.0632, 0.1014, 14.4123, 160.2211


7, 1, 0.0601, 0.1008, 14.2166, 160.2211


8, 1, 0.0577, 0.1004, 13.3309, 160.2211


9, 1, 0.0558, 0.1009, 13.4058, 160.2211


10, 1, 0.0542, 0.1006, 13.5155, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 0.6, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0974, 0.0996, 27.5449, 160.2211


2, 1, 0.0828, 0.0990, 21.3164, 160.2211


3, 1, 0.0738, 0.0998, 15.7423, 160.2211


4, 1, 0.0676, 0.1001, 14.9457, 160.2211


5, 1, 0.0633, 0.1016, 14.0046, 160.2211


6, 1, 0.0602, 0.1009, 13.3518, 160.2211


7, 1, 0.0576, 0.1008, 13.2508, 160.2211


8, 1, 0.0555, 0.1008, 12.5362, 160.2211


9, 1, 0.0538, 0.1004, 12.7956, 160.2211


10, 1, 0.0525, 0.1009, 13.0310, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 0.8, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0869, 0.0993, 21.9658, 160.2211


2, 1, 0.0760, 0.0988, 15.8357, 160.2211


3, 1, 0.0689, 0.1002, 14.1270, 160.2211


4, 1, 0.0639, 0.1015, 13.6485, 160.2211


5, 1, 0.0603, 0.1010, 12.8012, 160.2211


6, 1, 0.0577, 0.1009, 12.2562, 160.2211


7, 1, 0.0554, 0.1003, 12.2353, 160.2211


8, 1, 0.0537, 0.1004, 11.6861, 160.2211


9, 1, 0.0522, 0.1011, 12.3766, 160.2211


10, 1, 0.0511, 0.1016, 12.6569, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 1.0, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.0791, 0.0989, 18.9812, 160.2212


2, 1, 0.0704, 0.0991, 14.7621, 160.2212


3, 1, 0.0647, 0.1006, 12.8429, 160.2212


4, 1, 0.0606, 0.1010, 12.3511, 160.2212


5, 1, 0.0577, 0.1007, 11.5469, 160.2212


6, 1, 0.0555, 0.1005, 11.0778, 160.2212


7, 1, 0.0536, 0.1006, 11.1396, 160.2212


8, 1, 0.0521, 0.1001, 10.7226, 160.2212


9, 1, 0.0509, 0.1011, 11.8203, 160.2212


10, 1, 0.0500, 0.1017, 12.1200, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 1.2, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0736, 0.1001, 17.8167, 160.2211


2, 1, 0.0661, 0.0999, 13.6054, 160.2211


3, 1, 0.0613, 0.1008, 11.6075, 160.2211


4, 1, 0.0579, 0.1007, 11.0689, 160.2211


5, 1, 0.0555, 0.1017, 10.2674, 160.2211


6, 1, 0.0536, 0.1009, 9.8695, 160.2211


7, 1, 0.0521, 0.1006, 10.0199, 160.2211


8, 1, 0.0508, 0.1003, 9.7126, 160.2211


9, 1, 0.0498, 0.1013, 11.0834, 160.2211


10, 1, 0.0490, 0.1014, 11.3027, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 1.4, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.0703, 0.1004, 17.0811, 160.2212


2, 1, 0.0630, 0.1026, 12.5102, 160.2212


3, 1, 0.0588, 0.1021, 10.4564, 160.2212


4, 1, 0.0558, 0.1015, 9.8236, 160.2212


5, 1, 0.0538, 0.1018, 9.0415, 160.2212


6, 1, 0.0522, 0.1014, 8.7048, 160.2212


7, 1, 0.0509, 0.1014, 8.8622, 160.2212


8, 1, 0.0498, 0.1009, 8.7485, 160.2212


9, 1, 0.0490, 0.1012, 10.0091, 160.2212


10, 1, 0.0483, 0.1010, 10.2777, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 1.6, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0687, 0.0999, 16.5644, 160.2211


2, 1, 0.0609, 0.1011, 11.4919, 160.2211


3, 1, 0.0569, 0.1021, 9.3692, 160.2211


4, 1, 0.0543, 0.1019, 8.6789, 160.2211


5, 1, 0.0525, 0.1016, 7.8843, 160.2211


6, 1, 0.0511, 0.1020, 7.5456, 160.2211


7, 1, 0.0500, 0.1016, 7.6899, 160.2211


8, 1, 0.0491, 0.1010, 7.7044, 160.2211


9, 1, 0.0483, 0.1017, 8.7397, 160.2211


10, 1, 0.0477, 0.1014, 9.0127, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 1.8, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0684, 0.1001, 16.1955, 160.2211


2, 1, 0.0596, 0.1001, 10.5599, 160.2211


3, 1, 0.0557, 0.1014, 8.3933, 160.2211


4, 1, 0.0533, 0.1001, 7.6231, 160.2211


5, 1, 0.0516, 0.1010, 6.7770, 160.2211


6, 1, 0.0504, 0.1018, 6.4765, 160.2211


7, 1, 0.0494, 0.1014, 6.7013, 160.2211


8, 1, 0.0486, 0.1013, 6.5867, 160.2211


9, 1, 0.0479, 0.1022, 7.3404, 160.2211


10, 1, 0.0474, 0.1007, 7.6089, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 2.0, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0688, 0.1001, 15.9391, 160.2211


2, 1, 0.0589, 0.1001, 9.6606, 160.2211


3, 1, 0.0550, 0.0996, 7.5106, 160.2211


4, 1, 0.0526, 0.1002, 6.6742, 160.2211


5, 1, 0.0511, 0.0999, 5.9144, 160.2211


6, 1, 0.0499, 0.1010, 5.6182, 160.2211


7, 1, 0.0490, 0.1004, 5.7612, 160.2211


8, 1, 0.0483, 0.1009, 5.5515, 160.2211


9, 1, 0.0476, 0.1018, 6.0857, 160.2211


10, 1, 0.0471, 0.1002, 6.3510, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 2.2, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0697, 0.1002, 15.7443, 160.2211


2, 1, 0.0583, 0.1005, 8.7884, 160.2211


3, 1, 0.0545, 0.1001, 6.6910, 160.2211


4, 1, 0.0522, 0.0995, 5.7607, 160.2211


5, 1, 0.0508, 0.0992, 5.2191, 160.2211


6, 1, 0.0497, 0.0998, 4.8285, 160.2211


7, 1, 0.0488, 0.0992, 4.9026, 160.2211


8, 1, 0.0481, 0.1004, 4.5988, 160.2211


9, 1, 0.0475, 0.1001, 4.9636, 160.2211


10, 1, 0.0471, 0.0991, 5.1955, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 2.4, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.0707, 0.1001, 15.5726, 160.2212


2, 1, 0.0579, 0.1003, 7.9839, 160.2212


3, 1, 0.0540, 0.1002, 5.9388, 160.2212


4, 1, 0.0518, 0.1001, 4.9194, 160.2212


5, 1, 0.0505, 0.0993, 4.6530, 160.2212


6, 1, 0.0495, 0.0996, 4.2213, 160.2212


7, 1, 0.0487, 0.0985, 4.2390, 160.2212


8, 1, 0.0480, 0.0997, 3.8795, 160.2212


9, 1, 0.0474, 0.1001, 4.1261, 160.2212


10, 1, 0.0470, 0.0995, 4.2973, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 2.6, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0717, 0.1001, 15.4796, 160.2211


2, 1, 0.0573, 0.1003, 7.1997, 160.2211


3, 1, 0.0535, 0.1003, 5.2665, 160.2211


4, 1, 0.0514, 0.1004, 4.4176, 160.2211


5, 1, 0.0501, 0.1009, 4.1940, 160.2211


6, 1, 0.0492, 0.1004, 3.7556, 160.2211


7, 1, 0.0485, 0.1002, 3.7371, 160.2211


8, 1, 0.0478, 0.1002, 3.3660, 160.2211


9, 1, 0.0473, 0.1003, 3.5688, 160.2211


10, 1, 0.0469, 0.1001, 3.6784, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 2.8, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0727, 0.1001, 15.6742, 160.2211


2, 1, 0.0565, 0.1004, 6.4178, 160.2211


3, 1, 0.0528, 0.1006, 4.7735, 160.2211


4, 1, 0.0509, 0.1002, 4.0121, 160.2211


5, 1, 0.0497, 0.1006, 3.7971, 160.2211


6, 1, 0.0488, 0.1009, 3.3961, 160.2211


7, 1, 0.0481, 0.1003, 3.3752, 160.2211


8, 1, 0.0475, 0.0990, 3.0194, 160.2211


9, 1, 0.0471, 0.0998, 3.1970, 160.2211


10, 1, 0.0467, 0.1008, 3.2587, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 3.0, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.0741, 0.1001, 16.2621, 160.2212


2, 1, 0.0553, 0.1001, 5.6400, 160.2212


3, 1, 0.0520, 0.1006, 4.3041, 160.2212


4, 1, 0.0503, 0.1006, 3.6190, 160.2212


5, 1, 0.0492, 0.1009, 3.4240, 160.2212


6, 1, 0.0484, 0.0999, 3.0905, 160.2212


7, 1, 0.0478, 0.1000, 3.0701, 160.2212


8, 1, 0.0472, 0.0996, 2.7426, 160.2212


9, 1, 0.0468, 0.1001, 2.9360, 160.2212


10, 1, 0.0465, 0.0997, 2.9535, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 3.2, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0760, 0.1000, 17.3443, 160.2211


2, 1, 0.0538, 0.1006, 4.8386, 160.2211


3, 1, 0.0511, 0.1006, 3.8295, 160.2211


4, 1, 0.0496, 0.1010, 3.2421, 160.2211


5, 1, 0.0486, 0.0995, 3.1020, 160.2211


6, 1, 0.0479, 0.0994, 2.8476, 160.2211


7, 1, 0.0474, 0.0989, 2.8335, 160.2211


8, 1, 0.0469, 0.0991, 2.5548, 160.2211


9, 1, 0.0465, 0.1001, 2.7617, 160.2211


10, 1, 0.0463, 0.0994, 2.7518, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 3.4, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.0786, 0.1000, 18.9680, 160.2212


2, 1, 0.0522, 0.1008, 4.0257, 160.2212


3, 1, 0.0501, 0.0993, 3.3704, 160.2212


4, 1, 0.0489, 0.0996, 2.9061, 160.2212


5, 1, 0.0481, 0.0997, 2.8317, 160.2212


6, 1, 0.0475, 0.0991, 2.6684, 160.2212


7, 1, 0.0471, 0.0999, 2.6627, 160.2212


8, 1, 0.0467, 0.0992, 2.4290, 160.2212


9, 1, 0.0463, 0.0998, 2.6368, 160.2212


10, 1, 0.0461, 0.0992, 2.6333, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 3.6, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.0823, 0.1000, 21.2198, 160.2212


2, 1, 0.0505, 0.0988, 3.2692, 160.2212


3, 1, 0.0491, 0.0994, 2.9085, 160.2212


4, 1, 0.0482, 0.0999, 2.5696, 160.2212


5, 1, 0.0476, 0.0994, 2.5854, 160.2212


6, 1, 0.0472, 0.0990, 2.5098, 160.2212


7, 1, 0.0468, 0.0998, 2.5144, 160.2212


8, 1, 0.0464, 0.0999, 2.3464, 160.2212


9, 1, 0.0461, 0.0998, 2.5631, 160.2212


10, 1, 0.0459, 0.0998, 2.5678, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 3.8, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0876, 0.1000, 24.2791, 160.2211


2, 1, 0.0491, 0.0982, 2.5849, 160.2211


3, 1, 0.0482, 0.0987, 2.4507, 160.2211


4, 1, 0.0476, 0.0997, 2.2180, 160.2211


5, 1, 0.0472, 0.0993, 2.3202, 160.2211


6, 1, 0.0468, 0.0988, 2.3364, 160.2211


7, 1, 0.0465, 0.0993, 2.3580, 160.2211


8, 1, 0.0462, 0.0993, 2.2939, 160.2211


9, 1, 0.0460, 0.0984, 2.4902, 160.2211


10, 1, 0.0458, 0.0996, 2.5124, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 4.0, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.0953, 0.1000, 28.3368, 160.2211


2, 1, 0.0481, 0.0997, 1.9435, 160.2211


3, 1, 0.0475, 0.1001, 1.9633, 160.2211


4, 1, 0.0470, 0.0992, 1.8456, 160.2211


5, 1, 0.0467, 0.0998, 2.0140, 160.2211


6, 1, 0.0464, 0.0991, 2.1044, 160.2211


7, 1, 0.0462, 0.0994, 2.1605, 160.2211


8, 1, 0.0460, 0.0990, 2.2014, 160.2211


9, 1, 0.0458, 0.1003, 2.3761, 160.2211


10, 1, 0.0457, 0.0992, 2.4166, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 4.2, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.1069, 0.1000, 33.8955, 160.2211


2, 1, 0.0477, 0.0998, 1.5050, 160.2211


3, 1, 0.0471, 0.1007, 1.5238, 160.2211


4, 1, 0.0468, 0.0998, 1.4629, 160.2211


5, 1, 0.0465, 0.1003, 1.6823, 160.2211


6, 1, 0.0462, 0.1000, 1.8212, 160.2211


7, 1, 0.0460, 0.0998, 1.9375, 160.2211


8, 1, 0.0458, 0.1001, 2.1173, 160.2211


9, 1, 0.0456, 0.0995, 2.2501, 160.2211


10, 1, 0.0455, 0.0989, 2.2827, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 4.4, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.1257, 0.1000, 42.2624, 160.2212


2, 1, 0.0476, 0.0994, 1.3098, 160.2212


3, 1, 0.0471, 0.1007, 1.3427, 160.2212


4, 1, 0.0467, 0.1000, 1.3338, 160.2212


5, 1, 0.0463, 0.1009, 1.7000, 160.2212


6, 1, 0.0461, 0.1001, 1.8396, 160.2212


7, 1, 0.0458, 0.1006, 2.0855, 160.2212


8, 1, 0.0457, 0.1012, 2.3829, 160.2212


9, 1, 0.0455, 0.1012, 2.5327, 160.2212


10, 1, 0.0454, 0.1003, 2.5209, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 4.6, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.1580, 0.1000, 55.8445, 160.2211


2, 1, 0.0477, 0.1001, 1.2923, 160.2211


3, 1, 0.0472, 0.1015, 1.2960, 160.2211


4, 1, 0.0467, 0.1010, 1.3131, 160.2211


5, 1, 0.0463, 0.0999, 1.8048, 160.2211


6, 1, 0.0460, 0.1007, 1.9384, 160.2211


7, 1, 0.0457, 0.0993, 2.2511, 160.2211


8, 1, 0.0456, 0.1000, 2.5292, 160.2211


9, 1, 0.0454, 0.0994, 2.6621, 160.2211


10, 1, 0.0453, 0.1006, 2.6405, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 4.8, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.2163, 0.1000, 78.9179, 160.2211


2, 1, 0.0481, 0.0999, 1.3591, 160.2211


3, 1, 0.0474, 0.1004, 1.2405, 160.2211


4, 1, 0.0468, 0.1001, 1.0481, 160.2211


5, 1, 0.0463, 0.1003, 1.6738, 160.2211


6, 1, 0.0459, 0.0994, 1.8411, 160.2211


7, 1, 0.0456, 0.1007, 2.1564, 160.2211


8, 1, 0.0455, 0.0999, 2.5429, 160.2211


9, 1, 0.0453, 0.1005, 2.5911, 160.2211


10, 1, 0.0452, 0.1001, 2.5339, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 5.0, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2211, 160.2211


1, 1, 0.3253, 0.1000, 119.3887, 160.2211


2, 1, 0.0495, 0.0999, 1.8736, 160.2211


3, 1, 0.0484, 0.1007, 1.4269, 160.2211


4, 1, 0.0476, 0.0999, 1.3402, 160.2211


5, 1, 0.0470, 0.1009, 1.0635, 160.2211


6, 1, 0.0465, 0.0992, 1.1830, 160.2211


7, 1, 0.0460, 0.1005, 1.3039, 160.2211


8, 1, 0.0457, 0.1001, 1.8627, 160.2211


9, 1, 0.0455, 0.0996, 1.9530, 160.2211


10, 1, 0.0453, 0.1006, 2.0212, 160.2211


w: 128, d: 10, I: 1, J: 1, x: 5.2, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.5388, 0.1000, 192.7224, 160.2212


2, 1, 0.0544, 0.0990, 5.0748, 160.2212


3, 1, 0.0497, 0.1001, 1.3039, 160.2212


4, 1, 0.0488, 0.0992, 1.1645, 160.2212


5, 1, 0.0481, 0.1009, 0.7769, 160.2212


6, 1, 0.0476, 0.0997, 0.7733, 160.2212


7, 1, 0.0470, 0.0997, 1.3480, 160.2212


8, 1, 0.0466, 0.1002, 1.3081, 160.2212


9, 1, 0.0462, 0.1000, 1.8658, 160.2212


10, 1, 0.0459, 0.0999, 1.8395, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 5.4, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 0.9792, 0.1000, 330.3762, 160.2212


2, 1, 0.0891, 0.0999, 29.8593, 160.2212


3, 1, 0.0510, 0.1006, 1.0395, 160.2212


4, 1, 0.0502, 0.1012, 0.7769, 160.2212


5, 1, 0.0498, 0.1002, 0.7380, 160.2212


6, 1, 0.0495, 0.0997, 0.7083, 160.2212


7, 1, 0.0492, 0.0990, 0.6777, 160.2212


8, 1, 0.0490, 0.1002, 0.6995, 160.2212


9, 1, 0.0487, 0.1006, 0.7121, 160.2212


10, 1, 0.0484, 0.0996, 0.6349, 160.2212


w: 128, d: 10, I: 1, J: 1, x: 5.6, B: 256, t: 10


0, 0, 0.4071, 0.1000, 160.2212, 160.2212


1, 1, 1.9380, 0.1000, 602.2175, 160.2212


2, 1, 1.2359, 0.0999, 586.8378, 160.2212


3, 1, 0.5964, 0.1001, 670.6859, 160.2212


4, 1, 39.2308, 0.1004, 106962.4375, 160.2212


5, 1, 156136283043892612929341620224.0000, 0.0997, nan, 160.2212
Divergence


w: 256, d: 10, I: 1, J: 1, x: 0.0, B: 256, t: 10


2024-11-05 05:28:05.205077: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.21GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


2024-11-05 05:28:08.361694: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.21GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


2024-11-05 05:28:15.166621: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.77GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


2024-11-05 05:28:16.490636: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.41GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


2024-11-05 05:28:29.026410: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.57GiB (rounded to 7058878976)requested by op 
2024-11-05 05:28:29.026712: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] *___________________________________________________________________________________________________
E1105 05:28:29.026752  824005 pjrt_stream_executor_client.cc:3085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7058878944 bytes.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7058878944 bytes.

Post processing

In [None]:
# early training dynamics
dfs = dfs.astype(float)
dfs['norm_sharp'] = dfs['sharpness_step'] / dfs['sharpness_init']
dfs['norm_loss'] = dfs['train_loss_step'] / dfs['train_loss_init']

### Plot training trajectories

In [None]:
lr_plot = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
dfs['lr_exp'] = dfs['lr_exp'].round(1)
dfs['lr_const'] = dfs['lr_const'].round(1)

df_plot = get_rows_where_col_in(dfs, 'lr_exp', lr_plot)
print(df_plot['lr_exp'].unique())

for i, width in enumerate(widths):
    
    df_width = get_rows_where_col_equals(df_plot, 'width', width)

    fig, axes = plt.subplots(1, 2, figsize = (12, 4))
    ax = axes[0]
    ax = sns.lineplot(x = 'step', y = 'train_loss_step', data = df_width, hue = 'lr_const', palette = 'crest', legend = 'full', ax = ax)
    ax.set_xlabel('step')
    ax.set_ylabel(r'Training loss')
    ax.set_title(f'width: {width}')
    ax.get_legend().remove()

    ax = axes[1]
    ax = sns.lineplot(x = 'step', y = 'norm_sharp', data = df_width, hue = 'lr_const', palette = 'crest', legend = 'full', ax = ax)
    ax.set_xlabel('step')
    ax.set_ylabel(r'$\frac{\lambda_t^H}{\lambda_0^H}$')
    ax.set_title(f'width: {width}')
    ax.legend(title = r'$c$', loc='center left', bbox_to_anchor=(1, 0.5))


### Estimate critical constants

In [None]:
# note: For zero initialized networks, sharpness does not decrease appreciably during early training. 
# To avoid random fluctuations in sharpness to be considered as sharpness catapult, we compare the normalized sharpness with 1+epsilon to define sharpness catapult.

epsilon = 0.05

#estimate c_loss
df_filtered = dfs[dfs['norm_loss'] > 1.0]
df_grouped = df_filtered.groupby('width')['lr_const'].min()
df_closs = df_grouped.reset_index().astype(float)
df_closs['1w'] = 1 / df_closs['width']

#estimate c_sharp
df_filtered = dfs[dfs['norm_sharp'] > 1 + epsilon]
df_grouped = df_filtered.groupby('width')['lr_const'].min()
df_csharp = df_grouped.reset_index().astype(float)
df_csharp['1w'] = 1 / df_csharp['width']

#estimate c_max
df_grouped = dfs.groupby('width')['lr_const'].max()
df_cmax = df_grouped.reset_index().astype(float)
df_cmax['1w'] = 1 / df_cmax['width']



### Plot the phase diagram

In [None]:
order = 1 #order of curve to fit
# for this phase diagram order 1 works better

fig, ax = plt.subplots(1, 1, figsize = (8, 6))
colors = sns.color_palette('tab10', 3)

ax.scatter('lr_const', '1w', data = df_closs, marker = 's', color = 'black', label = r'$\langle c_{loss} \rangle$', facecolors = 'none', zorder = 2, s = 64)
ax.scatter('lr_const', '1w', data = df_csharp, marker = '^', color = 'black', label = r'$\langle c_{sharp} \rangle$', facecolors = 'none', zorder = 2, s = 64)
ax.scatter('lr_const', '1w', data = df_cmax, marker = 'D', color = 'black', label = r'$\langle c_{max} \rangle$', facecolors = 'none', zorder = 2, s = 64)

# fit curves
y = df_closs['1w'].values

ymin = np.min(y)
ymax = np.max(y)

y_new = np.array([2**i for i in np.arange(np.log2(ymin) - 0.1, np.log2(ymax) + 0.1, 0.05)])

x = df_closs['lr_const'].values
coeffs = np.polyfit(y, x, order)
smooth_loss = poly(coeffs, y_new)
ax.plot(smooth_loss, y_new, '-', color = 'black')

x = df_csharp['lr_const'].values
coeffs = np.polyfit(y, x, order)
smooth_sharp = poly(coeffs, y_new)
ax.plot(smooth_sharp, y_new, '-', color = 'black')


x = df_cmax['lr_const'].values
coeffs = np.polyfit(y, x, order)
smooth_max = poly(coeffs, y_new)
ax.plot(smooth_max, y_new, '-', color = 'black')

x_max = np.max(x) + 5

#Fill colors
y1 = 0.0*np.ones(len(y_new))
y2 = 2*np.ones(len(y_new))
ax.fill(np.append(y1, y2[::-1]), np.append(y_new, y_new[::-1]), '#CFF5E7')

y1 = 2*np.ones(len(y_new))
y2 = smooth_loss
ax.fill(np.append(y1, y2[::-1]), np.append(y_new, y_new[::-1]), '#FFFAD7')

y1 = smooth_loss
y2 = smooth_sharp
ax.fill(np.append(y1, y2[::-1]), np.append(y_new, y_new[::-1]), '#FCDDB0')

y1 = smooth_sharp
y2 = smooth_max
ax.fill(np.append(y1, y2[::-1]), np.append(y_new, y_new[::-1]), '#FF9F9F')

y1 = smooth_max
y2 = x_max*np.ones(len(y_new))
ax.fill(np.append(y1, y2[::-1]), np.append(y_new, y_new[::-1]), '#E97777')


ax.axvline(x = 2.0, linestyle = '--', color = 'gray')
ax.set_xscale('log', base = 2)
ax.set_yscale('log', base = 2)
ax.legend(fontsize = 20, loc='upper left', facecolor = 'white', framealpha = 0.5)
ax.set_ylabel(r'$1/w$')
ax.set_xlabel(r'$c$')
ax.set_xlim(0.5, x_max)
ax.set_ylim(y_new[0], y_new[-1])



Key results: 


*   critical constants $c_{loss}, c_{sharp}$ do not scale with $1 / w$
*   $c_{loss} = 2$ independent of depth and width.