In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import tqdm
import os
import time
from tensorboardX import SummaryWriter

#from envs.burgers import Burgers
from buffer import OfflineReplayBuffer
from critic import ValueLearner, QPiLearner, QSarsaLearner
from bppo import BehaviorCloning, BehaviorProximalPolicyOptimization
from envs.burgers import Burgers

In [2]:
# Hyperparameters

# Experiment
env_name='burger'
path='logs'
log_freq=int(100)
seed=20241219
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
N=100 # Number of trajectories to collect for offline dataset

# For Value
v_steps=int(5000)
v_hidden_dim = 256
v_depth = 3
v_lr = 1e-4
v_batch_size = 64

# For Q
q_bc_steps=int(5000)
q_pi_steps=10 # Number of steps to update Q-network in each iteration. Only used if is_offpolicy_update=True.
q_hidden_dim = 256
q_depth = 3
q_lr = 1e-4
q_batch_size = 64
target_update_freq=2
tau=0.005 # Soft update rate for target Q network parameters. See Q_learner.update()
gamma=0.99 # Discount factor for calculating the return.
is_offpolicy_update=False # Whether to use advantage replacement (as proposed in the BPPO paper) in Q-learning.
# If False, use Q-learning to update the Q-network parameters in each iteration.
# If True, only update the Q-network parameters once, and keep using this Q-network.

# For BC
bc_steps=int(500)
bc_lr = 1e-4
bc_hidden_dim = 256
bc_depth = 3
bc_batch_size = 64

# For BPPO
bppo_steps=int(100)
bppo_hidden_dim = 256
bppo_depth = 3
bppo_lr = 1e-4
bppo_batch_size = 64
clip_ratio=0.25 # PPO clip ratio. The probability ratio between new and old policy is clipped to be in the range [1-clip_ratio, 1+clip_ratio]
entropy_weight=0.00 # Weight of entropy loss in PPO and BPPO. Can be set to 0.01 for medium tasks.
decay=0.96 # Decay rate of PPO clip ratio
omega=0.9 # Related to setting the weight of advantage (see PPO code)
is_clip_decay=True # Whether to decay the clip_ratio during training
is_bppo_lr_decay=True # Whether to decay the learning rate of BPPO during trainining
is_update_old_policy=True # Whether to update the old policy of BPPO in each iteration. The old policy is used to calculate the probability ratio.
is_state_norm=False # Whether to normalize the states of the dataset.

# Other Settings
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device=torch.device('cpu')
state_dim = 128
action_dim = 128
x_range=(-5,5)
nt=500
dt=0.001


In [3]:
from generate_burgers import load_burgers

dataset = load_burgers(
                x_range=x_range,
                nt = 500, # Number of time steps
                nx = state_dim, # Number of spatial nodes (grid points)
                dt= dt, # Temporal interval
                N = 100, # Number of samples (trajectories) to generate
                visualize=False # Whether to show the animation of state trajectory evolution
                )
target_state=dataset['Y_f']

for key in dataset.keys():
    if key!="meta_data":
        dataset[key]=dataset[key].squeeze(0)

print(dataset['observations'].shape)
print(dataset['actions'].shape)
print(dataset['rewards'].shape)
print(dataset['terminals'].shape)
print(dataset['timeouts'].shape)
print(dataset['Y_f'].shape)


env=Burgers(x_range, nx=state_dim, nt=nt, dt=dt, energy_penalty=0.01, device='cpu')
env.reset()
env.set_target_state(target_state)

TypeError: load_burgers() got an unexpected keyword argument 'visualize'

In [4]:
replay_buffer = OfflineReplayBuffer(device, state_dim, action_dim, len(dataset['actions']))
replay_buffer.load_dataset(dataset=dataset)
replay_buffer.compute_return(gamma) # Compute the discounted return for the trajectory, with a discount factor of gamma (default 0.99).

Computing the returns: 499it [00:00, 374880.48it/s]


In [5]:
# summarywriter logger
# path

current_time = time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime())
path = os.path.join(path, str(seed))
os.makedirs(os.path.join(path, current_time))
print(f'Made log directory at {os.path.join(path, current_time)}')

logger_path = os.path.join(path, current_time)
logger = SummaryWriter(log_dir=logger_path, comment='')

Made log directory at logs\20241219\2024_12_22__14_14_26


In [6]:
# initilize
value = ValueLearner(device=device,
                        state_dim=state_dim,
                        hidden_dim=v_hidden_dim,
                        depth=v_depth,
                        value_lr=v_lr,
                        batch_size=v_batch_size)

Q_bc = QSarsaLearner(device=device,
                        state_dim=state_dim,
                        action_dim=action_dim,
                        hidden_dim=q_hidden_dim, depth=q_depth,
                        Q_lr=q_lr,
                        target_update_freq=target_update_freq,
                        tau=tau,
                        gamma=gamma,
                        batch_size=q_batch_size)
if is_offpolicy_update: 
    Q_pi=QPiLearner(device=device,
                        state_dim=state_dim,
                        action_dim=action_dim,
                        hidden_dim=q_hidden_dim,
                        depth=q_depth,
                        Q_lr=q_lr,
                        target_update_freq=target_update_freq,
                        tau=tau,
                        gamma=gamma,
                        batch_size=q_batch_size)
bc=BehaviorCloning(device=device,
                        state_dim=state_dim,
                        hidden_dim=bc_hidden_dim,
                        depth=bc_depth,
                        action_dim=action_dim,
                        policy_lr=bc_lr,
                        batch_size=bc_batch_size)
bppo=BehaviorProximalPolicyOptimization(device=device,
                        state_dim=state_dim,
                        hidden_dim=bppo_hidden_dim,
                        depth=bppo_depth,
                        action_dim=action_dim,
                        policy_lr=bppo_lr,
                        clip_ratio=clip_ratio,
                        entropy_weight=entropy_weight,
                        decay=decay,
                        omega=omega,
                        batch_size=bppo_batch_size)

In [7]:
# value training 
value_path = os.path.join(path, 'value.pt')
if os.path.exists(value_path):
    value.load(value_path)
else:
    for step in tqdm.tqdm(range(int(v_steps)), desc='value updating ......'):
        value_loss = value.update(replay_buffer)
        
        if step % int(log_freq) == 0:
            print(f"Step: {step}, Loss: {value_loss:.6f}")
            logger.add_scalar('value_loss', value_loss, global_step=(step+1))
    value.save(value_path)

# Q_bc training
Q_bc_path = os.path.join(path, 'Q_bc.pt')
if os.path.exists(Q_bc_path):
    Q_bc.load(Q_bc_path)
else:
    for step in tqdm.tqdm(range(int(q_bc_steps)), desc='Q_bc updating ......'):
        Q_bc_loss = Q_bc.update(replay_buffer, pi=None)
        if step % int(log_freq) == 0:
            print(f"Step: {step}, Loss: {Q_bc_loss:.6f}")
            logger.add_scalar('Q_bc_loss', Q_bc_loss, global_step=(step+1))
    Q_bc.save(Q_bc_path)

if is_offpolicy_update:
    Q_pi.load(Q_bc_path)



Value parameters loaded
Q function parameters loaded


  self._value.load_state_dict(torch.load(path, map_location=self._device))
  self._Q.load_state_dict(torch.load(path, map_location=self._device))


In [8]:
mean, std = 0., 1.

# bc training
best_bc_path = os.path.join(path, 'bc_best.pt')
if os.path.exists(best_bc_path):
    bc.load(best_bc_path)
else:
    best_bc_score = 0
    for step in tqdm.tqdm(range(int(bc_steps)), desc='bc updating ......'):
        bc_loss = bc.update(replay_buffer)
        if step % int(log_freq) == 0:
            current_bc_score = bc.offline_evaluate(env, seed)
            if current_bc_score > best_bc_score:
                best_bc_score = current_bc_score
                bc.save(best_bc_path)
                np.savetxt(os.path.join(path, 'best_bc.csv'), [best_bc_score], fmt='%f', delimiter=',')
            print(f"Step: {step}, Loss: {bc_loss:.4f}, Score: {current_bc_score:.4f}")
            logger.add_scalar('bc_loss', bc_loss, global_step=(step+1))
            logger.add_scalar('bc_score', current_bc_score, global_step=(step+1))
    bc.save(os.path.join(path, 'bc_best.pt'))
    bc.load(best_bc_path)



Behavior policy parameters loaded


  self._policy.load_state_dict(torch.load(path, map_location=self._device))


In [9]:
# bppo training
bppo.load(best_bc_path)
best_bppo_path = os.path.join(path, current_time, 'bppo_best.pt')
Q = Q_bc # If advantage replacement, then Q_{\pi k}=Q_{\pi\beta}
best_bppo_score = bppo.offline_evaluate(env, seed, eval_episodes=10)
print('best_bppo_score:',best_bppo_score,'-------------------------')
for step in tqdm.tqdm(range(int(bppo_steps)), desc='bppo updating ......'):
    print(f"\nEpoch {step+1}:")
    if step > 200:
        is_clip_decay = False
        is_bppo_lr_decay = False
    bppo_loss = bppo.update(replay_buffer, Q, value, is_clip_decay, is_bppo_lr_decay)
    current_bppo_score = bppo.offline_evaluate(env, seed, eval_episodes=10) # J_{\pi k}
    if current_bppo_score > best_bppo_score:
        best_bppo_score = current_bppo_score
        print('best_bppo_score:',best_bppo_score,'-------------------------')
        bppo.save(best_bppo_path)
        np.savetxt(os.path.join(path, current_time, 'best_bppo.csv'), [best_bppo_score], fmt='%f', delimiter=',')
        if is_update_old_policy:
            bppo.set_old_policy() # Set the old policy to the current policy
    if is_offpolicy_update: # If not using advantage replacement, calculate Q_{\pi k} by Q-learning
        for _ in tqdm(range(int(q_pi_steps)), desc='Q_pi updating ......'): 
            Q_pi_loss = Q_pi.update(replay_buffer, bppo)
        Q = Q_pi
    print(f"Step: {step}, Loss: {bppo_loss:.4f}, Score: {current_bppo_score:.4f}")
    logger.add_scalar('bppo_loss', bppo_loss, global_step=(step+1))
    logger.add_scalar('bppo_score', current_bppo_score, global_step=(step+1))

logger.close()

  self._policy.load_state_dict(torch.load(path, map_location=self._device))


Policy parameters loaded
best_bppo_score: -10000.0 -------------------------


bppo updating ......:   0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1:


bppo updating ......:   1%|          | 1/100 [00:10<18:06, 10.97s/it]

Step: 0, Loss: -0.2860, Score: -10000.0000

Epoch 2:
Step: 1, Loss: 0.1025, Score: -10000.0000

Epoch 3:
Step: 2, Loss: 0.0300, Score: -10000.0000

Epoch 4:


bppo updating ......:   9%|▉         | 9/100 [00:11<01:09,  1.31it/s]

best_bppo_score: -9259.149132811173 -------------------------
Policy parameters saved in logs\20241219\2024_12_22__14_14_26\bppo_best.pt
Step: 3, Loss: 0.0297, Score: -9259.1491

Epoch 5:
Step: 4, Loss: -0.3098, Score: -10000.0000

Epoch 6:
Step: 5, Loss: -0.0427, Score: -10000.0000

Epoch 7:
Step: 6, Loss: 0.0134, Score: -10000.0000

Epoch 8:
Step: 7, Loss: 0.0345, Score: -10000.0000

Epoch 9:
Step: 8, Loss: 0.0289, Score: -10000.0000

Epoch 10:
Step: 9, Loss: 0.0324, Score: -10000.0000

Epoch 11:
Step: 10, Loss: 0.0332, Score: -10000.0000

Epoch 12:
Step: 11, Loss: 0.0294, Score: -10000.0000

Epoch 13:
Step: 12, Loss: 0.0332, Score: -10000.0000

Epoch 14:


bppo updating ......:  19%|█▉        | 19/100 [00:11<00:20,  4.02it/s]

Step: 13, Loss: 0.0294, Score: -10000.0000

Epoch 15:
Step: 14, Loss: 0.0277, Score: -10000.0000

Epoch 16:
Step: 15, Loss: 0.0309, Score: -10000.0000

Epoch 17:
Step: 16, Loss: 0.0279, Score: -10000.0000

Epoch 18:
Step: 17, Loss: 0.0298, Score: -10000.0000

Epoch 19:
Step: 18, Loss: 0.0292, Score: -10000.0000

Epoch 20:
Step: 19, Loss: 0.0345, Score: -10000.0000

Epoch 21:
Step: 20, Loss: 0.0287, Score: -10000.0000

Epoch 22:
Step: 21, Loss: 0.0338, Score: -10000.0000

Epoch 23:


bppo updating ......:  29%|██▉       | 29/100 [00:12<00:08,  8.61it/s]

Step: 22, Loss: 0.0347, Score: -10000.0000

Epoch 24:
Step: 23, Loss: 0.0335, Score: -10000.0000

Epoch 25:
Step: 24, Loss: 0.0338, Score: -10000.0000

Epoch 26:
Step: 25, Loss: 0.0387, Score: -10000.0000

Epoch 27:
Step: 26, Loss: 0.0336, Score: -10000.0000

Epoch 28:
Step: 27, Loss: 0.0321, Score: -10000.0000

Epoch 29:
Step: 28, Loss: 0.0368, Score: -10000.0000

Epoch 30:
Step: 29, Loss: 0.0391, Score: -10000.0000

Epoch 31:
Step: 30, Loss: 0.0340, Score: -10000.0000

Epoch 32:
Step: 31, Loss: 0.0363, Score: -10000.0000

Epoch 33:


bppo updating ......:  39%|███▉      | 39/100 [00:12<00:04, 15.17it/s]

Step: 32, Loss: 0.0350, Score: -10000.0000

Epoch 34:
Step: 33, Loss: 0.0346, Score: -10000.0000

Epoch 35:
Step: 34, Loss: 0.0316, Score: -10000.0000

Epoch 36:
Step: 35, Loss: 0.0340, Score: -10000.0000

Epoch 37:
Step: 36, Loss: 0.0339, Score: -10000.0000

Epoch 38:
Step: 37, Loss: 0.0378, Score: -10000.0000

Epoch 39:
Step: 38, Loss: 0.0325, Score: -10000.0000

Epoch 40:
Step: 39, Loss: 0.0292, Score: -10000.0000

Epoch 41:
Step: 40, Loss: 0.0375, Score: -10000.0000

Epoch 42:


bppo updating ......:  49%|████▉     | 49/100 [00:12<00:02, 22.13it/s]

Step: 41, Loss: 0.0361, Score: -10000.0000

Epoch 43:
Step: 42, Loss: 0.0334, Score: -10000.0000

Epoch 44:
Step: 43, Loss: 0.0335, Score: -10000.0000

Epoch 45:
Step: 44, Loss: 0.0356, Score: -10000.0000

Epoch 46:
Step: 45, Loss: 0.0334, Score: -10000.0000

Epoch 47:
Step: 46, Loss: 0.0355, Score: -10000.0000

Epoch 48:
Step: 47, Loss: 0.0355, Score: -10000.0000

Epoch 49:
Step: 48, Loss: 0.0337, Score: -10000.0000

Epoch 50:


bppo updating ......:  57%|█████▋    | 57/100 [00:12<00:01, 27.14it/s]

Step: 49, Loss: 0.0384, Score: -10000.0000

Epoch 51:
Step: 50, Loss: 0.0307, Score: -10000.0000

Epoch 52:
Step: 51, Loss: 0.0322, Score: -10000.0000

Epoch 53:
Step: 52, Loss: 0.0335, Score: -10000.0000

Epoch 54:
Step: 53, Loss: 0.0377, Score: -10000.0000

Epoch 55:
Step: 54, Loss: 0.0302, Score: -10000.0000

Epoch 56:
Step: 55, Loss: 0.0369, Score: -10000.0000

Epoch 57:
Step: 56, Loss: 0.0349, Score: -10000.0000

Epoch 58:


bppo updating ......:  62%|██████▏   | 62/100 [00:12<00:01, 31.34it/s]

Step: 57, Loss: 0.0326, Score: -10000.0000

Epoch 59:
Step: 58, Loss: 0.0338, Score: -10000.0000

Epoch 60:
Step: 59, Loss: 0.0373, Score: -10000.0000

Epoch 61:
Step: 60, Loss: 0.0368, Score: -10000.0000

Epoch 62:
Step: 61, Loss: 0.0392, Score: -10000.0000

Epoch 63:
Step: 62, Loss: 0.0373, Score: -10000.0000

Epoch 64:
Step: 63, Loss: 0.0376, Score: -10000.0000

Epoch 65:
Step: 64, Loss: 0.0344, Score: -10000.0000

Epoch 66:
Step: 65, Loss: 0.0376, Score: -10000.0000

Epoch 67:


bppo updating ......:  67%|██████▋   | 67/100 [00:12<00:00, 33.74it/s]

Step: 66, Loss: 0.0311, Score: -10000.0000

Epoch 68:
Step: 67, Loss: 0.0364, Score: -10000.0000

Epoch 69:
Step: 68, Loss: 0.0343, Score: -10000.0000

Epoch 70:
Step: 69, Loss: 0.0328, Score: -10000.0000

Epoch 71:


bppo updating ......:  72%|███████▏  | 72/100 [00:13<00:01, 18.00it/s]

best_bppo_score: -9092.083142119134 -------------------------
Policy parameters saved in logs\20241219\2024_12_22__14_14_26\bppo_best.pt
Step: 70, Loss: 0.0356, Score: -9092.0831

Epoch 72:
Step: 71, Loss: -0.2931, Score: -10000.0000

Epoch 73:


bppo updating ......:  76%|███████▌  | 76/100 [00:14<00:01, 13.50it/s]

Step: 72, Loss: -0.1428, Score: -9820.3164

Epoch 74:
Step: 73, Loss: -0.1199, Score: -10000.0000

Epoch 75:
Step: 74, Loss: -0.0802, Score: -10000.0000

Epoch 76:
Step: 75, Loss: -0.0392, Score: -10000.0000

Epoch 77:
Step: 76, Loss: -0.1682, Score: -10000.0000

Epoch 78:
Step: 77, Loss: -0.0383, Score: -10000.0000

Epoch 79:


bppo updating ......:  84%|████████▍ | 84/100 [00:14<00:01, 14.59it/s]

Step: 78, Loss: 0.0443, Score: -9674.4577

Epoch 80:
Step: 79, Loss: 0.0426, Score: -10000.0000

Epoch 81:
Step: 80, Loss: 0.0158, Score: -10000.0000

Epoch 82:
Step: 81, Loss: 0.0518, Score: -10000.0000

Epoch 83:
Step: 82, Loss: 0.0067, Score: -10000.0000

Epoch 84:
Step: 83, Loss: 0.0219, Score: -10000.0000

Epoch 85:
Step: 84, Loss: 0.0297, Score: -10000.0000

Epoch 86:
Step: 85, Loss: 0.0943, Score: -10000.0000

Epoch 87:
Step: 86, Loss: 0.0349, Score: -10000.0000

Epoch 88:


bppo updating ......:  94%|█████████▍| 94/100 [00:14<00:00, 23.01it/s]

Step: 87, Loss: 0.2578, Score: -10000.0000

Epoch 89:
Step: 88, Loss: 0.0350, Score: -10000.0000

Epoch 90:
Step: 89, Loss: 0.0676, Score: -10000.0000

Epoch 91:
Step: 90, Loss: 0.0375, Score: -10000.0000

Epoch 92:
Step: 91, Loss: 0.0623, Score: -10000.0000

Epoch 93:
Step: 92, Loss: 0.0343, Score: -10000.0000

Epoch 94:
Step: 93, Loss: 0.0378, Score: -10000.0000

Epoch 95:
Step: 94, Loss: 0.0367, Score: -10000.0000

Epoch 96:
Step: 95, Loss: 0.0367, Score: -10000.0000

Epoch 97:


bppo updating ......: 100%|██████████| 100/100 [00:14<00:00,  6.67it/s]

Step: 96, Loss: 0.0363, Score: -10000.0000

Epoch 98:
Step: 97, Loss: 0.0359, Score: -10000.0000

Epoch 99:
Step: 98, Loss: 0.0381, Score: -10000.0000

Epoch 100:
Step: 99, Loss: 0.0388, Score: -10000.0000





In [14]:
from generate_burgers import generate_initial_y

x=np.linspace(*x_range, state_dim)
y0=generate_initial_y(x)
#y0=torch.tensor(y0, dtype=torch.float32)

timesteps=16
y=y0

for _ in range(timesteps):
    env.reset(y0)
    action=bppo._policy(torch.tensor(y)).sample().detach().numpy()
    env.step(action)



print(((target_state-y)**2).mean())

RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float