<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Parameters/Paths" data-toc-modified-id="Parameters/Paths-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Parameters/Paths</a></span></li><li><span><a href="#Load-dataset" data-toc-modified-id="Load-dataset-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Load dataset</a></span></li><li><span><a href="#Load-model" data-toc-modified-id="Load-model-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Load model</a></span></li><li><span><a href="#Examine-random-instances-from-validation" data-toc-modified-id="Examine-random-instances-from-validation-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Examine random instances from validation</a></span></li></ul></div>

In [None]:
import numpy as np
import scipy as sp
import scipy.linalg as spla
import scipy.sparse as spsp
import scipy.sparse.linalg as spspla

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import DGLDataset
from dgl.data.utils import save_graphs, load_graphs, save_info, load_info

from src.utils.converters import extract_simplices, build_boundaries, build_laplacians, compute_Ltilde_pinv
from src.utils.postprocessing import neighborhood_smoothing
from src.dataset.ComplexesDatasetLazy import ComplexesDatasetLazy
from src.model.Dist2CycleRegressor import Dist2CycleRegressor


import matplotlib as mpl
import matplotlib.pyplot as plt

from src.viz.vizualization import *

# Parameters/Paths #

In [None]:
raw_dir='raw_data/3D/EX_Alpha_D3_HDim2_k10'
dataset_path='datasets/3D/LAZY_EX_Alpha_D3_HDim2_k10_boundary_1HD_Lfunc_k10_7_6_0.0'

model_config='models/3D/model_params.npy'
model_path='models/3D/trained_model.pkl'

Laplacian_smoothing_logits=True


#vizualization params
cmap='jet'
default_vals=[0.0,0.5,0.5] #default colors for 0-, 1-, 2-simplices
alphas=(0.0, 1.0, 0.4) #opacity of 0-, 1-, 2-simplices
figsize=(18,5)

# Load dataset #

In [None]:
dataset_path_split=dataset_path.split('/')
dataset_name=dataset_path_split[-1]


val_dataset=ComplexesDatasetLazy(raw_dir=raw_dir,
                               save_dir='/'.join(dataset_path_split[:-1]),
                               saved_dataname=dataset_name,
                               mode='val')

dataset_index={}
for i,data in enumerate(val_dataset):
    g,l,gname=data
    dataset_index[gname]=i

# Load model #

In [None]:
model_params=np.load(model_config, allow_pickle=True).item()


model=Dist2CycleRegressor(in_feats=model_params['in_feats'],
             n_layers=model_params['n_layers'],
             out_feats=model_params['out_feats'],
             hidden_feats=model_params['hidden_feats'],
             aggregator_type=model_params['aggregator_type'],
             weighted_edges=model_params['weighted_edges'],
             fc_bias=model_params['fc_bias'],
             norm=model_params['norm'],
             fc_activation=model_params['fc_activation'],
             out_activation=model_params['out_activation'],
             initialization=model_params['initialization'])

print(model)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])

infeaturesval=model_params['feats']
infeaturescols=model_params['featsize']

model.eval()
loss_criterion = torch.nn.MSELoss()

In [None]:
def pass_through_model(graph, model ):

    eigenvecs=graph.ndata['V'][:,0:infeaturescols]
    feats=torch.hstack([graph.ndata['lk_hom'],eigenvecs])

    del graph.ndata['V']
    del graph.edata['S']

    if 'lk_hom' in graph.ndata:
        del graph.ndata['lk_hom']

    feats=feats.float()

    logits=model(graph,feats)

    return logits

# Examine random instances from validation#

In [None]:
plt.close('all')

num_instances=6

num_instances=min(num_instances, len(val_dataset))

selected=random.sample(range(len(val_dataset)), num_instances)

for i in selected:
    graph, labels, gname = val_dataset[i]
    
    logits=pass_through_model(graph, model)
    
    labels_ng=labels.cpu().numpy()
    logits_ng=logits.squeeze(-1).cpu().detach().numpy()
    
    #vizualization
    pts=np.load(f'{raw_dir}/{gname}_pts.npy', allow_pickle=True)
    simplices=np.load(f'{raw_dir}/{gname}_simplices.npy', allow_pickle=True)
    
    if Laplacian_smoothing_logits:
        boundaries=build_boundaries(simplices)
        B=boundaries[0].todense()
        logits_ng_t=neighborhood_smoothing(B,logits_ng, power=2)
        if not np.isnan(logits_ng_t).any():
            logits_ng=logits_ng_t
    
    mse=loss_criterion(labels, torch.tensor(logits_ng))
    
    plain_distances=[np.array([default_vals[0]]*len(simplices[0])),np.array([default_vals[1]]*len(simplices[1])),np.array([default_vals[2]]*len(simplices[2]))]
    
    fig=plt.figure(figsize=figsize)
    
    if pts.shape[1]==3:
        ax=fig.add_subplot(131, projection='3d')
    else:
        ax=fig.add_subplot(131)
        
    distances_L=plain_distances
    distances_L[1]=logits_ng
    _,ax_o=complex_pyplot(pts, simplices, distances_L,
                        fig=fig, ax=ax, alphas=alphas, cmap=cmap)
    ax_o.set_title('Ours')
    ax_o.set_axis_off()
    
    if pts.shape[1]==3:
        ax=fig.add_subplot(132, projection='3d')
    else:
        ax=fig.add_subplot(132)
        
    distances_L=plain_distances
    distances_L[1]=labels_ng
    _,ax_r=complex_pyplot(pts, simplices, distances_L,
                        fig=fig, ax=ax, alphas=alphas, cmap=cmap)
    ax_r.set_title('Reference')
    ax_r.set_axis_off()

    ax=fig.add_subplot(133)

    _,ax_d=distance_plot(labels_ng, logits_ng, 
                    fig_in=fig, ax_in=ax,
                    title=f'MSE:{np.round(mse,3)}')