## load related packages

In [None]:
import argparse
import os
import sys
import time
import math
import numpy as np
import torch
import torch.nn.functional as F
import scipy.sparse as sp
import h5py
import random
from copy import copy
from datetime import datetime
import argparse
from IPython import embed
from generate_burgers import burgers_numeric_solve
import pdb
import torch
import torch.nn as nn
import matplotlib.pylab as plt
import tqdm
from multiprocessing import cpu_count
from torch.utils.data import Dataset, DataLoader
parser = argparse.ArgumentParser(description='Generating PDE data')
parser.add_argument('--experiment', type=str, default='burgers_pid',
                    help='Experiment for which data should create for')
parser.add_argument('--exp_path', default='/user/project/pde_gen_control', type=str, help='experiment folder')
parser.add_argument('--date_time', default='2023-11-25_cNet_without_last_relu', type=str, help='experiment date')
parser.add_argument('--model_type', default='PID', type=str, help='model type.')
parser.add_argument('--gpuid', type=int, default=3,
                    help='Used device id')
parser.add_argument('--num_f', default=500, type=int,
                    help='the number of force data')
parser.add_argument('--num_u0', default=60, type=int, 
                    help='the number of initial data')
parser.add_argument('--train_samples', type=int, default=24000,
                    help='Samples in the training dataset')
# parser.add_argument('--valid_samples', type=int, default=0,
#                     help='Samples in the validation dataset')
parser.add_argument('--test_samples', type=int, default=6000,
                    help='Samples in the test dataset')
parser.add_argument('--log', type=eval, default=False,
                    help='pip the output to log file')
parser.add_argument('--max_iter_steps', type=int, default=10,
                    help='max iter steps for tuning params of PID')
parser.add_argument('--max_training_iters', type=int, default=4000,
                    help='max number of iters for tuning params of PID')
parser.add_argument('--save_iters', type=int, default=500,
                    help='save weight each save_iters iters')
parser.add_argument('--model_mode', type=str, default='train',
                    help='train or eval')
parser.add_argument('--model_weight_path', type=str, default='train',
                    help='path of model weight')
parser.add_argument('-dataset_path', type=str, default='/user/project/pde_gen_control/datasets/dataset_control_burgers/varying_f_1e4',
                    help='path of dataset path')
parser.add_argument('-train_batch_size', type=int, default=16,
                    help='batch size for training')
parser.add_argument('-lr', type=float, default=1e-4,
                    help='batch size for training')
parser.add_argument('--is_partially_controllable', type=int, default=0,
                    help='0: fully controllable; 1. partially_controllable')
parser.add_argument('--simulation_method', type=str, default="solver",
                    help='solver or surrogate_model')
parser.add_argument('--pde_1d_surrogate_model_checkpoint', type=str, default="/user/project/pde_gen_control/results/2023-12-02_1d_surrogate_model",
                    help='1d pde surrogate_model checkpoint path')
parser.add_argument('--f_max', type=float, default=5000,
                    help='1d pde surrogate_model checkpoint path')
parser.add_argument('--is_partially_observable', type=int, default=0,
                    help='0: fully observable; 1. partially_observable')
parser.add_argument('--use_model', type=int, default=0,
                    help='0: controller interacts with the solver ; 1: controller interacts with the surrogate model')

In [None]:
from pde_1d_control_PID import cNet,Controller,cycle,plot_result,plot_result_process
from data_burgers_1d import Burgers1DSimple, Burgers
RESCALER = RESCALER_1e5 = 6.4519
from matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np
def plot_result_pre(idx, u, f, u_gt, plot_u0 = True, control_energy_bool = True):
    '''
    Args:
        idx: index of the data to plot
        u: [batch_size, num_t+1, s] ddpm/PID/SAC predicted u
        f: [batch_size, num_t, s] ddpm/PID/SAC predicted control
        u_gt: [batch_size, num_t+1, s] ground truth u
        plot_u0: whether to plot u at t=0
        control_energy_bool: whether to display the control energy
    '''
    u = u[idx]
    u_gt = u_gt[idx]
    f = f[idx]
    uf = u_gt[-1]

    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(4, 2)
    x_pos = np.linspace(0, 1, 128)
    
    if plot_u0:
        ax.plot(x_pos, u[0, :], label=f'PID, $u_{{t=0}}$', color='orange')
        ax.plot(x_pos, u_gt[0, :], label=f'Target, $u_{{t=0}}$', color='peru', ls='-.')
    ax.plot(x_pos, u[-1, :], label=f"PID $u_{{t=10}}$", color='lightskyblue')
    ax.plot(x_pos, uf, label='Target $u_{t=10}$', color='royalblue', ls='-.')
    ax.legend(ncol=2, loc='center', bbox_to_anchor=(1.6, 0.5))

    J = ((u[-1] - u_gt[-1])**2).mean()
    control_energy = (f**2).sum()
    if control_energy_bool:
        ax.set_title(f'J={J:.4f}, control energy={control_energy:.4f}')
    else:
        ax.set_title(f'J={J:.4f}')
    ax.set_xlabel('x')
    ax.set_ylabel('u')
    plt.show()
def metric(u_controlled: torch.Tensor,u_target: torch.Tensor, f: torch.Tensor, target='final_u'):
    '''
    Evaluates the control based on the state deviation and the control cost.
    Note that f and u should NOT be rescaled. (Should be directly input to the solver)

    Arguments:
        u_target:
            Ground truth states
            size: (batch_size, Nt, Nx) (currently Nt = 11, Nx = 128)
        f: 
            Generated control force
            size: (batch_size, Nt - 1, Nx) (currently Nt = 11, Nx = 128)
    
    Returns:
        J_actual:
            Deviation of controlled u from target u for each sample, in MSE.
            When target is 'final_u', evaluate only at the final time stamp
            size: (batch_size)
        
        control_energy:
            Cost of the control force for each sample, in sum of square.
            size: (bacth_size)
    '''
    
    assert len(u_target.size()) == len(f.size()) == 3

    # u_controlled = burgers_numeric_solve_free(u_target[:, 0, :], f, visc=0.01, T=1.0, dt=1e-4, num_t=10)

    # eval J_actual
    if target == 'final_u':
        J_actual = (u_controlled[:, -1, :] - u_target[:, -1, :]).square().mean(-1)
    else:
        raise ValueError('Undefined target to evaluate')
    
    control_energy = f.square().sum((-1, -2))

results_path="/results_0102/"

## Analyze the Quarter control surrogate 1-step loss

In [None]:
try:
    get_ipython().run_line_magic('matplotlib', 'inline')
    args = parser.parse_args([])
    # args.date_time="2024-01-20_PID_surrogate_simu_full_ob_partial_ctr_clamp"
    args.date_time="2024-01-24_PID_solver_full_ob_partial_ctr_clamp"
    args.exp_path='/user/project/pde_gen_control'
    args.gpuid=7
    args.model_mode="train"
    args.is_partially_observable=0
    args.is_partially_controllable=1
    args.f_max=5
    simulation_steps=10
    args.dataset_path='/user/project/pde_gen_control/dataset_control_burgers/free_u_f_1e5_front_rear_quarter'
    args.train_batch_size=50
    args.max_training_iters=10
    args.save_iters=5
    args.model_weight_path="/user/project/pde_gen_control/results/2024-01-09_PID_surrogate_simu_partial_clamp-0.1/burgers_pid/model_weights-9990.pth"
    args.simulation_method="solver"
    args.use_model=1
    args.pde_1d_surrogate_model_checkpoint="/user/project/pde_gen_control/checkpoints/pde_1d/full_ob_partial_ctr_1-step"
except:
    args = parser.parse_args()
    is_jupyter = False
if args.use_model==1:
    args.date_time=args.date_time+"_use_model"
# device=f"cuda:{args.gpuid}"
device=torch.device (f'cuda:{args.gpuid}')
exp_path=args.exp_path+results_path+args.date_time+"/"+args.experiment+f"/{args.f_max}"
# model=Network()
if args.is_partially_controllable==1:
    control_mask=torch.ones((args.train_batch_size,128),device=device)
    # control_mask[:,:32]=control_mask[:,:32]*0
    # control_mask[:,-32:]=control_mask[:,-32:]*0
    control_mask[:,32:96]=control_mask[:,32:96]*0
else:
    control_mask=None
if args.is_partially_observable==1:
    observed_mask=torch.ones((args.train_batch_size,128),device=device)
    # observed_mask[:,:32]=observed_mask[:,:32]*0
    # observed_mask[:,-32:]=observed_mask[:,-32:]*0
    observed_mask[:,32:96]=observed_mask[:,32:96]*0
else:
    observed_mask=None
model=cNet(ns=128)
model=model.to(device=device)
model.load_state_dict(torch.load(args.model_weight_path,map_location='cuda:0'))
controller=Controller(model=model,control_mask=control_mask,obsereved_mask=observed_mask)
model.eval()
print(device)
if args.simulation_method=="solver":
    simu_surrogate_model=None
elif args.simulation_method=="surrogate_model":
    from model.pde_1d_surrogate_model.burgers_operator import Simu_surrogate_model
    simu_surrogate_model=Simu_surrogate_model(path=args.pde_1d_surrogate_model_checkpoint,device=device)
if args.use_model==1:
    milestone=500
    if args.is_partially_observable==1:
        s_ob=64
    else:
        s_ob=128
    from model.pde_1d_surrogate_model.burgers_operator import Simu_surrogate_model
    simu_surrogate_model=Simu_surrogate_model(path=args.pde_1d_surrogate_model_checkpoint,device=device,s_ob=s_ob,milestone=milestone)
if not os.path.exists(exp_path):
    os.makedirs(exp_path)
test_dataset= Burgers1DSimple(
    dataset="burgers",
    input_steps=1,
    output_steps=10,
    time_interval=1,
    is_y_diff=False,
    split="test",
    transform=None,
    pre_transform=None,
    verbose=False,
    root_path =args.dataset_path ,
    device='cuda',
    rescaler=1
)
dl=DataLoader(test_dataset, batch_size = args.train_batch_size, shuffle = False, pin_memory = True, num_workers =24)
dl=cycle(dl)
data = next(dl).to(device)
u0=data[:,0]
ud=data[:,10]
control_list=data[:,11:]
u_target=data[:,[10]]
ut_trj=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
ut_trj_free=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
ut_trj_gd=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
ut_trj_free_gd=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
f=torch.zeros((args.train_batch_size,simulation_steps,ud.shape[-1]),device=device)
ut_trj[:,0,:]=u0
ut_trj_free[:,0,:]=u0
ut_trj_gd[:,0,:]=u0
ut_trj_free_gd[:,0,:]=u0
metrices_J=torch.zeros((9,args.train_batch_size),device=device)
metrices_control_energy=torch.zeros((9,args.train_batch_size),device=device)
metrices_MSE=torch.zeros((9,args.train_batch_size),device=device)
metrices_MSE_relative=torch.zeros((9,args.train_batch_size),device=device)
k=0
for args.f_max in [0.1,0.2,0.5,1.0,1.5,2.0,5.0,10.0,100.0]:
    model=cNet(ns=s_ob)
    model=model.to(device=device)
    args.model_weight_path=f"/user/project/pde_gen_control/results_0106/2023-12-19_PID_solver_partial_clamp-{args.f_max}/burgers_pid/model_weights-9990.pth"
    # args.model_weight_path=f"/user/project/pde_gen_control/results/2024-01-09_PID_surrogate_simu_partial_clamp-{args.f_max}/burgers_pid/model_weights-9990.pth"
    model.load_state_dict(torch.load(args.model_weight_path,map_location='cuda:0'))
    controller=Controller(model=model,control_mask=control_mask,obsereved_mask=None)
    model.eval()
    exp_path=args.exp_path+results_path+args.date_time+"/"+args.experiment+f"/{args.f_max}"
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    with torch.no_grad():
        if args.use_model==1:
            for i in tqdm.trange(simulation_steps):
                if i==0:
                    ut=u0
                    ut_free=u0
                    ut_gd=u0
                    ut_free_gd=u0
                    control=controller(ut-ud)
                    control=control.reshape(control.shape[0],1,-1)
                else:
                    control=controller(ut-ud)
                    control=control.reshape(control.shape[0],1,-1)
                control=torch.clamp(control, -args.f_max, args.f_max)
                if args.is_partially_observable==1:
                    ut_partial=torch.cat([ut[:,:32],ut[:,-32:]],dim=-1)
                    ut_partial=simu_surrogate_model.simulation(ut=ut_partial,ft=control)
                    ut_partial=ut_partial.reshape(ut_partial.shape[0],ut_partial.shape[-1])
                    ut[:,:32]=ut_partial[:,:32]
                    ut[:,-32:]=ut_partial[:,-32:]
                else:
                    ut=simu_surrogate_model.simulation(ut=ut,ft=control)
                    ut=ut.reshape(ut.shape[0],ut.shape[-1])
                ut_trj[:,i+1,:]=ut.clone()
                f[:,i:i+1,:]=control.clone()
        for i in tqdm.trange(simulation_steps):
            if i==0:
                ut=u0
                ut_free=u0
                ut_gd=u0
                ut_free_gd=u0
                # control=control_list[:,i]
                if args.use_model==1:
                    control=f[:,i]
                else:
                    control=controller(ut_gd-ud)
                control=control.reshape(control.shape[0],1,-1)
            else:
                if args.use_model==1:
                    control=f[:,i]
                else:
                    control=controller(ut_gd-ud)
                # control=control_list[:,i]
                control=control.reshape(control.shape[0],1,-1)
            control=torch.clamp(control, -args.f_max, args.f_max)
            if args.simulation_method=="solver":
                # trajectory=burgers_numeric_solve(ut, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                # ut_free=burgers_numeric_solve(ut_free, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                # ut=trajectory[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                # ut_free=ut_free[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                if args.is_partially_observable==1:
                    trajectory=burgers_numeric_solve(ut_gd, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                    ut_free_gd=burgers_numeric_solve(ut_free_gd, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')         
                else:
                    trajectory=burgers_numeric_solve(ut_gd, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                    ut_free_gd=burgers_numeric_solve(ut_free_gd, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                ut_gd=trajectory[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                ut_free_gd=ut_free_gd[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                if args.use_model!=1:
                    ut_trj=None
                ut_trj_free=None
                
            # else:
            #     trajectory=burgers_numeric_solve(ut, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
            #     ut_free=burgers_numeric_solve(ut_free, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
            #     ut=trajectory[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
            #     ut_free=ut_free[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
            if args.use_model!=1:
                f[:,i:i+1,:]=control.clone()
            ut_trj_gd[:,i+1,:]=ut_gd.clone()
            ut_trj_free_gd[:,i+1,:]=ut_free_gd
        ut_trj_gd_partial=torch.cat([ut_trj_gd[:,:,:32],ut_trj_gd[:,:,-32:]],dim=-1)
        if args.use_model==1:
            ut_trj_partial=torch.cat([ut_trj[:,:,:32],ut_trj[:,:,-32:]],dim=-1)
        u_target_partial=torch.cat([u_target[:,:,:32],u_target[:,:,-32:]],dim=-1)
        if args.is_partially_observable==1:
            J_actual = (ut_trj_gd_partial[:, -1, :]- u_target_partial[:, -1, :]).square().mean(-1) ###only take observed state into consideration
        else:
            J_actual = (ut_trj_gd[:, -1, :] - u_target[:, -1, :]).square().mean(-1)
        control_energy = f.square().sum((-1, -2))
        if args.use_model==1:
            print("ut_trj_gd_partial shape: ",ut_trj_gd_partial.shape)
            if args.is_partially_observable==1:
                MSE=(ut_trj_gd_partial-ut_trj_partial).square().mean((-2,-1))
                MSE_relative=torch.norm(torch.tensor(np.array(ut_trj_partial.cpu())) - torch.tensor(np.array(ut_trj_gd_partial.cpu()))) / torch.norm(torch.tensor(np.array(ut_trj_gd_partial.cpu()))) 
            else:
                MSE=(ut_trj_gd-ut_trj).square().mean((-2,-1))
                MSE_relative=torch.norm(torch.tensor(np.array(ut_trj.cpu())) - torch.tensor(np.array(ut_trj_gd.cpu()))) / torch.norm(torch.tensor(np.array(ut_trj_gd.cpu())))
            print("MSE shape",MSE.shape)
            print("metrices_MSE shape:",metrices_MSE.shape)
            metrices_MSE[k]=MSE
            metrices_MSE_relative[k]=MSE_relative
        metrices_J[k]=(J_actual)
        metrices_control_energy[k]=control_energy
        k=k+1
        if ut_trj==None:
            plot_result(u0=u0.detach().cpu(),
                    ut=ut_trj,
                    ud=ud.detach().cpu(),
                    ut_free=ut_trj_free,
                    ut_gd=ut_trj_gd.detach().cpu(),
                    ut_free_gd=ut_trj_free_gd.detach().cpu(),
                    f=f,path=exp_path)
        else:
            plot_result(u0=u0.detach().cpu(),
                        ut=ut_trj.detach().cpu(),
                        ud=ud.detach().cpu(),
                        ut_free=ut_trj_free,
                        ut_gd=ut_trj_gd.detach().cpu(),
                        ut_free_gd=ut_trj_free_gd.detach().cpu(),
                        f=f,path=exp_path)
        # break
#then  save metrices_J and metrices_control_energy as two npy files
np.save(exp_path+"/metrices_J.npy",metrices_J.cpu().numpy())
np.save(exp_path+"/metrices_control_energy.npy",metrices_control_energy.cpu().numpy())
if args.use_model==1:
    np.save(exp_path+"/metrices_MSE.npy",metrices_MSE.cpu().numpy())
    np.save(exp_path+"/metrices_MSE_relative.npy",metrices_MSE_relative.cpu().numpy())
if args.use_model==1:
    plot_result(u0=u0.detach().cpu(),
            ut=ut_trj.detach().cpu(),
            ud=ud.detach().cpu(),
            ut_free=ut_trj_free,
            ut_gd=ut_trj_gd.detach().cpu(),
            ut_free_gd=ut_trj_free_gd.detach().cpu(),
            f=f,path=exp_path)
else:
    plot_result(u0=u0.detach().cpu(),
            ut=ut_trj,
            ud=ud.detach().cpu(),
            ut_free=ut_trj_free,
            ut_gd=ut_trj_gd.detach().cpu(),
            ut_free_gd=ut_trj_free_gd.detach().cpu(),
            f=f,path=exp_path)
plot_result_pre(idx=1, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=2, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=3, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=4, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=5, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)

## partial ob,partial ctr PID with 1-step loss surrogate model

In [None]:
try:
    get_ipython().run_line_magic('matplotlib', 'inline')
    args = parser.parse_args([])
    # args.date_time="2024-01-09_PID_surrogate_simu_partial_ob_partial_ctr_clamp"
    args.date_time="2024-01-09_PID_solver_partial_ob_partial_ctr_clamp"
    args.exp_path='/user/project/pde_gen_control'
    args.gpuid=1
    args.model_mode="train"
    args.is_partially_observable=1
    args.is_partially_controllable=1
    args.f_max=5
    simulation_steps=10
    args.dataset_path='/user/project/pde_gen_control/dataset_control_burgers/free_u_f_1e5_front_rear_quarter'
    args.train_batch_size=50
    args.max_training_iters=10
    args.save_iters=5
    args.model_weight_path="/user/project/pde_gen_control/results/2024-01-09_PID_surrogate_simu_partial_ob_partial_ctr_clamp-0.1/burgers_pid/model_weights-9990.pth"
    args.simulation_method="solver"
    args.use_model=1
    args.pde_1d_surrogate_model_checkpoint="/user/project/pde_gen_control/checkpoints/pde_1d/partial_ob_partial_ctr_1-step"
except:
    args = parser.parse_args()
    is_jupyter = False
if args.use_model==1:
    args.date_time=args.date_time+"_use_model"
# device=f"cuda:{args.gpuid}"
device=torch.device (f'cuda:{args.gpuid}')
exp_path=args.exp_path+results_path+args.date_time+"/"+args.experiment+f"/{args.f_max}"
# model=Network()
if args.is_partially_controllable==1:
    control_mask=torch.ones((args.train_batch_size,128),device=device)
    # control_mask[:,:32]=control_mask[:,:32]*0
    # control_mask[:,-32:]=control_mask[:,-32:]*0
    control_mask[:,32:96]=control_mask[:,32:96]*0
else:
    control_mask=None
if args.is_partially_observable==1:
    observed_mask=torch.ones((args.train_batch_size,128),device=device)
    # observed_mask[:,:32]=observed_mask[:,:32]*0
    # observed_mask[:,-32:]=observed_mask[:,-32:]*0
    observed_mask[:,32:96]=observed_mask[:,32:96]*0
else:
    observed_mask=None
model=cNet(ns=64)
model=model.to(device=device)
model.load_state_dict(torch.load(args.model_weight_path,map_location='cuda:0'))
controller=Controller(model=model,control_mask=control_mask,obsereved_mask=observed_mask)
model.eval()
print(device)
if args.simulation_method=="solver":
    simu_surrogate_model=None
elif args.simulation_method=="surrogate_model":
    from model.pde_1d_surrogate_model.burgers_operator import Simu_surrogate_model
    simu_surrogate_model=Simu_surrogate_model(path=args.pde_1d_surrogate_model_checkpoint,device=device)
if args.use_model==1:
    milestone=500
    if args.is_partially_observable==1:
        s_ob=64
    else:
        s_ob=128
    from model.pde_1d_surrogate_model.burgers_operator import Simu_surrogate_model
    simu_surrogate_model=Simu_surrogate_model(path=args.pde_1d_surrogate_model_checkpoint,device=device,s_ob=s_ob,milestone=milestone)
if not os.path.exists(exp_path):
    os.makedirs(exp_path)
test_dataset= Burgers1DSimple(
    dataset="burgers",
    input_steps=1,
    output_steps=10,
    time_interval=1,
    is_y_diff=False,
    split="test",
    transform=None,
    pre_transform=None,
    verbose=False,
    root_path =args.dataset_path ,
    device='cuda',
    rescaler=1
)
dl=DataLoader(test_dataset, batch_size = args.train_batch_size, shuffle = False, pin_memory = True, num_workers =24)
dl=cycle(dl)
data = next(dl).to(device)
u0=data[:,0]
ud=data[:,10]
control_list=data[:,11:]
u_target=data[:,[10]]
ut_trj=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
ut_trj_free=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
ut_trj_gd=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
ut_trj_free_gd=torch.zeros((args.train_batch_size,simulation_steps+1,ud.shape[-1]),device=device)
f=torch.zeros((args.train_batch_size,simulation_steps,ud.shape[-1]),device=device)
ut_trj[:,0,:]=u0
ut_trj_free[:,0,:]=u0
ut_trj_gd[:,0,:]=u0
ut_trj_free_gd[:,0,:]=u0
metrices_J=torch.zeros((9,args.train_batch_size),device=device)
metrices_control_energy=torch.zeros((9,args.train_batch_size),device=device)
metrices_MSE=torch.zeros((9,args.train_batch_size),device=device)
metrices_MSE_relative=torch.zeros((9,args.train_batch_size),device=device)
k=0
for args.f_max in [0.1,0.2,0.5,1.0,1.5,2.0,5.0,10.0,100.0]:
    model=cNet(ns=64)
    model=model.to(device=device)
    args.model_weight_path=f"/user/project/pde_gen_control/results_0106/2023-12-29_PID_solver_partial_ob_partial_ctr_clamp-{args.f_max}/burgers_pid/model_weights-9990.pth"
    # args.model_weight_path=f"/user/project/pde_gen_control/results/2024-01-09_PID_surrogate_simu_partial_ob_partial_ctr_clamp-{args.f_max}/burgers_pid/model_weights-9990.pth"
    model.load_state_dict(torch.load(args.model_weight_path,map_location='cuda:0'))
    controller=Controller(model=model,control_mask=control_mask,obsereved_mask=observed_mask)
    model.eval()
    exp_path=args.exp_path+results_path+args.date_time+"/"+args.experiment+f"/{args.f_max}"
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    with torch.no_grad():
        if args.use_model==1:
            for i in tqdm.trange(simulation_steps):
                if i==0:
                    ut=u0
                    ut_free=u0
                    ut_gd=u0
                    ut_free_gd=u0
                    control=controller(ut-ud)
                    control=control.reshape(control.shape[0],1,-1)
                else:
                    control=controller(ut-ud)
                    control=control.reshape(control.shape[0],1,-1)
                control=torch.clamp(control, -args.f_max, args.f_max)
                ut_partial=torch.cat([ut[:,:32],ut[:,-32:]],dim=-1)
                ut_partial=simu_surrogate_model.simulation(ut=ut_partial,ft=control)
                ut_partial=ut_partial.reshape(ut_partial.shape[0],ut_partial.shape[-1])
                ut[:,:32]=ut_partial[:,:32]
                ut[:,-32:]=ut_partial[:,-32:]
                ut_trj[:,i+1,:]=ut.clone()
                f[:,i:i+1,:]=control.clone()
        for i in tqdm.trange(simulation_steps):
            if i==0:
                ut=u0
                ut_free=u0
                ut_gd=u0
                ut_free_gd=u0
                # control=control_list[:,i]
                if args.use_model==1:
                    control=f[:,i]
                else:
                    control=controller(ut_gd-ud)
                control=control.reshape(control.shape[0],1,-1)
            else:
                if args.use_model==1:
                    control=f[:,i]
                else:
                    control=controller(ut_gd-ud)
                # control=control_list[:,i]
                control=control.reshape(control.shape[0],1,-1)
            control=torch.clamp(control, -args.f_max, args.f_max)
            if args.simulation_method=="solver":
                # trajectory=burgers_numeric_solve(ut, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                # ut_free=burgers_numeric_solve(ut_free, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                # ut=trajectory[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                # ut_free=ut_free[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                if args.is_partially_observable==1:
                    trajectory=burgers_numeric_solve(ut_gd, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                    ut_free_gd=burgers_numeric_solve(ut_free_gd, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')         
                else:
                    trajectory=burgers_numeric_solve(ut_gd, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                    ut_free_gd=burgers_numeric_solve(ut_free_gd, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
                ut_gd=trajectory[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                ut_free_gd=ut_free_gd[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
                if args.use_model!=1:
                    ut_trj=None
                ut_trj_free=None
                
            # else:
            #     trajectory=burgers_numeric_solve(ut, control, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
            #     ut_free=burgers_numeric_solve(ut_free, control*0, visc=0.01, T=1e-1, dt=1e-4, num_t=1, mode='PID')
            #     ut=trajectory[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
            #     ut_free=ut_free[torch.arange(args.train_batch_size),torch.arange(args.train_batch_size),1]##[batch_size,ns
            if args.use_model!=1:
                f[:,i:i+1,:]=control.clone()
            ut_trj_gd[:,i+1,:]=ut_gd.clone()
            ut_trj_free_gd[:,i+1,:]=ut_free_gd
        ut_trj_gd_partial=torch.cat([ut_trj_gd[:,:,:32],ut_trj_gd[:,:,-32:]],dim=-1)
        ut_trj_partial=torch.cat([ut_trj[:,:,:32],ut_trj[:,:,-32:]],dim=-1)
        u_target_partial=torch.cat([u_target[:,:,:32],u_target[:,:,-32:]],dim=-1)
        if args.is_partially_observable==1:
            J_actual = (ut_trj_gd_partial[:, -1, :]- u_target_partial[:, -1, :]).square().mean(-1) ###only take observed state into consideration
        else:
            J_actual = (ut_trj_gd[:, -1, :] - u_target[:, -1, :]).square().mean(-1)
        control_energy = f.square().sum((-1, -2))
        if args.use_model==1:
            print("ut_trj_gd_partial shape: ",ut_trj_gd_partial.shape)
            MSE=(ut_trj_gd_partial-ut_trj_partial).square().mean((-2,-1))
            MSE_relative=torch.norm(torch.tensor(np.array(ut_trj_partial.cpu())) - torch.tensor(np.array(ut_trj_gd_partial.cpu()))) / torch.norm(torch.tensor(np.array(ut_trj_gd_partial.cpu()))) 
            print("MSE shape",MSE.shape)
            print("metrices_MSE shape:",metrices_MSE.shape)
            metrices_MSE[k]=MSE
            metrices_MSE_relative[k]=MSE_relative
        metrices_J[k]=(J_actual)
        metrices_control_energy[k]=control_energy
        k=k+1
        if ut_trj==None:
            plot_result(u0=u0.detach().cpu(),
                    ut=ut_trj,
                    ud=ud.detach().cpu(),
                    ut_free=ut_trj_free,
                    ut_gd=ut_trj_gd.detach().cpu(),
                    ut_free_gd=ut_trj_free_gd.detach().cpu(),
                    f=f,path=exp_path)
        else:
            plot_result(u0=u0.detach().cpu(),
                        ut=ut_trj.detach().cpu(),
                        ud=ud.detach().cpu(),
                        ut_free=ut_trj_free,
                        ut_gd=ut_trj_gd.detach().cpu(),
                        ut_free_gd=ut_trj_free_gd.detach().cpu(),
                        f=f,path=exp_path)
        # break
#then  save metrices_J and metrices_control_energy as two npy files
np.save(exp_path+"/metrices_J.npy",metrices_J.cpu().numpy())
np.save(exp_path+"/metrices_control_energy.npy",metrices_control_energy.cpu().numpy())
if args.use_model==1:
    np.save(exp_path+"/metrices_MSE.npy",metrices_MSE.cpu().numpy())
    np.save(exp_path+"/metrices_MSE_relative.npy",metrices_MSE_relative.cpu().numpy())
if args.use_model==1:
    plot_result(u0=u0.detach().cpu(),
            ut=ut_trj.detach().cpu(),
            ud=ud.detach().cpu(),
            ut_free=ut_trj_free,
            ut_gd=ut_trj_gd.detach().cpu(),
            ut_free_gd=ut_trj_free_gd.detach().cpu(),
            f=f,path=exp_path)
else:
    plot_result(u0=u0.detach().cpu(),
            ut=ut_trj,
            ud=ud.detach().cpu(),
            ut_free=ut_trj_free,
            ut_gd=ut_trj_gd.detach().cpu(),
            ut_free_gd=ut_trj_free_gd.detach().cpu(),
            f=f,path=exp_path)
plot_result_pre(idx=1, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=2, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=3, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=4, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)
plot_result_pre(idx=5, u=ut_trj_gd.detach().cpu(), f=f, u_gt=u_target.detach().cpu(), plot_u0 = True)