In [None]:
##### Import #####

import os
import sys

import time
import argparse
import json
import copy
from tqdm import tqdm

from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn

import plotly.graph_objects as go
import matplotlib.pyplot as plt

sys.path.append('/root/Concept/PytorchMPM')
sys.path.append('/root/Concept/PytorchMPM/learnable/learn_Psi')
from model.model_loop import MPMModelLearnedPhi, PsiModel2d

import random
random.seed(20010313)
np.random.seed(20010313)
torch.manual_seed(20010313)
torch.cuda.manual_seed(20010313)

device = torch.device('cuda:0') 

use_double = False

if use_double:
    torch.set_default_dtype(torch.float64)

In [None]:
##### Data #####

data_dir = '/xiaodi-fast-vol/PytorchMPM/learnable/learn_Psi/data/jelly_v2_every_iter/config_0000'

with open(os.path.join(data_dir, 'config_dict.json'), 'r') as f:
    config_dict = json.load(f)
n_grid = config_dict['n_grid']
dx = 1 / n_grid
dt = config_dict['dt']
frame_dt = config_dict['frame_dt']
n_iter_per_frame = int(frame_dt / dt + 0.5)
p_vol, p_rho = config_dict['p_vol'], config_dict['p_rho']
gravity = config_dict['gravity']
E_range = config_dict['E_range']
nu_range = config_dict['nu_range']
E_gt = config_dict['E']
nu_gt = config_dict['nu']
E_range = config_dict['E_range']
nu_range = config_dict['nu_range']

traj_list = sorted([s for s in os.listdir(data_dir) if 'traj_' in s])
traj_name = traj_list[0]
clip_idx = 0
clip_start = 110 # TODO: change back
clip_len = 50
clip_end = clip_start + clip_len
supervise_frame_interval = 10

data_dict = torch.load(os.path.join(data_dir, traj_name, 'data_dict.pth'), map_location="cpu")
traj_len = len(data_dict['x_traj'])

x_start, v_start, C_start, F_start = data_dict['x_traj'][clip_start].to(device), \
                                           data_dict['v_traj'][clip_start].to(device), \
                                           data_dict['C_traj'][clip_start].to(device), \
                                           data_dict['F_traj'][clip_start].to(device)
if use_double:
    x_start, v_start, C_start, F_start = x_start.double(), v_start.double(), C_start.double(), F_start.double()
x_traj, v_traj, C_traj, F_traj = data_dict['x_traj'][clip_start + 1: clip_end + 1].to(device), \
                                 data_dict['v_traj'][clip_start + 1: clip_end + 1].to(device), \
                                 data_dict['C_traj'][clip_start + 1: clip_end + 1].to(device), \
                                 data_dict['F_traj'][clip_start + 1: clip_end + 1].to(device)
if use_double:
    x_traj, v_traj, C_traj, F_traj = x_traj.double(), v_traj.double(), C_traj.double(), F_traj.double()


In [None]:
##### Loss #####

def get_loss(a_1, a_2):
    n_hidden_layer = 0
    with torch.no_grad():
        state_dict = OrderedDict([('psi_model.mlp.0.weight', torch.tensor([[a_1, a_2]], device=device)), ('psi_model.mlp.0.bias', torch.tensor([0.], device=device))])
        mpm_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, psi_model_input_type='eigen', base_model='fixed_corotated', n_hidden_layer=n_hidden_layer).to(device)
        mpm_model.load_state_dict(state_dict)

        material = torch.ones((len(x_start),), dtype=torch.int, device=device)
        Jp = torch.ones((len(x_start),), dtype=torch.double, device=device)

        criterion = nn.MSELoss()

        x, v, C, F = x_start.clone(), v_start.clone(), C_start.clone(), F_start.clone()

        loss = 0
        x_scale = 1e3

        for clip_frame in range(clip_len):
            for s in range(n_iter_per_frame):
                x, v, C, F, material, Jp = mpm_model(x, v, C, F, material, Jp)

            if (clip_frame + 1) % supervise_frame_interval == 0:
                frame_loss = criterion(x * x_scale, x_traj[clip_frame] * x_scale)
                loss += frame_loss

        loss /= clip_len // supervise_frame_interval

    return loss.item()

a_1, a_2 = 11.71528149, -19.02729416 # epoch 128
print(get_loss(a_1, a_2))

In [None]:
##### Gradient #####

def get_grad(a_1, a_2):
    n_hidden_layer = 0
    
    with torch.enable_grad():
        state_dict = OrderedDict([('psi_model.mlp.0.weight', torch.tensor([[a_1, a_2]], device=device)), ('psi_model.mlp.0.bias', torch.tensor([0.], device=device))])
        mpm_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, psi_model_input_type='eigen', base_model='fixed_corotated', n_hidden_layer=n_hidden_layer).to(device)
        mpm_model.load_state_dict(state_dict)

        material = torch.ones((len(x_start),), dtype=torch.int, device=device)
        Jp = torch.ones((len(x_start),), dtype=torch.double, device=device)

        criterion = nn.MSELoss()

        x, v, C, F = x_start.clone(), v_start.clone(), C_start.clone(), F_start.clone()

        loss = 0
        x_scale = 1e3

        for clip_frame in range(clip_len):
            for s in range(n_iter_per_frame):
                x, v, C, F, material, Jp = mpm_model(x, v, C, F, material, Jp)

            if (clip_frame + 1) % supervise_frame_interval == 0:
                frame_loss = criterion(x * x_scale, x_traj[clip_frame] * x_scale)
                loss += frame_loss

        loss /= clip_len // supervise_frame_interval
        
        loss.backward()
        
        return mpm_model.psi_model.mlp[0].weight.grad.squeeze()

a_1, a_2 = 11.71528149, -19.02729416 # epoch 128
a_1, a_2 = 32, -32
print(get_grad(a_1, a_2))

In [None]:
##### Model #####

n_hidden_layer = 0
model_epoch = 128
model_path = f'/root/Concept/PytorchMPM/learnable/learn_Psi/log/loop_0layer_noclip_sgd_lr3/traj_0000_clip_0000/model/checkpoint_{model_epoch:04d}.pth'

mpm_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, psi_model_input_type='eigen', base_model='fixed_corotated', n_hidden_layer=n_hidden_layer).to(device)
state_dict = torch.load(model_path)
# print(state_dict)
mpm_model.load_state_dict(state_dict)

for layer in mpm_model.psi_model.mlp:
    if isinstance(layer, nn.Linear):
        print(layer.weight, layer.bias)

a_1, a_2 = state_dict['psi_model.mlp.0.weight'].view(-1)
print(get_loss(a_1, a_2))

In [None]:
##### 1D plot (Loss) #####

a1 = 32.

y_25 = np.arange(-31, -30, 0.04, dtype=np.double)
z_25 = np.zeros(len(y_25))

for i, a2 in enumerate(y_25):
    z_25[i] = get_loss(a1, a2)
    print(f"a1={a1:.2f}, a2={a2:.2f}, loss={z_25[i]}")
    

# y_100 = np.arange(-31, -30, 0.01, dtype=np.double)
# z_100 = np.zeros(len(y_100))

# for i, a2 in enumerate(y_100):
#     z_100[i] = get_loss(a1, a2)
#     print(f"a1={a1:.2f}, a2={a2:.2f}, loss={z_100[i]}")
    
plt.figure()
plt.plot(y_25, z_25)
# plt.plot(y_100, z_100)
# plt.savefig('sigma_times_I.png')
plt.show()

In [None]:
plt.figure(figsize=(16, 9), dpi=80)
plt.plot(y_25, z_25, label='interval=0.04')
plt.plot(y_100, z_100, label='interval=0.01')
# plt.savefig('sigma_times_I.png')
plt.title("Psi = a1*sigma1 + a2*sigma2, a1=32, float64")
plt.xlabel("a2")
plt.ylabel("MSE*1e6")
plt.legend()
plt.show()

In [None]:
##### 1D plot (Loss and Grad) #####

a1 = 0.

a2s = np.arange(-20, -10, 1, dtype=np.double)
losses = np.zeros(len(a2s), dtype=np.double)
grad2s = np.zeros(len(a2s), dtype=np.double)

for i, a2 in enumerate(a2s):
    losses[i] = get_loss(a1, a2)
    grad2s[i] = get_grad(a1, a2)[1]
    print(f"a1={a1:.2f}, a2={a2:.2f}, loss={losses[i]}, grad2={grad2s[i]}")

In [None]:
fig, ax1 = plt.subplots(figsize=(16, 9), dpi=80)
ax1.set_title(f"Psi = a1*sigma1 + a2*sigma2, a1={a1}")
ax2 = ax1.twinx()
ax1.set_xlabel("a2")

ax1.plot(a2s, losses, 'g-')
ax1.set_ylabel("MSE*1e6", color='g')

ax2.plot(a2s, grad2s, 'b-')
ax2.set_ylabel("Grad", color='b')
ax2.axhline(y=0, color='0.4', linestyle=':')

plt.show()

In [None]:
plt.figure(figsize=(16, 9), dpi=80)
plt.plot(y, z)
# plt.savefig('sigma_times_I.png')
plt.title("Psi = a1*sigma1 + a2*sigma2, a1=32")
plt.xlabel("a2")
plt.ylabel("MSE*1e6")
plt.show()

In [None]:
##### 2D plot #####

x = np.arange(32, 36, 0.04)
y = np.arange(-35, -31, 0.04)
z = np.zeros((len(x), len(y)))

for i, a1 in enumerate(x):
    for j, a2 in enumerate(y):
        z[i][j] = get_loss(a1, a2)
        print(f"a1={a1:.2f}, a2={a2:.2f}, loss={z[i][j]}")
        
print(x)
print(y)
print(z)
    
fig = go.Figure(data=[go.Surface(z=z, x=x, y=y)])
fig.update_layout(title='Corrected Fix Corotated', autosize=False,
                  width=500, height=500,
                  margin=dict(l=65, r=50, b=65, t=90))
fig.show()

In [None]:
fig = go.Figure(data=[go.Surface(
                                    z=z, x=x, y=y, 
                                    contours = {
                                        "z": {"show": True, "start": 1.3, "end": 1.4, "size": 0.005}
                                    },
                                )])
fig.update_layout(title='Corrected Fix Corotated', autosize=False,
                  width=500, height=500,
                  margin=dict(l=65, r=50, b=65, t=90))
fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                  highlightcolor="limegreen", project_z=True))
fig.show()