# Visualizations

# 1. Geometry

This section covers Figure 1(d), (e), (f), Figure 2, Figure 6 and Figure 8. 

The visualization here (and any call to the SpaceGroupVisualizer object) requires an environment with `pyvista` correctly configured, which is absent from the singularity container. To have this environment set up, follow the non-singularity setup instructions in README.md.

## 1.1 Simple examples

In [None]:
from utils.visualize import SpaceGroupVisualizer
from IPython.display import clear_output

symm_visualizer = SpaceGroupVisualizer(
    base_cfg_path='utils/base_config.py',
    cfg_path = "config/graphene_1.py",
    libcu_lib_path = '/YOUR/ENV/PATH/TO/CUDA/LIB/FILES/', # '/opt/conda/envs/deepsolid/lib/',
)

clear_output(wait=True)

plotter = symm_visualizer.plot_supercell_atom(
    interactive=False,
    show=True
)

plotter.save_graphic("graphene_1_supercell.pdf")

In [None]:
# Figure 1(d)

plotter = symm_visualizer.plot_supercell_atom_symmetry(
    mode='gpave',
    interactive=False,
    show=True,
)

plotter.save_graphic("graphene_1_symmetry.pdf")

In [None]:
symm_visualizer.plot_supercell_group_asu(
    interactive=False,
    title=f"Normal vectors used for asu identification\n in canonicalization for graphene",
    show=True,
)
plotter.save_graphic("graphene_1_symmetry.pdf")

## 1.2 Auxiliary lines and texts

In [None]:
# example illustration to get additional lines for LiH - Figure 6(b)

from utils.visualize import SpaceGroupVisualizer
from IPython.display import clear_output

symm_visualizer = SpaceGroupVisualizer(
    base_cfg_path='utils/base_config.py',
    cfg_path = "LiH_1_g225.py",
    libcu_lib_path = '/YOUR/ENV/PATH/TO/CUDA/LIB/FILES/', # '/opt/conda/envs/deepsolid/lib/',
)

clear_output(wait=True)

plotter = symm_visualizer.plot_supercell_atom(
    interactive=False,
    show=False,
)


# extra lines for this plot
import numpy as np

L_Bohr = symm_visualizer.cell.L_Bohr
d = 0.5 * L_Bohr

lines = [
            np.array([ [d, d, z] for z in np.linspace(0, d, 20) ]),
            np.array([ [0, d, z] for z in np.linspace(0, d, 20) ]),
            np.array([ [d, 0, z] for z in np.linspace(0, d, 20) ]),
            np.array([ [0, 0, z] for z in np.linspace(0, d, 20) ]),
            np.array([ [d, y, d] for y in np.linspace(0, d, 20) ]),
            np.array([ [0, y, d] for y in np.linspace(0, d, 20) ]),
            np.array([ [d, y, 0] for y in np.linspace(0, d, 20) ]),
            np.array([ [0, y, 0] for y in np.linspace(0, d, 20) ]),
            np.array([ [x, d, d] for x in np.linspace(0, d, 20) ]),
            np.array([ [x, 0, d] for x in np.linspace(0, d, 20) ]),
            np.array([ [x, d, 0] for x in np.linspace(0, d, 20) ]),
            np.array([ [x, 0, 0] for x in np.linspace(0, d, 20) ]),
]

for line in lines:
    plotter.add_lines(line, color='black', width=4)

plotter.show()
plotter.save_graphic('LiH_1_g225_extra_lines.pdf')

## 1.3 Visualize a symmetric configuration from `symmscan_config` folder

In [None]:
# Figure 1(e)

plotter = symm_visualizer.plot_supercell_symmscan_cfg(
    symscan_fname='graphene-1-for-illustration.py',
    interactive=False,
    show=False,
    show_asu=True,
)

# extra lines for this plot
import numpy as np

L_Bohr = symm_visualizer.cell.L_Bohr
line1 = np.array([ 
            [1/2. * L_Bohr, np.cos(np.pi/6)/3 * L_Bohr, z] for z in np.linspace(-2, 2.2, 20)
        ])
line2 = np.array([ 
            [0., np.cos(np.pi/6)*2/3 * L_Bohr, z] for z in np.linspace(-2.2, 2, 20)
        ])
plotter.add_lines(line1, color='black', width=9)
plotter.add_lines(line2, color='red', width=9)
plotter.add_text('ref. line', position=[215, 475], color='red', font_size=25)
plotter.show()
plotter.save_graphic('graphene_1_symmscan.pdf')

In [None]:
# Figure 1(f)

plotter = symm_visualizer.plot_supercell_symmscan_cfg(
    symscan_fname='graphene-1-for-illustration-shifted.py',
    interactive=False,
    show=False,
    show_asu=True,
)

# extra lines for this plot
import numpy as np


L_Bohr = symm_visualizer.cell.L_Bohr

adj_15 = 1/2. * L_Bohr
opp_15 = adj_15 * np.tan(np.pi/12)
hyp_15 = adj_15 / np.cos(np.pi/12) 
opp_30 = np.cos(np.pi/6)/3 * L_Bohr

line1 = np.array([ 
            [1/2. * L_Bohr + adj_15, opp_30 - opp_30, z] for z in np.linspace(-2, 2.2, 20)
        ])
line2 = np.array([ 
            [0. + adj_15, opp_30*2 - opp_30, z] for z in np.linspace(-2.2, 2, 20)
        ])
plotter.add_lines(line1, color='black', width=9)
plotter.add_lines(line2, color='red', width=9)
plotter.add_text('ref. line', position=[610, 380], color='red', font_size=25)
plotter.show()
plotter.save_graphic('graphene_1_symmscan_shifted.pdf')

In [None]:
# Figure 8

plotter = symm_visualizer.plot_supercell_symmscan_cfg(
    symscan_fname='graphene-1-reflection.py',
    interactive=False,
    # title=f"Symmetric electron configuration",
    show=False,
    show_asu=True,
)

plotter.show()
plotter.save_graphic('symmscan_config/graphene_1_symmscan-reflection.pdf')

# 2. Training curves

In [None]:
from utils.visualize import DeepSolidVisualizer

visualizer = DeepSolidVisualizer(
    log_dir_list = [
                    'log_graphene_OG_test/',
                    'log_graphene_DA_test/',
                    ],
    label_list = [
                    'OG',
                    'DA'
                ],
    libcu_lib_path = '/opt/conda/envs/deepsolid/lib/',
)
visualizer.load_train_stats()

### Energy (moving average)

In [None]:
visualizer.plot_train_stats(
    field='energy', 
    ylim=(-76.2,-76.0), 
    # t_range=(None,2000), 
    title='Energy',
    figsize=(10,8),
    # savepath='energy.pdf'
    ma_window=200,
)

### Variance (moving average)

In [None]:
visualizer.plot_train_stats(
    field='variance', 
    ylim=(0, 100), 
    # t_range=(None,2000), 
    title='Variance',
    figsize=(10,8),
    # savepath='variance.pdf'
    ma_window=200,
)

### Symmetric measure: Var[group-ave net / net]. Perfect symmetry if equals zero

In [None]:
visualizer.plot_train_stats(
        field='symm_ratio_var', 
        ylim=(0,1), 
        # t_range=(None,2000), 
        title='Var[Symm ratio]',
        figsize=(10,8),
        # savepath='symm_ratio_var.pdf'
        ma_window=200,
)

# 3. Visualize invariance via diagonal translation of a symmetric configuration

## 3.1 Generate a list of electron configs by translating a symmetric config from `symmscan_config` folder

In [None]:
# generate a list of configurations to illustrate P3m1 symmetry
log_dir = '_log_graphene_OG_test/'

from utils.loader import load_module
import numpy as np

base_cfg = load_module('base_cfg', 'utils/base_config.py').default()
cfg = load_module('cfg', 'config/graphene_1.py').get_config(base_cfg)
L_Bohr = cfg.system.pyscf_cell.L_Bohr
elecs, _ = load_module('symmscan', f'symmscan_config/graphene-1-for-illustration.py').symm_cfg(L_Bohr)

tx = ty = np.linspace( - 2 * L_Bohr, 2 * L_Bohr, 500)
X, Y = np.meshgrid(tx, ty)
translates = np.array([X.flatten(), Y.flatten(), np.zeros_like(X.flatten())]).T

ts = translates[:,:2]
xs = np.array([ (elecs + t).flatten() for t in translates]) # simultaneous translations
np.savez(f'{log_dir}symmscan_-2_2_500.npz', xs=xs, ts=ts)

In [None]:
# generate a list of configurations to illustrate P6mm symmetry
log_dir = '_log_graphene_OG_test/'

from utils.loader import load_module
import numpy as np

base_cfg = load_module('base_cfg', 'utils/base_config.py').default()
cfg = load_module('cfg', 'config/graphene_1.py').get_config(base_cfg)
L_Bohr = cfg.system.pyscf_cell.L_Bohr
elecs, _ = load_module('symmscan', f'symmscan_config/graphene-1-reflection.py').symm_cfg(L_Bohr)

tx = ty = np.linspace( - 2 * L_Bohr, 2 * L_Bohr, 500)
X, Y = np.meshgrid(tx, ty)
translates = np.array([X.flatten(), Y.flatten(), np.zeros_like(X.flatten())]).T

ts = translates[:,:2]
xs = np.array([ (elecs + t).flatten() for t in translates]) # simultaneous translations
np.savez(f'{log_dir}symmscan_-2_2_500-reflection.npz', xs=xs, ts=ts)

## 3.2 Evaluate wavefunction and energy on these configurations

In [None]:
from eval import eval

for t in ['080000']:
    print(f'\n=====\OG, {t}\n======\n')
    for mode in ['energy', 'slogdet']:
        eval(
                log_dir='_log_graphene_OG_test/',
                mode=mode,
                input_file='symmscan_-2_2_500.npz',
                batch_size=100,
                libcu_lib_path='/opt/conda/envs/deepsolid/lib/',
                num_processes=1,
                process_id=0,
                ckpt_restore_filename=f'qmcjax_ckpt_{t}_process0.npz',
                save_freq=1,
                x64=True,
            )
        eval(
                log_dir='_log_graphene_OG_test/',
                mode=mode,
                input_file='symmscan_-2_2_500-reflection.npz',
                batch_size=100,
                libcu_lib_path='/opt/conda/envs/deepsolid/lib/',
                num_processes=1,
                process_id=0,
                ckpt_restore_filename=f'qmcjax_ckpt_{t}_process0.npz',
                save_freq=1,
                x64=True,
            )
        

for t in ['080000']:
    print(f'\n=====\PA, {t}\n======\n')
    for mode in ['energy', 'slogdet']:
        eval(
                log_dir='_log_graphene_PA_test/',
                mode=mode,
                input_file='symmscan_-2_2_500.npz',
                batch_size=100,
                libcu_lib_path='/opt/conda/envs/deepsolid/lib/',
                num_processes=1,
                process_id=0,
                ckpt_restore_filename=f'qmcjax_ckpt_{t}_process0.npz',
                save_freq=1,
                x64=True,
            )
        eval(
                log_dir='_log_graphene_PA_test/',
                mode=mode,
                input_file='symmscan_-2_2_500-reflection.npz',
                batch_size=100,
                libcu_lib_path='/opt/conda/envs/deepsolid/lib/',
                num_processes=1,
                process_id=0,
                ckpt_restore_filename=f'qmcjax_ckpt_{t}_process0.npz',
                save_freq=1,
                x64=True,
            )

## 3.3 Generate scan plots

These code snippets are used for generating Figure 1(a)-(c) and Figure 9.

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
from eval import retrieve_evals

def remove_duplicates_within_epsilon(arr, epsilon):
    """Remove duplicates from a NumPy array within a specified epsilon using an optimized method.
    
    Args:
        arr (np.ndarray): The input array.
        epsilon (float): The tolerance for considering values as duplicates.
        
    Returns:
        np.ndarray: A new array with duplicates removed.
    """
    # Sort the array
    sorted_arr = np.sort(arr)
    
    # Calculate the differences between consecutive elements
    diffs = np.diff(sorted_arr)
    
    # Create a mask for where the differences exceed epsilon
    mask = np.concatenate(([True], diffs > epsilon))  # Start with True for the first element
    
    # Use the mask to filter the unique values
    unique_values = sorted_arr[mask]

    return unique_values

def plot_2d_scan(mode, 
                 ckpt_list, scanfile2d, scansize, savepath=None, vmin=None, vmax=None, xrange=None, yrange=None, title=None, poly2d_list=None, traj1d_list=None, dpi=100, vertical=True, cmap = 'viridis', levels=None,  fontsize=None, cbarfrac=0.05,
                 poly2dcolor='red'):
    """
    """
    get_ckpt_fname = lambda iter: f'qmcjax_ckpt_{iter}_process0.npz'

    ncols = len(ckpt_list)
    if not vertical:
        fig, axes = plt.subplots(1, ncols, figsize=(ncols*6, 6), sharey=True, dpi=dpi)
        plt.subplots_adjust(hspace=-0.2,wspace=-0.05)
    else:
        fig, axes = plt.subplots(ncols, 1, figsize=(6, ncols*6), sharex=True, dpi=dpi)
        plt.subplots_adjust(hspace=-0.2,wspace=0.1)
    if ncols == 1:
        axes = [axes]

    norm = plt.Normalize(vmin=vmin, vmax=vmax)

    assert mode in ['slogdet', 'energy']
    if mode == 'slogdet':
        # normalize log-probability
        post_process_fn = lambda vals: 2 * vals - np.log(np.sum(np.exp(2 * vals)))
    else:
        # keep only real part of the energy
        post_process_fn = lambda vals: np.real(vals)

    
    for j, logdir_ckpt in enumerate(ckpt_list):
        log_dir = logdir_ckpt[0]
        ckpt = logdir_ckpt[1]
        # retrieve eval and plot 2d histogram
        _, ts, vals = retrieve_evals(
            log_dir=log_dir,
            mode=mode,
            input_file=scanfile2d,
            ckpt_restore_filename=get_ckpt_fname(ckpt),
            num_processes=1,
        )
        vals = post_process_fn(vals)

        X = ts[:,0].reshape(scansize)
        Y = ts[:,1].reshape(scansize)
        Z = vals.reshape(scansize)
    
        im = axes[j].contourf(X, Y, Z, cmap=cmap, norm=norm, levels=levels)
        axes[j].set_xlim(xrange)
        axes[j].set_ylim(yrange)


        # place x, y axis
        axes[j].axhline(y=0, color='black', linestyle='--', linewidth=.3)
        axes[j].axvline(x=0, color='black', linestyle='--', linewidth=.3)
        # place additional polygons
        if poly2d_list is not None:
            for vertices in poly2d_list:
                axes[j].add_patch(
                    Polygon(vertices, closed=True, edgecolor=poly2dcolor, fill=False, linewidth=.8, linestyle=':')
                )
        axes[j].set_aspect('equal')
        # draw additional trajectories
        if traj1d_list is not None:
            for traj1d in traj1d_list:
                plt.plot(traj1d['x'], traj1d['y'], color='red', linewidth=.8, linestyle=':')

        if fontsize is not None:
            axes[j].tick_params(labelsize=fontsize)

    # axes[0].set_ylabel(log_label)

    # for j, ckpt in enumerate(ckpt_list):
    #     axes[j].set_xlabel(f'Iter {ckpt}')
    
    cb = fig.colorbar(im, location='top', ax=axes, fraction=cbarfrac, pad=0.02)
    if fontsize is not None:
        cb.ax.tick_params(labelsize=fontsize)
    if title is not None:
        fig.suptitle(title)
    
    if savepath is not None:
        plt.savefig(savepath, dpi=dpi)
    plt.show()

def plot_1d_scan(mode, log_dir, log_label, ckpt_list, scanfile2d, scansize, prange, pfn,  savepath=None, title=None, poly2d_list=None, dpi=100):
    """
        use a parameter p specified in prange and a function specified by pfn to specify a particular 1d trajectory for plotting
    """
    get_ckpt_fname = lambda iter: f'qmcjax_ckpt_{iter}_process0.npz'

    ncols = len(ckpt_list)
    fig, axes = plt.subplots(1, ncols, figsize=(ncols*6, 6), dpi=dpi)
    if ncols == 1:
        axes = [axes]

    assert mode in ['slogdet', 'energy']
    if mode == 'slogdet':
        # normalize log-probability
        post_process_fn = lambda vals: 2 * vals - np.log(np.sum(np.exp(2 * vals)))
    else:
        # keep only real part of the energy
        post_process_fn = lambda vals: np.real(vals)

    for j, ckpt in enumerate(ckpt_list):
        # retrieve eval and plot 2d histogram
        _, ts, vals = retrieve_evals(
            log_dir=log_dir,
            mode=mode,
            input_file=scanfile2d,
            ckpt_restore_filename=get_ckpt_fname(ckpt),
            num_processes=1,
        )
        vals = post_process_fn(vals)

        XY = ts.reshape(list(scansize) + [2])
        Z = vals.reshape(scansize)

        # get the closest point in XY to every point of the trajectory
        indices = [ np.unravel_index(np.argmin(np.linalg.norm(XY - pfn(p), axis=2)), scansize) for p in prange]
        Ztraj = [Z[idx] for idx in indices ]
    
        axes[j].plot(prange, Ztraj)

        # place lines at which the ref. line crosses the boundary of an asu
        intersects_indices = np.array([y for vertices in poly2d_list for y in intersect_polygon([XY[idx] for idx in indices ], vertices, 1e-2)]).flatten()
        p_intersects = remove_duplicates_within_epsilon([prange[idx] for idx in intersects_indices], 1e-2)
        for p in p_intersects:
            axes[j].axvline(x=p, color='blue', linestyle=':', linewidth=.2)


    axes[0].set_ylabel(log_label)

    for j, ckpt in enumerate(ckpt_list):
        axes[j].set_xlabel(f'Iter {ckpt}')
    
    if title is not None:
        fig.suptitle(title)
        
    if savepath is not None:
        plt.savefig(savepath, dpi=dpi)
    plt.show()


def intersect_polygon(XYtraj, poly_vertices, tol):
    intersects_indices = []
    n = len(poly_vertices)
    XYtraj = np.array(XYtraj)
    indices = np.arange(0, XYtraj.shape[0])

    for i in range(n):
        p1 = poly_vertices[i]
        p2 = poly_vertices[(i + 1) % n]  # Wrap around to the first vertex
        dist_arr = np.abs(np.cross(p2-p1, XYtraj-p1)/np.linalg.norm(p2-p1))
        intersects_indices += list(indices[dist_arr <= tol])

    return intersects_indices


In [None]:
# Figure 1(a) - (b)

# get a 2d shape to be overlayed onto the figure
ref_pos = symm_visualizer.cell.atom[1][1][:2] # position of the blue ref line
poly2d_list = [asu['vertices'][:3, :2] - ref_pos for asu in symm_visualizer.base_group.base_asu_list]

plot_2d_scan(
    mode='slogdet',   # use 'energy' to get energy scan plots
    ckpt_list=[
        ['_log_graphene_OG_test_multi/', '080000'],
        ['_log_graphene_PA_test_multi/', '080000'],
    ], 
    scanfile2d='symmscan_-2_2_500.npz', 
    scansize=(500,500), 
    # title='Log normalized probability',
    poly2d_list=poly2d_list, # 2d overlay
    # vmin=-100,
    # vmax=25,
    xrange=[-4.5,4.5],
    yrange=[-4.5,4.5],
    # dpi=200,
    cmap=plt.get_cmap('RdBu').reversed(),
    levels=15,
    vertical=False,
    savepath='wavefunction-scan.png'
)

In [None]:
# Figure 9

# get a 2d shape to be overlayed onto the figure
ref_pos = symm_visualizer.cell.atom[1][1][:2] # position of the blue ref line
poly2d_list = [asu['vertices'][:3, :2] - ref_pos for asu in symm_visualizer.base_group.base_asu_list]

plot_2d_scan(
    mode='slogdet',   # use 'energy' to get energy scan plots
    ckpt_list=[
        ['_log_graphene_OG_test_multi/', '080000'],
        ['_log_graphene_PA_test_multi/', '080000'],
    ], 
    scanfile2d='symmscan_-2_2_500-reflection.npz', 
    scansize=(500,500), 
    # title='Log normalized probability',
    poly2d_list=poly2d_list, # 2d overlay
    # vmin=-100,
    # vmax=25,
    xrange=[-4.5,4.5],
    yrange=[-4.5,4.5],
    # dpi=200,
    cmap=plt.get_cmap('RdBu').reversed(),
    levels=15,
    vertical=False,
    savepath='wavefunction-scan-reflection.png'
)

In [None]:
# Figure 1(c) -- difference plot

ckpt_list=[
    ['_log_graphene_OG_test_multi/', '080000'],
    ['_log_graphene_PA_test_multi/', '080000'],
]
scansize=(500,500)
cmap=plt.get_cmap('bwr')
levels=15
xrange=[-4.5,4.5]
yrange=[-4.5,4.5]
poly2d_list = [asu['vertices'][:3, :2] - ref_pos for asu in symm_visualizer.base_group.base_asu_list]
savepath = 'wavefunction-scan-OG-PA-diff.png'
dpi = 100

get_ckpt_fname = lambda iter: f'qmcjax_ckpt_{iter}_process0.npz'
post_process_fn = lambda vals: 2 * vals - np.log(np.sum(np.exp(2 * vals)))
Z_list = []
fig, axes = plt.subplots(figsize=(6, 6), dpi=dpi)
for logdir_ckpt in ckpt_list:
    log_dir = logdir_ckpt[0]
    ckpt = logdir_ckpt[1]
    # retrieve eval and plot 2d histogram
    _, ts, vals = retrieve_evals(
        log_dir=log_dir,
        mode='slogdet',
        input_file='symmscan_-2_2_500.npz',
        ckpt_restore_filename=get_ckpt_fname(ckpt),
        num_processes=1,
    )
    vals = post_process_fn(vals)

    X = ts[:,0].reshape(scansize)
    Y = ts[:,1].reshape(scansize)
    Z = vals.reshape(scansize)
    Z_list.append(Z)

Z_diff = Z_list[0] - Z_list[1]

im = axes.contourf(X, Y, Z_diff, cmap=cmap, levels=levels)
axes.set_xlim(xrange)
axes.set_ylim(yrange)

# place x, y axis
axes.axhline(y=0, color='black', linestyle='--', linewidth=.3)
axes.axvline(x=0, color='black', linestyle='--', linewidth=.3)

# place additional polygons
if poly2d_list is not None:
    for vertices in poly2d_list:
        axes.add_patch(
            Polygon(vertices, closed=True, edgecolor='red', fill=False, linewidth=.8, linestyle=':')
        )
axes.set_aspect('equal')


fig.colorbar(im, location='top', ax=axes, fraction=0.04, pad=0.02)

if savepath is not None:
    plt.savefig(savepath, dpi=dpi)
plt.show()

# 4. Compare inference performance of different checkpoints

The code snippets below are used to produce Figure 5 and Figure 7



In [None]:
# load stats from different configs and checkpoints

from utils.DeepSolid_sampler import DeepSolidSampler
import jax.numpy as jnp 
from IPython.display import clear_output

def get_stats(log_dir, sampling_cfg_str, ckpt_restore_filename, num_processes, n_for_each_est):
    sampler = DeepSolidSampler(
        log_dir=log_dir,
        sampling_cfg_str=sampling_cfg_str,
        libcu_lib_path='/opt/conda/envs/deepsolid/lib/',
        ckpt_restore_filename=ckpt_restore_filename,
        num_processes=num_processes
    )
    sampler.load_all_samples()
    sampler.load_stats()
    
    return sampler.stats_list, len(sampler.samples) if sampler.samples is not None else None

log_dir_list = [
    '_log_graphene_OG_test_multi/',
    '_log_graphene_DA_test_multi/',
    '_log_graphene_GA_test_multi/',
    '_log_graphene_PA_test_multi/',
    '_log_graphene_PC_test_multi/',
]

label_list = [
    r'OG, $N=1000$', 
    r'DA, $N=90, k=12$',
    r'GA, $N=90, k=12$',  
    r'PA, $N=1000, k=12$',
    r'PC, $N=1000, k=12$',
]

sampling_cfgs_list = [
    ['OG_batch1000_mcmc3e4.py'],
    ['PA_batch1000_mcmc3e4.py'],
    ['PA_batch1000_mcmc3e4.py'],
    ['PA_batch1000_mcmc3e4.py'],
    ['PC_batch1000_mcmc3e4.py'],
]

num_processes_list = [
    [2],
    [2],
    [2],
    [2],
    [2],
]

t_meta_list = [
    ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
    ['010000', '020000', '030000', '040000', '050000', '060000', '070000', '080000'],
]

color_list = [
    'black',
    'tab:blue', 
    'tab:orange', 
    'tab:red', 
    'tab:green',
]

linestyle_list = [
    '-o',
    '-o',
    '-o',
    '-o',
    '-o',
]

stats_t_meta_list = []
for log_dir, label, sampling_cfgs_ls, num_processes_ls, t_list in zip(log_dir_list, label_list, sampling_cfgs_list, num_processes_list, t_meta_list):
    stats_t_list = []
    for t in t_list:
        print(f'{label}, {t}')
        stats = None
        for sampling_cfg, num_processes in zip(sampling_cfgs_ls, num_processes_ls):
            stats, samples_num = get_stats(
                    log_dir=log_dir,
                    sampling_cfg_str=sampling_cfg,
                    ckpt_restore_filename=f'qmcjax_ckpt_{t}_process0.npz',
                    num_processes=num_processes,
                    n_for_each_est=100,
            )
            if samples_num is not None:
                break
            clear_output(wait=True)
        if samples_num is not None:
            stats_t_list.append((t,stats,samples_num))
    stats_t_meta_list.append(stats_t_list)

In [None]:
# convert checkpoint times to gpu hours


import numpy as np 
import matplotlib.pyplot as plt
import pickle

def crawl_time(fname):
    sample_time_list = []
    optim_time_list = []
    get_time = lambda line: float(line.split(', ')[-1].split(' seconds in total')[0])

    with open(fname) as file:
        for line in file:
            if 'Sampling duration' in line:
                sample_time_list.append(get_time(line))
            elif 'Optimization duration' in line:
                optim_time_list.append(get_time(line))

    # remove first few iterations which includes warmup time
    min_t = min(len(sample_time_list),len(optim_time_list))
    if min_t > 60:
        sample_time_list = sample_time_list[60:min_t]
        optim_time_list = optim_time_list[60:min_t]
    else:
        sample_time_list = sample_time_list[1:min_t]
        optim_time_list = optim_time_list[1:min_t]
    total_time_list = list(np.array(sample_time_list) + np.array(optim_time_list))

    return sample_time_list, optim_time_list, total_time_list

# replace xxx_slurm.out by an actual slurm output file, which logs the time taken for each step of training
flist = [
    '_log_graphene_OG_test_multi/_xxxxxx_0_slurm.out', 
    '_log_graphene_DA_test_multi/_xxxxxx_0_slurm.out',
    '_log_graphene_GA_test_multi/_xxxxxx_0_slurm.out',
    '_log_graphene_PA_test_multi/_xxxxxx_0_slurm.out',
    '_log_graphene_PC_test_multi/_xxxxxx_0_slurm.out',
]

timelist = [crawl_time(f) for f in flist]
samplet_list = [t[0] for t in timelist]
optimt_list = [t[1] for t in timelist]
totalt_list = [t[2] for t in timelist]
t_meta_list = [samplet_list, optimt_list, totalt_list]

gpuhrs_multipliers = [
    5 / 3600,
    5 / 3600,
    5 / 3600,
    5 / 3600,
    5 / 3600,
]

gpuhrs_mean_list = [ np.mean(ts) * m for ts, m in zip(totalt_list, gpuhrs_multipliers)]

In [None]:
from matplotlib import colors

field = 'mean' # can replace with 'variance' or 'symm_ratio_var'
transparent_ratio = 0.2

import matplotlib.pyplot as plt 
import numpy as np

fig, ax = plt.subplots(figsize=(3.2,4))

for stats_t_list, label, color, linestyle, gpuhrs_mean in zip(stats_t_meta_list, label_list, color_list, linestyle_list, gpuhrs_mean_list):
    sim_num = len(stats_t_list[0][1])
    t_list = [int(a[0]) for a in stats_t_list]
    gpuhrs_list = np.array(t_list) * gpuhrs_mean
    gpuhrs_list = np.array([g for g in gpuhrs_list if g < 300])

    stats_list = np.array([jnp.array([b[field] for b in a[1][:sim_num]]).mean() for a, g in zip(stats_t_list, gpuhrs_list) if g < 300 ]) # if int(a[0]) >=10000 and int(a[0]) <= 80000])
    std_list = np.array([jnp.array([b[field] for b in a[1][:sim_num]]).std() for a, g in zip(stats_t_list, gpuhrs_list) if  g< 300]) # if int(a[0]) >=10000 and int(a[0]) <= 80000])

    upper_err_list = stats_list + std_list / np.sqrt(sim_num)
    lower_err_list = stats_list - std_list / np.sqrt(sim_num)

    fill_color = np.array(colors.to_rgba(color))
    fill_color[3] *= transparent_ratio
    ax.fill_between(gpuhrs_list, lower_err_list, upper_err_list, facecolor=fill_color)
    ax.plot(gpuhrs_list, stats_list, linestyle, color=color, label=label)

ax.legend(fontsize=10,labelspacing=0)

ax.set_xlim([30,300])

ax.tick_params(axis='both', which='major', labelsize=10)

plt.tight_layout()
plt.savefig('graphene_1.pdf', dpi=300)
plt.show()