In [7]:
##### Import #####

import os
import sys

import time
import argparse
import json
import copy
from tqdm import tqdm

from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn

sys.path.append('/root/Concept/PytorchMPM')
sys.path.append('/root/Concept/PytorchMPM/learnable/learn_Psi')
from model.model_loop import MPMModelLearnedPhi, PsiModel2d

import random
random.seed(20010313)
np.random.seed(20010313)
torch.manual_seed(20010313)
torch.cuda.manual_seed(20010313)

device = torch.device('cuda:0') 

In [2]:
##### Data #####

data_dir = '/xiaodi-fast-vol/PytorchMPM/learnable/learn_Psi/data/jelly_v2_every_iter/config_0000'

with open(os.path.join(data_dir, 'config_dict.json'), 'r') as f:
    config_dict = json.load(f)
n_grid = config_dict['n_grid']
dx = 1 / n_grid
dt = config_dict['dt']
frame_dt = config_dict['frame_dt']
n_iter_per_frame = int(frame_dt / dt + 0.5)
p_vol, p_rho = config_dict['p_vol'], config_dict['p_rho']
gravity = config_dict['gravity']
E_range = config_dict['E_range']
nu_range = config_dict['nu_range']
E_gt = config_dict['E']
nu_gt = config_dict['nu']
E_range = config_dict['E_range']
nu_range = config_dict['nu_range']

traj_list = sorted([s for s in os.listdir(data_dir) if 'traj_' in s])
traj_name = traj_list[0]
clip_idx = 0
clip_start = 110 # TODO: change back
clip_len = 100
clip_end = clip_start + clip_len
supervise_frame_interval = 10

data_dict = torch.load(os.path.join(data_dir, traj_name, 'data_dict.pth'), map_location="cpu")
traj_len = len(data_dict['x_traj'])

x_start, v_start, C_start, F_start = data_dict['x_traj'][clip_start].to(device), \
                                           data_dict['v_traj'][clip_start].to(device), \
                                           data_dict['C_traj'][clip_start].to(device), \
                                           data_dict['F_traj'][clip_start].to(device)
x_traj, v_traj, C_traj, F_traj = data_dict['x_traj'][clip_start + 1: clip_end + 1].to(device), \
                                 data_dict['v_traj'][clip_start + 1: clip_end + 1].to(device), \
                                 data_dict['C_traj'][clip_start + 1: clip_end + 1].to(device), \
                                 data_dict['F_traj'][clip_start + 1: clip_end + 1].to(device)



In [27]:
##### Loss #####

def get_loss(a_1, a_2):
    state_dict = OrderedDict([('psi_model.mlp.0.weight', torch.tensor([[a_1, a_2]], device=device)), ('psi_model.mlp.0.bias', torch.tensor([0], device=device))])
    mpm_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, psi_model_input_type='eigen', base_model='fixed_corotated', n_hidden_layer=n_hidden_layer).to(device)
    mpm_model.load_state_dict(state_dict)

    material = torch.ones((len(x_start),), dtype=torch.int, device=device)
    Jp = torch.ones((len(x_start),), dtype=torch.float, device=device)
    
    criterion = nn.MSELoss()

    x, v, C, F = x_start.clone(), v_start.clone(), C_start.clone(), F_start.clone()

    loss = 0
    x_scale = 1e3

    for clip_frame in range(clip_len):
        for s in range(n_iter_per_frame):
            x, v, C, F, material, Jp = mpm_model(x, v, C, F, material, Jp)

        if (clip_frame + 1) % supervise_frame_interval == 0:
            frame_loss = criterion(x * x_scale, x_traj[clip_frame] * x_scale)
            loss += frame_loss

    loss /= clip_len // supervise_frame_interval

    return loss.item()

a_1, a_2 = 11.71528149, -19.02729416 # epoch 128
print(get_loss(a_1, a_2))

2.425779104232788


In [29]:
##### Model #####

n_hidden_layer = 0
model_epoch = 128
model_path = f'/root/Concept/PytorchMPM/learnable/learn_Psi/log/loop_0layer_noclip_sgd_lr3/traj_0000_clip_0000/model/checkpoint_{model_epoch:04d}.pth'

mpm_model = MPMModelLearnedPhi(2, n_grid, dx, dt, p_vol, p_rho, gravity, psi_model_input_type='eigen', base_model='fixed_corotated', n_hidden_layer=n_hidden_layer).to(device)
state_dict = torch.load(model_path)
# print(state_dict)
mpm_model.load_state_dict(state_dict)

for layer in mpm_model.psi_model.mlp:
    if isinstance(layer, nn.Linear):
        print(layer.weight, layer.bias)

a_1, a_2 = state_dict['psi_model.mlp.0.weight'].view(-1)
print(get_loss(a_1, a_2))

Parameter containing:
tensor([[ 11.7153, -19.0273]], device='cuda:0', requires_grad=True) Parameter containing:
tensor([-0.0042], device='cuda:0', requires_grad=True)
2.425779104232788
