In [None]:
import torch
import json
import os
import sys

sys.path.append('/root/Concept/PytorchMPM/learnable/learn_Psi')
from model.model import MPMModelLearnedPhi, PsiModel2d

print('importing finished.')

device = torch.device('cuda:0')

data_dir = '/xiaodi-fast-vol/PytorchMPM/learnable/learn_Psi/data/jelly_v2/config_0000'
with open(os.path.join(data_dir, 'config_dict.json'), 'r') as f:
    config_dict = json.load(f)
n_grid = config_dict['n_grid']
dx = 1 / n_grid
dt = config_dict['dt']
frame_dt = config_dict['frame_dt']
n_iter_per_frame = int(frame_dt / dt + 0.5)
p_vol, p_rho = config_dict['p_vol'], config_dict['p_rho']
gravity = config_dict['gravity']
E_range = config_dict['E_range']
nu_range = config_dict['nu_range']
E_gt = config_dict['E']
nu_gt = config_dict['nu']
E_range = config_dict['E_range']
nu_range = config_dict['nu_range']


checkpoint_dir = '/root/Concept/PytorchMPM/learnable/learn_Psi/log/eigen_3layer_worked/traj_0000_clip_0000/model'
psi_model_input_type = 'eigen'

mpm_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, psi_model_input_type=psi_model_input_type).to(device)
mpm_model.eval()

In [None]:
import plotly.graph_objects as go

min_sigma = 0.5
max_sigma = 1.5

sigma1 = torch.arange(min_sigma, max_sigma, 1e-2, device=device)
sigma2 = torch.arange(min_sigma, max_sigma, 1e-2, device=device)
sigma_meshgrid = torch.meshgrid(sigma1, sigma2)
sigma = torch.empty((len(sigma1), len(sigma2), 2), device=device)
sigma[:, :, 0] = torch.maximum(sigma_meshgrid[0], sigma_meshgrid[1])
sigma[:, :, 1] = torch.minimum(sigma_meshgrid[0], sigma_meshgrid[1])
sigma = sigma.view(-1, 2)
sigma = sigma + torch.tensor([5e-3, -5e-3], device=device)

checkpoint_list = sorted(os.listdir(checkpoint_dir))

frames = []
trace_fig = None

min_psi = 1e8
max_psi = -1e8

for checkpoint_name in checkpoint_list[0:300:10]:
    print(checkpoint_name)
    
    state_dict = torch.load(os.path.join(checkpoint_dir, checkpoint_name))
    mpm_model.load_state_dict(state_dict)
    
    psi = mpm_model.psi_model(torch.diag_embed(sigma))
    psi = psi.view(len(sigma1), len(sigma2))
    
    min_psi = min(min_psi, psi.min().item())
    max_psi = max(max_psi, psi.max().item())
    
    surface = go.Surface(z=psi.detach().cpu().numpy(), x=sigma1.cpu().numpy(), y=sigma2.cpu().numpy())
    frame = go.Frame(data=[surface], name=checkpoint_name)
#     frame.update_layout(title='NN Pred', autosize=False,
#                       width=500, height=500,
#                       margin=dict(l=65, r=50, b=65, t=90))
#     fig.show()
#     frame = go.Frame(data=fig, name=checkpoint_name)
    frames.append(frame)
    
    if trace_fig is None:
        trace_fig = surface

print(len(frames))

In [None]:
fig = go.Figure(frames=frames)
fig.add_trace(surface)

def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(0)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

# Layout
fig.update_layout(
         title='NN Pred',
         autosize=False,
         width=500,
         height=500,
         margin=dict(l=65, r=50, b=65, t=90),
         scene=dict(
                    xaxis=dict(range=[min_sigma, max_sigma], autorange=False),
                    yaxis=dict(range=[min_sigma, max_sigma], autorange=False),
                    zaxis=dict(range=[min_psi, max_psi], autorange=False),
               ),
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(50)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(0)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
#     title='NN Pred', autosize=False,
#                       width=500, height=500,
#                       margin=dict(l=65, r=50, b=65, t=90)
)


fig.show()

In [None]:
# draw the grad norms of Psi to sigma

min_grad_norm = 1e8
max_grad_norm = -1e8

frames = []
trace_fig = None

for checkpoint_name in checkpoint_list[0:300:10]:
#     print(checkpoint_name)
    
    state_dict = torch.load(os.path.join(checkpoint_dir, checkpoint_name))
    mpm_model.load_state_dict(state_dict)
    
    with torch.enable_grad():
        sigma.requires_grad_()
        psi = mpm_model.psi_model(torch.diag_embed(sigma))
        grad = torch.autograd.grad(psi.sum(), sigma, allow_unused=True)[0]
    grad_norm = torch.linalg.norm(grad, dim=-1)
    grad_norm = grad_norm.view(len(sigma1), len(sigma2))
    
    min_grad_norm = min(min_grad_norm, grad_norm.min().item())
    max_grad_norm = max(max_grad_norm, grad_norm.max().item())
    
    surface = go.Surface(z=grad_norm.detach().cpu().numpy(), x=sigma1.cpu().numpy(), y=sigma2.cpu().numpy())
    frame = go.Frame(data=[surface], name=checkpoint_name)
#     frame.update_layout(title='NN Pred', autosize=False,
#                       width=500, height=500,
#                       margin=dict(l=65, r=50, b=65, t=90))
#     fig.show()
#     frame = go.Frame(data=fig, name=checkpoint_name)
    frames.append(frame)
    
    if trace_fig is None:
        trace_fig = surface

print(f"len(frames) = {len(frames)}")
print(f"psi_range = [{min_grad_norm}, {max_grad_norm}]")

In [None]:
fig = go.Figure(frames=frames)
fig.add_trace(surface)

def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(0)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

# Layout
fig.update_layout(
         title='grad(Psi): NN Prediction',
         autosize=False,
         width=800,
         height=800,
         margin=dict(l=65, r=50, b=65, t=90),
         scene=dict(
                    xaxis=dict(range=[min_sigma, max_sigma], autorange=False),
                    yaxis=dict(range=[min_sigma, max_sigma], autorange=False),
                    zaxis=dict(range=[min_grad_norm, max_grad_norm], autorange=False),
               ),
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(50)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(0)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
#     title='NN Pred', autosize=False,
#                       width=500, height=500,
#                       margin=dict(l=65, r=50, b=65, t=90)
)


fig.show()

In [None]:
gt_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, \
     psi_model_input_type='enu', guess_E=1351.1224365234375, guess_nu=0.08396488428115845).to(device)

min_grad_norm = 1e8
max_grad_norm = -1e8

frames = []
trace_fig = None

with torch.enable_grad():
    sigma.requires_grad_()
    psi = gt_model.psi_model(torch.diag_embed(sigma))
    grad = torch.autograd.grad(psi.sum(), sigma, allow_unused=True)[0]
grad_norm = torch.linalg.norm(grad, dim=-1)
grad_norm = grad_norm.view(len(sigma1), len(sigma2))

min_grad_norm = min(min_grad_norm, grad_norm.min().item())
max_grad_norm = max(max_grad_norm, grad_norm.max().item())

fig = go.Figure(data=go.Surface(z=grad_norm.detach().cpu().numpy(), x=sigma1.cpu().numpy(), y=sigma2.cpu().numpy()))
fig.update_layout(title='grad(Psi): GT', autosize=False,
                  width=800, height=800,
                  margin=dict(l=65, r=50, b=65, t=90))
fig.show()