# Headers

In [1]:
import pickle
import numpy as np
from IPython.display import display
import os
import glob
from ipywidgets import interact, widgets
from IPython.display import display
import pyvista as pv
from matplotlib import pyplot as plt
from scipy.linalg import lstsq
import igl
from matplotlib import cm
import polyscope as ps


# Macros

In [2]:
workspace_path = os.path.dirname(os.path.dirname(os.path.abspath("__file__")))
import sys
sys.path.append(workspace_path)
import meshplot as mp
from lib.utils import *


code_file = 'real_data_opti'

meshes_path = os.path.join(workspace_path, "results", code_file)
shading_params = {
    "width": 600, "height": 600,
    "antialias": True,
    "colormap": "viridis",
    "wireframe": False, "wire_width": 0.03, "wire_color": "black",
    "line_color": "red",
} 

In [3]:
latest_runs_only = widgets.Checkbox(description='Only Last 5 runs?', value=True, layout=widgets.Layout(width='100%'))
visualize_gauss_curv = widgets.Checkbox(description='Visualize Gauss Curvature?', value=False, layout=widgets.Layout(width='100%'))
plot_gauss_pos_neg = widgets.Checkbox(description='Plot Gauss Curvature Pos/Neg?', value=True, layout=widgets.Layout(width='100%'))
show_normals = widgets.Checkbox(description='Show Normals?', value=False, layout=widgets.Layout(width='100%'))
display(widgets.VBox([latest_runs_only, visualize_gauss_curv, plot_gauss_pos_neg, show_normals]))

VBox(children=(Checkbox(value=True, description='Only Last 5 runs?', layout=Layout(width='100%')), Checkbox(va…

# Read Files

In [4]:
# read files based on last created
all_pkl_files = glob.glob(os.path.join(meshes_path, "*.pkl"))
all_pkl_files.sort(key=os.path.getmtime)

if latest_runs_only.value:
    all_pkl_files = all_pkl_files[-5:]
num_cols = 3
num_rows = (len(all_pkl_files) // num_cols) + 1

all_dicts = []
dict_name = []
for pkl_file in all_pkl_files:
    try:
        dict_name.append(os.path.basename(pkl_file))
    except Exception as e:
        print(e, pkl_file)
# print(os.listdir(meshes_path))

# Create a widget with checkboxes for each mesh default is false
checkboxes = [widgets.Checkbox(description=name, value=False, layout=widgets.Layout(width='100%')) for name in dict_name]
display(widgets.VBox(checkboxes))

VBox(children=(Checkbox(value=False, description='bear-v34-0-left-rad1-tlin0-2023-10-16_13-14-39-test-v1.pkl',…

# Visualize Meshes

## Slide through optimization

In [10]:
all_dicts = []
for pkl_file in all_pkl_files:
    try:
        if checkboxes[all_pkl_files.index(pkl_file)].value:
            with open(pkl_file, 'rb') as f:
                mesh_dict = pickle.load(f)
                all_dicts.append(mesh_dict)
        else:
            all_dicts.append({})
    except Exception as e:
        print(e, pkl_file)
        
max_iterations = max([len(mesh_dict['verts_seq']) for i, mesh_dict in enumerate(all_dicts) if checkboxes[i].value])
count_checboxes = sum([checkbox.value for checkbox in checkboxes])
mp_plot = None  
plotted_once = [False for _ in range(count_checboxes)]
face_counter = [0 for _ in range(count_checboxes)]
last_num_verts = [None for _ in range(count_checboxes)]

num_cols = 3
num_rows = (count_checboxes // num_cols) + 1
last_it = 0

gauss_curv_thresh = -0.007
@interact(it=widgets.IntSlider(min=0, max=max_iterations-1, step=1, value=0))
def plot_verts(it):
    global mp_plot, last_it
    ck_itr = 0

    if visualize_gauss_curv:
        color_map = cm.ScalarMappable(cmap='coolwarm')
        color_map.set_clim(-0.005, 0.005)
                
    for i, mesh_dict in enumerate(all_dicts):
        if not checkboxes[i].value:
            continue
        v_numpy = mesh_dict['verts_seq']
        faces_idx_sim = mesh_dict['faces']
        colors = None
        if not isinstance(faces_idx_sim, list):
            faces_idx_sim = [faces_idx_sim]
        if mp_plot is None:
            if visualize_gauss_curv.value:
                gcurv = igl.gaussian_curvature(v_numpy[it], faces_idx_sim[face_counter[ck_itr]])
                if plot_gauss_pos_neg.value:
                    gcurv[gcurv < gauss_curv_thresh] = -1
                    gcurv[gcurv > gauss_curv_thresh] = 1
                colors = color_map.to_rgba(gcurv)[:, :3]
            mp_plot = mp.subplot(v_numpy[it], faces_idx_sim[face_counter[ck_itr]], c=colors, shading=shading_params, s=[num_rows, num_cols, ck_itr], label=f'{dict_name[i]}')

            if show_normals.value:
                verts_normals = compute_vertex_normals_numpy(v_numpy[it]*1000.0, faces_idx_sim[face_counter[ck_itr]])
                mp_plot.add_lines_to_subplot(s=[num_rows, num_cols, ck_itr], beginning=v_numpy[it] + verts_normals * 0.0, ending=v_numpy[it] + verts_normals * 0.01, shading=shading_params)

            last_num_verts[ck_itr] = v_numpy[it].shape[0]
            plotted_once[ck_itr] = True
        elif not plotted_once[ck_itr] or last_num_verts[ck_itr] != v_numpy[it].shape[0]:
            face_counter[ck_itr] = np.argmin([np.abs(np.unique(faces_idx_sim[k]).shape[0] - v_numpy[it].shape[0]) for k in range(len(faces_idx_sim))])
            
            if visualize_gauss_curv.value:
                gcurv = igl.gaussian_curvature(v_numpy[it], faces_idx_sim[face_counter[ck_itr]])
                if plot_gauss_pos_neg.value:
                    gcurv[gcurv < gauss_curv_thresh] = -1
                    gcurv[gcurv > gauss_curv_thresh] = 1
                colors = color_map.to_rgba(gcurv)[:, :3]
            mp.subplot(v_numpy[it], faces_idx_sim[face_counter[ck_itr]], c=colors, shading=shading_params, s=[num_rows, num_cols, ck_itr], data=mp_plot, label=f'{dict_name[i]}')
            plotted_once[ck_itr] = True
            last_num_verts[ck_itr] = v_numpy[it].shape[0]
            
            if show_normals.value:
                verts_normals = compute_vertex_normals_numpy(v_numpy[it]*1000.0, faces_idx_sim[face_counter[ck_itr]])
                mp_plot.add_lines_to_subplot(s=[num_rows, num_cols, ck_itr], beginning=v_numpy[it] + verts_normals * 0.0, ending=v_numpy[it] + verts_normals * 0.01, shading=shading_params)

        elif it < len(v_numpy):
            if visualize_gauss_curv.value:
                gcurv = igl.gaussian_curvature(v_numpy[it], faces_idx_sim[face_counter[ck_itr]])
                if plot_gauss_pos_neg.value:
                    gcurv[gcurv < gauss_curv_thresh] = -1
                    gcurv[gcurv > gauss_curv_thresh] = 1
                colors = color_map.to_rgba(gcurv)[:, :3]
            mp_plot.update_object(s=[num_rows, num_cols, ck_itr], v=v_numpy[it], c=colors)
            if show_normals.value:
                verts_normals = compute_vertex_normals_numpy(v_numpy[it]*1000.0, faces_idx_sim[face_counter[ck_itr]])
                mp_plot.remove_object_type(s=[num_rows, num_cols, ck_itr], obj_type='Lines')
                mp_plot.add_lines_to_subplot(s=[num_rows, num_cols, ck_itr], beginning=v_numpy[it] + verts_normals * 0.0, ending=v_numpy[it] + verts_normals * 0.01, shading=shading_params)

        ck_itr += 1 
        last_it = it

interactive(children=(IntSlider(value=0, description='it', max=20), Output()), _dom_classes=('widget-interact'…

## Display last frame

In [None]:
shading_params_last = {
    "width": 600, "height": 600,
    "antialias": True,
    "colormap": "viridis",
    "wireframe": False, "wire_width": 0.03, "wire_color": "black",
    "line_color": "red",
} 

all_dicts = []
for pkl_file in all_pkl_files:
    try:
        if checkboxes[all_pkl_files.index(pkl_file)].value:
            with open(pkl_file, 'rb') as f:
                mesh_dict = pickle.load(f)
                all_dicts.append(mesh_dict)
        else:
            all_dicts.append({})
    except Exception as e:
        print(e, pkl_file)
        
# vis_frame_nos = [999, 1999, 2999]
vis_frame_nos = [-1]
use_face_idx = -1
for vis_frame_no in vis_frame_nos:
    print("Vis_frame_no:", vis_frame_no)
    for i, mesh_dict in enumerate(all_dicts):
            if not checkboxes[i].value:
                continue
            v_numpy = np.array(mesh_dict['verts_seq'])
            print(os.path.basename(all_pkl_files[i]))
            print("Max - Min :", v_numpy[vis_frame_no].max(0) - v_numpy[vis_frame_no].min(0))
            print("Axis Max val:", v_numpy[vis_frame_no].max(0))
            print("Axis Min val:", v_numpy[vis_frame_no].min(0))
            print("*************")
            faces_idx_sim = mesh_dict['faces']
            if not isinstance(faces_idx_sim, list):
                faces_idx_sim = [faces_idx_sim]
            faces_break_loc = [0]
            last_vert_size = v_numpy[0].shape[0]
            if use_face_idx is None:
                for k in range(1, v_numpy.shape[0]):
                    if v_numpy[k].shape[0] != last_vert_size:
                        faces_break_loc.append(k)
                        last_vert_size = v_numpy[k].shape[0]
                    if v_numpy[vis_frame_no].shape[0] == v_numpy[k].shape[0]:
                        face_idx = len(faces_break_loc) - 1
                        break
            else:
                face_idx = use_face_idx
            mp.plot(v_numpy[vis_frame_no]*10, faces_idx_sim[face_idx], shading=shading_params_last)


In [None]:
for i, mesh_dict in enumerate(all_dicts):
    if not checkboxes[i].value:
        continue
    v_numpy = np.array(mesh_dict['verts_seq'])
    faces_idx_sim = mesh_dict['faces']
    # do a matplotlib plot of how min and max values change over time on two side by side plots
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    min_xyz_t = np.min(v_numpy, axis=1)
    max_xyz_t = np.max(v_numpy, axis=1)
    axs[0].plot(min_xyz_t[:, 0], label='min x')
    axs[0].plot(min_xyz_t[:, 1], label='min y')
    axs[0].plot(min_xyz_t[:, 2], label='min z')
    axs[0].legend()
    axs[1].plot(max_xyz_t[:, 0], label='max x')
    axs[1].plot(max_xyz_t[:, 1], label='max y')
    axs[1].plot(max_xyz_t[:, 2], label='max z')
    axs[1].legend()
    plt.show()

## Polyscope Save Animation

In [None]:
ps.init()
for i, mesh_dict in enumerate(all_dicts):
	if not checkboxes[i].value:
		continue
	v_numpy = mesh_dict['verts_seq']
	faces_idx_sim = mesh_dict['faces']
	colors = None
	if not isinstance(faces_idx_sim, list):
		faces_idx_sim = [faces_idx_sim]
	
	ps.register_surface_mesh("Mesh", v_numpy[-1], faces_idx_sim[-1], material='ceramic')
		
	ps.look_at([-1, 2, -1], [0, 0, 0])
	img_filename  = dict_name[i].split('.pkl')[0]
	base_path = os.path.join(workspace_path, "results", 'poly_imgs', code_file)
	os.makedirs(base_path, exist_ok=True)
	ps.screenshot(os.path.join(base_path, img_filename + ".png"))

# Vis Texture Optimization

In [None]:
for i, checkbox in enumerate(checkboxes):
    if checkbox.value:
        moi_dict = all_dicts[i]

print(moi_dict.keys())
diffusion_sim_img = moi_dict['diffusion_sim_img'].transpose(0, 3, 1, 2)
diffusion_gt_img = moi_dict['diffusion_gt_img'].transpose(0, 3, 1, 2)
print(diffusion_sim_img.shape, diffusion_gt_img.shape)

plt.close('all')
fig = plt.figure(figsize=(15, 10))
# Use a slider and plot heat flow at different time steps for u_real_pixels_msize using matplotlib
# Update matplotlib plot with new heat values based on slider value
ax = fig.add_subplot(1, 3, 1)
ax2 = fig.add_subplot(1, 3, 2)
ax3 = fig.add_subplot(1, 3, 3)
ax.imshow(diffusion_gt_img[0, 0], cmap='hot')
ax2.imshow(diffusion_sim_img[0, 0], cmap='hot')
ax3.imshow(np.abs(diffusion_gt_img[0, 0] - diffusion_sim_img[0, 0]), cmap='hot')

def update(val):
    ax.clear()
    ax.imshow(diffusion_gt_img[0, int(val)], cmap='hot', vmin=0, vmax=1)
    ax2.imshow(diffusion_sim_img[0,int(val)], cmap='hot',vmin=0, vmax=1)
    ax3.imshow(np.abs(diffusion_gt_img[0, int(val)] - diffusion_sim_img[0, int(val)]), cmap='hot')
    ax.set_title('Heat flow at t = {}'.format(int(val)))
    fig.canvas.draw()
    display(fig)


interact(update, val=widgets.IntSlider(min=0, max=diffusion_gt_img.shape[1]-1, step=1, value=0))

# Create GIFs of the meshes

In [None]:
# pv.start_xvfb(wait=0.1)
# print(pv.OFF_SCREEN)
# p = pv.Plotter(notebook=True)

In [None]:

# for i, mesh_dict in enumerate(all_dicts):
#     if not checkboxes[i].value:
#         continue
#     plotter = pv.Plotter()
#     v_numpy = mesh_dict['verts_seq']
#     faces_idx_sim = mesh_dict['faces']
#     gif_path = os.path.join(meshes_path, dict_name[i].split('.')[0])
#     # plotter.open_gif(f'{gif_path}.gif')
#     plotter.open_movie(f'{gif_path}.mp4')
#     for k in range(v_numpy.shape[0]):
#         plotter.clear()
#         plotter.add_mesh(pv.PolyData(v_numpy[k], faces_idx_sim), show_edges=True)
#         plotter.write_frame()
#     plotter.close()

# Fit Plane to Points

In [None]:
def fit_plane_to_points(points):
    # Prepare the data for linear regression (plane fitting)
    A = np.c_[points[:, 0], points[:, 1], np.ones(points.shape[0])]
    b = points[:, 2]

    # Perform linear regression to find coefficients for the plate equation: ax + by + c = z
    coefficients, _, _, _ = lstsq(A, b)

    # Unpack the coefficients
    a, b, c, = coefficients

    return a, b, c

In [None]:
img_size = (128, 160)

In [None]:
for i, mesh_dict in enumerate(all_dicts):
    if not checkboxes[i].value:
        continue
    v_numpy = np.array(mesh_dict['verts_seq'])

    # Fit a plane to the points
    a, b, c = fit_plane_to_points(v_numpy[-1])

    # Calculate the rms error of the plane fit
    dist_calc = (v_numpy[-1][:, 0] * a + v_numpy[-1][:, 1] * b + c - v_numpy[-1][:, 2]) ** 2
    rms_error = np.sqrt(np.mean(dist_calc))

    plt.figure(figsize=(10, 10))
    vnp = v_numpy[-1] * 64
    # Plot dist_calc as a heatmap
    xmin = vnp[:, 0].min()
    xmax = vnp[:, 0].max()
    ymin = vnp[:, 1].min()
    ymax = vnp[:, 1].max()
    tmpimg = np.zeros((int(xmax - xmin), int(ymax - ymin)))
    tmpimg[vnp[:, 0].astype(int) - int(xmin), vnp[:, 1].astype(int) - int(ymin)] = dist_calc
    plt.imshow(tmpimg, cmap='hot')
    plt.colorbar()
    plt.title(f'RMS error: {rms_error}')
    plt.show()