In [None]:
import fnmatch
import itertools as it
import json
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import networkx as nx
import numpy as np
import os
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
# font styles
font_family = "Times"
fs_title = 20
fs_label = 18
fs_small = 14

# init mpl plot style
sns.set_style('ticks')
mpl.rcParams["font.family"] = font_family
mpl.rcParams["mathtext.fontset"] = "stix"
colors = sns.color_palette()
fig_height = 6

In [None]:
dir_pattern = '../logs/2d_heisenberg_checkpoints/conditional_heisenberg_{rows}x{cols}/{model_id}/ns{ns}/{train_id}/'
figures_dir = './figures/'

In [None]:
train_ids = {
    "2x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-110204",
    "2x6": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-110404",
    "2x7": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-110604",
    "2x8": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-110704",
    "2x9": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-110815",
    "4x4": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-181536",
    "4x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-181556",
    "5x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-181610",
    "6x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-181649",
    "7x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-181750",
    "8x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-182303",
    "9x5": "iter100_lr0.001_wd0.0_bs100_dropout0.1_samplestruct2_lrschedulewarmup_cosine13102022-182505",
}

In [None]:
save_figures = True
file_type = 'pdf'

# get system
model_id = 'gcn_proj_3_16-transformer_l4_d128_h4_featone_hot'
rows = 2
cols = 9
ns = 1000
snapshots = 20000
split = "test"
tick_multiples = 2

train_id = train_ids[f"{rows}x{cols}"]

# data and results directories
res_dir = dir_pattern.format(rows=rows, cols=cols, model_id=model_id, ns=ns, train_id=train_id, split=split)
model_properties_dir = os.path.join(res_dir, 'properties', split, 'model')
shadow_properties_dir = os.path.join(res_dir, 'properties', split, 'shadow')
data_dir = os.path.join(res_dir, 'data', f'{rows}x{cols}', split)

# list of hamiltonian ids
correlation_mse_file = os.path.join(res_dir, 'properties', split, 'model', 'model_correlations_mse.json')

with open(correlation_mse_file, 'r') as f:
    errors = json.load(f)

sorted_hamiltonian_ids = [k for k, _ in sorted(errors.items(), key=lambda item: item[1])]
best_hamiltonian_id = sorted_hamiltonian_ids[0]

# Coupling Graph

In [None]:
def plot_couplings(couplings, rows, cols, node_size, figsize, save_as=None):
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    graph = nx.from_numpy_matrix(np.matrix(couplings), create_using=nx.DiGraph)
    mapping = {i: i + 1 for i in graph.nodes}
    graph = nx.relabel_nodes(graph, mapping)
    pos = {i: ((i-1) % cols, -((i-1) // cols)) for i in graph.nodes()}
    edge_widths = [(x + 1) ** 2 for x in list(nx.get_edge_attributes(graph, "weight").values())]
    
    edges, weights = zip(*nx.get_edge_attributes(graph,'weight').items())

    nx.draw(
        graph, pos, node_color="white", with_labels=True, font_color="black", edge_cmap=plt.cm.Blues,
        node_size=node_size, width=edge_widths, horizontalalignment='center', edgecolors="black", edgelist=edges, edge_color=weights,
        arrows=False, ax=ax, verticalalignment='center_baseline', font_size=fs_small, font_family=font_family
    )
    
    if save_as is not None:
        folder, fn = os.path.split(save_as)
        
        if not os.path.exists(folder):
            os.makedirs(folder)
            
        plt.savefig(save_as, bbox_inches='tight', pad_inches=0.0, dpi=200)
        plt.close(fig)

In [None]:
# plot couplings
scale = 0.5

save_as = None
if save_figures:
    save_as = os.path.join(figures_dir, f'2d_heisenberg/2d_heisenberg_{split}_coupling_{rows}x{cols}_{best_hamiltonian_id}.{file_type}')
    
coupling_matrix = np.load(os.path.join(data_dir, f'coupling_matrix_id{best_hamiltonian_id}.npy'))
plot_couplings(coupling_matrix, rows, cols, figsize=(scale * fig_height * (cols / rows), scale * fig_height), node_size=700, save_as=save_as)

# Two point correlation functions

In [None]:
def _make_subplot_correlation(ax, data, x_tick_locs, x_tick_marks, y_tick_locs, y_tick_marks, label, cmap):
    im = ax.imshow(data, cmap=plt.get_cmap(cmap), vmin=-1, vmax=1)
    ax.xaxis.tick_top()
    ax.yaxis.tick_left()
    ax.xaxis.set_ticks(x_tick_locs, x_tick_marks, fontsize=fs_small)
    ax.yaxis.set_ticks(y_tick_locs, y_tick_marks, fontsize=fs_small)
    ax.set_xlabel(label, fontsize=fs_label, labelpad=10)

    for _,s in ax.spines.items():
        s.set_color('white')
    
    return im
    

def plot_correlation(cm_true, cm_pred, rows, cols, figsize, title=None, save_as=None, cmap='RdBu', tick_multiples=5, plot_error=True):
    fig, axes = plt.subplots(1, 3 if plot_error else 2, figsize=figsize)
    tick_locs = np.array([i for i in np.arange(tick_multiples, rows * cols + 1, tick_multiples) - 1] ) # + [rows * cols - 1]
    tick_marks = tick_locs + 1
    
    # True Correlation Function
    im = _make_subplot_correlation(axes[0], cm_true, tick_locs, tick_marks, tick_locs, tick_marks, "True Correlation Function", cmap)
        
    # Model Prediction
    _make_subplot_correlation(axes[1], cm_pred, tick_locs, tick_marks, [], [], "Transformer Prediction", cmap)
        
    # Absolute Error
    if plot_error:
        _make_subplot_correlation(axes[2], np.abs(cm_pred - cm_true), tick_locs, tick_marks, [], [], "Transformer Absolute Error", cmap)
    
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
        
    # colorbar
    bar = fig.colorbar(im, pad=0.01, shrink=0.85, ax=axes.ravel().tolist())
    bar.set_label(r"$C_{ij}$", fontsize=fs_label, rotation=0, labelpad=20)
    bar.ax.tick_params(labelsize=fs_small)

    for _,s in bar.ax.spines.items():
        s.set_color('white')
    
    if title is not None:
        fig.suptitle(title, fontsize=fs_title, y=0.95)
    
    if save_as is not None:
        folder, fn = os.path.split(save_as)
        
        if not os.path.exists(folder):
            os.makedirs(folder)
            
        plt.savefig(save_as, bbox_inches='tight', pad_inches=0.1, dpi=200)
        plt.close(fig)

In [None]:
def make_correlation_plots(res_dir, rows, cols, ntrain, split, idx, figsize, model_id, train_id, snapshots, 
                           title=None, save_as=None, cmap='RdBu', tick_multiples=5, plot_error=True):  
    cmat_true = np.load(os.path.join(data_dir, f'correlation_matrix_id{idx}.npy'))
    cmat_pred = np.load(os.path.join(model_properties_dir, 'correlations', f'correlations_model_id{idx}.npy'))
    
    plot_correlation(cmat_true, cmat_pred, rows, cols, figsize, title=title, 
                     save_as=save_as, cmap=cmap, tick_multiples=tick_multiples, plot_error=plot_error)

In [None]:
########
# params
cmap = 'RdBu'
plot_error = False
figsize = (fig_height * (3 if plot_error else 2) + 1, fig_height)

# plot correlation functions
save_as = None
if save_figures:
    save_as = os.path.join(figures_dir, f'2d_heisenberg/2d_heisenberg_{split}_correlation_{rows}x{cols}_{best_hamiltonian_id}.{file_type}')

title = f"Two-Point Correlation Function"
title = None
make_correlation_plots(res_dir, rows, cols, ns, split, idx=best_hamiltonian_id, figsize=figsize, model_id=model_id, 
                       train_id=train_id, snapshots=snapshots, title=title, save_as=save_as, cmap=cmap,
                       tick_multiples=tick_multiples, plot_error=plot_error)