# Gradient and computational time analysis (via the example of graphene)

## 1. Computational time analysis

In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import pickle

def crawl_time(fname):
    sample_time_list = []
    optim_time_list = []
    get_time = lambda line: float(line.split(', ')[-1].split(' seconds in total')[0])

    with open(fname) as file:
        for line in file:
            if 'Sampling duration' in line:
                sample_time_list.append(get_time(line))
            elif 'Optimization duration' in line:
                optim_time_list.append(get_time(line))

    # remove first few iterations which includes warmup time
    min_t = min(len(sample_time_list),len(optim_time_list))
    if min_t > 60:
        sample_time_list = sample_time_list[60:min_t]
        optim_time_list = optim_time_list[60:min_t]
    else:
        sample_time_list = sample_time_list[1:min_t]
        optim_time_list = optim_time_list[1:min_t]
    total_time_list = list(np.array(sample_time_list) + np.array(optim_time_list))

    return sample_time_list, optim_time_list, total_time_list

# replace xxx_slurm.out by an actual slurm output file, which logs the time taken for each step of training
flist = [
    '_log_graphene_OG_test_multi/_xxxxxx_0_slurm.out', 
    '_log_graphene_DA_test_multi/_xxxxxx_0_slurm.out' 
]

labels = [
    "OG",
    "DA",
]

timelist = [crawl_time(f) for f in flist]
samplet_list = [t[0] for t in timelist]
optimt_list = [t[1] for t in timelist]
totalt_list = [t[2] for t in timelist]

In [None]:
# example: outputs the total GPU hours used for OG training with 5 gpus and 80000 iterations
i = 0               # specify which setup to check
num_gpus = 5        # specify the number of GPUs used
num_steps = 80000   # specify the number of total iterations

pre = 4
mean = np.mean(totalt_list[i])
ste = np.std(totalt_list[i]) / np.sqrt(len(totalt_list[i]))

labels[i], round(mean, pre) * num_gpus * num_steps / 3600, round(ste, pre) * num_gpus * num_steps / 3600

# 2. Gradient evaluation

Example code for computing the gradients of the network obtained from doing 1 gradient step at different checkpoints.

In [None]:
from utils.loader import mpatch_load_cfg
import pickle, os, time
from absl import logging

sim_num = 50  # number of times to simulate the one-gradient step

meta_list = [
    [
        '_log_graphene_OG_test_multi/', 
        ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ],
    [
        '_log_graphene_DA_test_multi/', 
        ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ],
]

# read environmental variables for distributed setup
if 'COORD_IP' in os.environ and 'PORT' in os.environ:
    coord_address = str(os.environ['COORD_IP']).strip()+":"+str(os.environ['PORT']).strip()
else:
    coord_address = None

num_processes_str = os.environ.get('NUM_JOBS')
num_processes = int(num_processes_str) if num_processes_str else 1

process_id_str = os.environ.get('SLURM_ARRAY_TASK_ID')
process_id = int(process_id_str) if process_id_str else 0

job_id_str = os.environ.get('SLURM_ARRAY_JOB_ID')
job_id = int(job_id_str) if job_id_str else None

timeout_str = os.environ.get('TIMEOUT')
timeout = int(timeout_str) if timeout_str else None

dist_initialize = False

for meta_info in meta_list:
    log_dir = meta_info[0]
    ckpts = meta_info[1]
    cfg = mpatch_load_cfg(
        log_dir=log_dir,
        mode='train',
        libcu_lib_path='/opt/conda/envs/deepsolid/lib/',
        resume=True,
        coord_address=coord_address,
        num_processes=num_processes,
        process_id=process_id,
        job_id=job_id,
        timeout=timeout,
        x64=True,
        dist_initialize=dist_initialize,
    )

    import jax
    from DeepSolid import constants

    seed = int(1e6 * time.time())
    key = jax.random.PRNGKey(seed)
    sharded_key = constants.make_different_rng_key_on_all_devices(key)
    sharded_key, subkeys = constants.p_split(sharded_key)

    if dist_initialize is False:
        from utils.process import process
        from DeepSolid.utils.kfac_ferminet_alpha import utils as kfac_utils
        dist_initialize = True

    for ckpt in ckpts:
        method_dict, result_dict = process(cfg, 
                                           process_id, 
                                           get_gradient_for_one_step=True, 
                                           ckpt_restore_filename=f'{log_dir}qmcjax_ckpt_{ckpt}_process0.npz'
        )
        sharded_key = method_dict['sharded_key']
        new_params_list = []

        for i in range(sim_num):
            logging.info(f'sim {i}')
            sharded_key, subkeys = kfac_utils.p_split(sharded_key)
            new_data, _ = method_dict['mcmc_step'](method_dict['old_params'], method_dict['old_data'], subkeys)

            new_data = method_dict['mask_if_required'](data=new_data, sharded_key=sharded_key)
            processed_data, processed_data_with_keys, sharded_key = method_dict['augment_if_required'](data=new_data, sharded_key=sharded_key)
            
            sharded_key, subkeys = kfac_utils.p_split(sharded_key)
            new_params, _, _ = method_dict['optimizer_step'](  
                params=method_dict['old_params'],
                rng=subkeys,
                data_iterator=iter([processed_data_with_keys]) if method_dict['need_key_for_optim'] else iter([processed_data]),
            )
            new_params_list.append(new_params)

        if mcmc_steps is not None:
            with open(f"{log_dir}params_ckpt_{ckpt}_mcmc{mcmc_steps}.pk", "wb+") as f:
                pickle.dump(new_params_list, f)
        else:
            with open(f"{log_dir}params_ckpt_{ckpt}.pk", "wb+") as f:
                pickle.dump(new_params_list, f)
        


retrieve std estimates of the gradients (same as that of new_params, since all new_params are obtained from the same init param, and the randomness is over sampling)

In [None]:
import matplotlib.pyplot as plt
import pickle
import jax.numpy as jnp
from matplotlib import colors
import numpy as np

def convert_dict_to_array(input):
    if isinstance(input, dict):
        return jnp.concatenate([convert_dict_to_array(input[key]) for key in input])
    # else, array
    try:
        return jnp.array(input).flatten()
    except TypeError:
        return jnp.concatenate([convert_dict_to_array(item) for item in input])

meta_list = [
    [
        '_log_graphene_OG_test_multi/', 
        ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ],
    [
        '_log_graphene_DA_test_multi/', 
        ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ],
]

params_all = []
for meta_info in meta_list:
    log_dir = meta_info[0]
    ckpts = meta_info[1]
    params_ckpts = []
    for ckpt in ckpts:
        print(f'Reading {log_dir}params_ckpt_{ckpt}.pk...', end='\r')
        with open(f"{log_dir}params_ckpt_{ckpt}.pk", "rb") as f:
            new_params_dicts = pickle.load(f)
        new_params_list = jnp.array([convert_dict_to_array(param_dict) for param_dict in new_params_dicts])
        
        params_mean = jnp.mean(new_params_list, axis=0)
        params_std = jnp.std(new_params_list, axis=0, ddof=1)
        params_std_std = jnp.std(jnp.sqrt((new_params_list - params_mean)**2), axis=0, ddof=1)

        params_ckpts.append([params_mean, params_std, params_std_std])
        
    params_all.append(params_ckpts)

In [None]:
# convert time taken until each checkpoint to GPU hours

gpuhrs_multipliers = [
    5 / 3600,
    5 / 3600,
    5 / 3600,
    5 / 3600,
    5 / 3600,
]

gpuhrs_mean_list = [ np.mean(ts) * m for ts, m in zip(totalt_list, gpuhrs_multipliers)]


In [None]:
# std plot
from matplotlib.ticker import ScalarFormatter

label_list = [
                r'OG, $N=1000$', 
                r'DA, $N=90, k=12$',
]
color_list = [
                'black',
                'tab:blue', 
] 
transparent_ratio = 0.2
sim_num = 50 # number of times the one-gradient step has been simulated

fig,ax = plt.subplots(figsize=(5,3))

for params_ckpts, meta_info, label, color, gpuhrs_mean in zip(params_all, meta_list, label_list, color_list, gpuhrs_mean_list):
    t = [int(a) * 1. for a in meta_info[1]]
    gpuhrs = np.array(t) * gpuhrs_mean

    std_vecs = [p[1] for p in params_ckpts]
    error_bars = [p[2]/jnp.sqrt(sim_num) for p in params_ckpts]
    sum_stds = [jnp.sqrt(jnp.sum(v**2) / len(v)) for v in std_vecs]
    sum_err_bars = [jnp.sqrt(jnp.sum(e**2) / len(e)) for e in error_bars]
    sum_lower_err = [s - e for s, e in zip(sum_stds, sum_err_bars)]
    sum_upper_err = [s + e for s, e in zip(sum_stds, sum_err_bars)]

    prune = lambda x: [b for a,b in zip(gpuhrs,x) if int(a) > 30 and int(a) <= 300]
    
    gpuhrs = prune(gpuhrs)
    sum_stds = prune(sum_stds)
    sum_lower_err = prune(sum_lower_err)
    sum_upper_err = prune(sum_upper_err)

    ax.plot(gpuhrs, sum_stds, '-o', label=label, color=color)

    fill_color = np.array(colors.to_rgba(color))
    fill_color[3] *= transparent_ratio
    ax.fill_between(gpuhrs, sum_lower_err, sum_upper_err, facecolor=fill_color)

ax.set_xlim([30,285])
ax.legend(loc='lower left', fontsize=10,labelspacing=0)
ax.tick_params(axis='both', which='major', labelsize=10)

# Enable scientific notation for the axes
formatter = ScalarFormatter()
formatter.set_scientific(True)  # Enable scientific notation
formatter.set_powerlimits((-2, 2))  # Control when scientific notation is used (optional)
formatter.useMathText = True

# Apply formatter to both x and y axes
ax.yaxis.set_major_formatter(formatter)

plt.savefig('gradient_stab_std_gpuhrs.pdf')