In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import tqdm
import os
import time
from tensorboardX import SummaryWriter

#from envs.burgers import Burgers
from buffer import OfflineReplayBuffer
from critic import ValueLearner, QPiLearner, QSarsaLearner
from bppo import BehaviorCloning, BehaviorProximalPolicyOptimization

In [2]:
# Hyperparameters

# Experiment
env_name='burger'
path='logs'
log_freq=int(20)
seed=20241219
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
N=100 # Number of trajectories to collect for offline dataset

# For Value
v_steps=int(5000)
v_hidden_dim = 256
v_depth = 3
v_lr = 1e-4
v_batch_size = 64

# For Q
q_bc_steps=int(5000)
q_pi_steps=10 # Number of steps to update Q-network in each iteration. Only used if is_offpolicy_update=True.
q_hidden_dim = 256
q_depth = 3
q_lr = 1e-4
q_batch_size = 64
target_update_freq=2
tau=0.005 # Soft update rate for target Q network parameters. See Q_learner.update()
gamma=0.99 # Discount factor for calculating the return.
is_offpolicy_update=False # Whether to use advantage replacement (as proposed in the BPPO paper) in Q-learning.
# If False, use Q-learning to update the Q-network parameters in each iteration.
# If True, only update the Q-network parameters once, and keep using this Q-network.

# For BC
bc_steps=int(500)
bc_lr = 1e-4
bc_hidden_dim = 256
bc_depth = 3
bc_batch_size = 64

# For BPPO
bppo_steps=int(100)
bppo_hidden_dim = 256
bppo_depth = 3
bppo_lr = 1e-4
bppo_batch_size = 64
clip_ratio=0.25 # PPO clip ratio. The probability ratio between new and old policy is clipped to be in the range [1-clip_ratio, 1+clip_ratio]
entropy_weight=0.00 # Weight of entropy loss in PPO and BPPO. Can be set to 0.01 for medium tasks.
decay=0.96 # Decay rate of PPO clip ratio
omega=0.9 # Related to setting the weight of advantage (see PPO code)
is_clip_decay=True # Whether to decay the clip_ratio during training
is_bppo_lr_decay=True # Whether to decay the learning rate of BPPO during trainining
is_update_old_policy=True # Whether to update the old policy of BPPO in each iteration. The old policy is used to calculate the probability ratio.
is_state_norm=False # Whether to normalize the states of the dataset.

# Other Settings
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device=torch.device('cpu')
state_dim = 128
action_dim = 128
x_range=(-5,5)



In [3]:
#env=Burgers(
#    n=state_dim,
#    m=action_dim,
#    T=8,
#    x_range=[-1,1],
#    energy_penalty=0.01,
#    device='cpu'
#)

In [4]:
from generate_burgers import load_burgers

dataset = load_burgers(
                x_range=x_range,
                nt = 500, # Number of time steps
                nx = state_dim, # Number of spatial nodes (grid points)
                dt= 0.001, # Temporal interval
                N = 1, # Number of samples (trajectories) to generate
                visualize=False # Whether to show the animation of state trajectory evolution
                )

for key in dataset.keys():
    if key!="meta_data":
        dataset[key]=dataset[key].squeeze(0)

print(dataset['observations'].shape)
print(dataset['actions'].shape)
print(dataset['rewards'].shape)
print(dataset['terminals'].shape)
print(dataset['timeouts'].shape)

Generating samples: 100%|██████████| 1/1 [00:00<00:00,  5.37it/s]
Setting terminal flags: 100%|██████████| 1/1 [00:00<00:00, 150.63it/s]
Setting rewards: 100%|██████████| 1/1 [00:00<00:00, 982.04it/s]

Y_bar shape:  (1, 500, 128)
Y_f shape:  (1, 128)
U shape:  (1, 500, 128)
Terminals shape:  (1, 500)
Timeouts shape:  (1, 500)
Rewards shape:  (1, 500)
(500, 128)
(500, 128)
(500,)
(500,)
(500,)





In [5]:
replay_buffer = OfflineReplayBuffer(device, state_dim, action_dim, len(dataset['actions']))
replay_buffer.load_dataset(dataset=dataset)
replay_buffer.compute_return(gamma) # Compute the discounted return for the trajectory, with a discount factor of gamma (default 0.99).

Computing the returns: 499it [00:00, 176323.31it/s]


In [6]:
# summarywriter logger
# path

current_time = time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime())
path = os.path.join(path, str(seed))
os.makedirs(os.path.join(path, current_time))
print(f'Made log directory at {os.path.join(path, current_time)}')

logger_path = os.path.join(path, current_time)
logger = SummaryWriter(log_dir=logger_path, comment='')

Made log directory at logs\20241219\2024_12_22__08_48_18


In [7]:
# initilize
value = ValueLearner(device=device,
                        state_dim=state_dim,
                        hidden_dim=v_hidden_dim,
                        depth=v_depth,
                        value_lr=v_lr,
                        batch_size=v_batch_size)

Q_bc = QSarsaLearner(device=device,
                        state_dim=state_dim,
                        action_dim=action_dim,
                        hidden_dim=q_hidden_dim, depth=q_depth,
                        Q_lr=q_lr,
                        target_update_freq=target_update_freq,
                        tau=tau,
                        gamma=gamma,
                        batch_size=q_batch_size)
if is_offpolicy_update: 
    Q_pi=QPiLearner(device=device,
                        state_dim=state_dim,
                        action_dim=action_dim,
                        hidden_dim=q_hidden_dim,
                        depth=q_depth,
                        Q_lr=q_lr,
                        target_update_freq=target_update_freq,
                        tau=tau,
                        gamma=gamma,
                        batch_size=q_batch_size)
bc=BehaviorCloning(device=device,
                        state_dim=state_dim,
                        hidden_dim=bc_hidden_dim,
                        depth=bc_depth,
                        action_dim=action_dim,
                        policy_lr=bc_lr,
                        batch_size=bc_batch_size)
bppo=BehaviorProximalPolicyOptimization(device=device,
                        state_dim=state_dim,
                        hidden_dim=bppo_hidden_dim,
                        depth=bppo_depth,
                        action_dim=action_dim,
                        policy_lr=bppo_lr,
                        clip_ratio=clip_ratio,
                        entropy_weight=entropy_weight,
                        decay=decay,
                        omega=omega,
                        batch_size=bppo_batch_size)

In [8]:
# value training 
value_path = os.path.join(path, 'value.pt')
if os.path.exists(value_path):
    value.load(value_path)
else:
    for step in tqdm.tqdm(range(int(v_steps)), desc='value updating ......'): 
        value_loss = value.update(replay_buffer)
        
        if step % int(log_freq) == 0:
            print(f"Step: {step}, Loss: {value_loss:.4f}")
            logger.add_scalar('value_loss', value_loss, global_step=(step+1))
    value.save(value_path)

# Q_bc training
Q_bc_path = os.path.join(path, 'Q_bc.pt')
if os.path.exists(Q_bc_path):
    Q_bc.load(Q_bc_path)
else:
    for step in tqdm.tqdm(range(int(q_bc_steps)), desc='Q_bc updating ......'):
        Q_bc_loss = Q_bc.update(replay_buffer, pi=None)
        if step % int(log_freq) == 0:
            print(f"Step: {step}, Loss: {Q_bc_loss:.4f}")
            logger.add_scalar('Q_bc_loss', Q_bc_loss, global_step=(step+1))
    Q_bc.save(Q_bc_path)

if is_offpolicy_update:
    Q_pi.load(Q_bc_path)



value updating ......:   1%|          | 37/5000 [00:01<02:19, 35.63it/s] 

Step: 0, Loss: 638.8598
Step: 20, Loss: 320.1189
Step: 40, Loss: 565.1950
Step: 60, Loss: 424.6550


value updating ......:   3%|▎         | 139/5000 [00:01<00:31, 152.38it/s]

Step: 80, Loss: 801.4885
Step: 100, Loss: 742.3171
Step: 120, Loss: 941.8500
Step: 140, Loss: 743.4591


value updating ......:   4%|▍         | 205/5000 [00:01<00:22, 217.60it/s]

Step: 160, Loss: 462.5739
Step: 180, Loss: 396.0893
Step: 200, Loss: 694.8433
Step: 220, Loss: 744.7766


value updating ......:   6%|▌         | 306/5000 [00:02<00:17, 275.69it/s]

Step: 240, Loss: 564.7543
Step: 260, Loss: 804.4673
Step: 280, Loss: 397.1973
Step: 300, Loss: 684.2120


value updating ......:   7%|▋         | 370/5000 [00:02<00:17, 267.53it/s]

Step: 320, Loss: 743.1808
Step: 340, Loss: 754.7490
Step: 360, Loss: 350.7657


value updating ......:   9%|▊         | 434/5000 [00:02<00:15, 289.17it/s]

Step: 380, Loss: 659.4980
Step: 400, Loss: 463.6422
Step: 420, Loss: 497.9687
Step: 440, Loss: 791.5464


value updating ......:  10%|▉         | 497/5000 [00:02<00:15, 294.60it/s]

Step: 460, Loss: 792.3071
Step: 480, Loss: 606.4875
Step: 500, Loss: 534.1238
Step: 520, Loss: 773.9961


value updating ......:  12%|█▏        | 601/5000 [00:03<00:13, 324.15it/s]

Step: 540, Loss: 1053.7471
Step: 560, Loss: 559.0447
Step: 580, Loss: 518.7889
Step: 600, Loss: 608.3584


value updating ......:  13%|█▎        | 670/5000 [00:03<00:13, 332.75it/s]

Step: 620, Loss: 701.5894
Step: 640, Loss: 554.2498
Step: 660, Loss: 597.8645
Step: 680, Loss: 708.7838


value updating ......:  15%|█▍        | 738/5000 [00:03<00:12, 333.91it/s]

Step: 700, Loss: 561.6718
Step: 720, Loss: 565.0482
Step: 740, Loss: 325.8413
Step: 760, Loss: 726.3159


value updating ......:  17%|█▋        | 836/5000 [00:03<00:13, 307.76it/s]

Step: 780, Loss: 708.1761
Step: 800, Loss: 528.6220
Step: 820, Loss: 444.6699
Step: 840, Loss: 592.1289


value updating ......:  18%|█▊        | 903/5000 [00:04<00:13, 299.17it/s]

Step: 860, Loss: 601.7285
Step: 880, Loss: 662.8354
Step: 900, Loss: 853.8428
Step: 920, Loss: 357.7036


value updating ......:  20%|██        | 1008/5000 [00:04<00:12, 327.25it/s]

Step: 940, Loss: 431.4635
Step: 960, Loss: 863.9988
Step: 980, Loss: 741.7407
Step: 1000, Loss: 558.1211


value updating ......:  22%|██▏       | 1076/5000 [00:04<00:12, 321.40it/s]

Step: 1020, Loss: 787.9465
Step: 1040, Loss: 674.3970
Step: 1060, Loss: 536.1116
Step: 1080, Loss: 675.0549


value updating ......:  23%|██▎       | 1144/5000 [00:04<00:11, 321.49it/s]

Step: 1100, Loss: 406.8063
Step: 1120, Loss: 833.5507
Step: 1140, Loss: 415.9622
Step: 1160, Loss: 478.1833


value updating ......:  24%|██▍       | 1209/5000 [00:05<00:13, 279.01it/s]

Step: 1180, Loss: 687.3192
Step: 1200, Loss: 662.6204
Step: 1220, Loss: 225.9887


value updating ......:  26%|██▌       | 1308/5000 [00:05<00:11, 307.80it/s]

Step: 1240, Loss: 712.2952
Step: 1260, Loss: 597.4591
Step: 1280, Loss: 549.0647
Step: 1300, Loss: 274.6313


value updating ......:  28%|██▊       | 1378/5000 [00:05<00:11, 325.80it/s]

Step: 1320, Loss: 531.5466
Step: 1340, Loss: 814.0281
Step: 1360, Loss: 733.6257
Step: 1380, Loss: 532.2000


value updating ......:  29%|██▉       | 1447/5000 [00:05<00:10, 331.19it/s]

Step: 1400, Loss: 545.4302
Step: 1420, Loss: 367.0625
Step: 1440, Loss: 779.0391
Step: 1460, Loss: 584.6986


value updating ......:  30%|███       | 1519/5000 [00:06<00:10, 342.63it/s]

Step: 1480, Loss: 630.8165
Step: 1500, Loss: 528.8572
Step: 1520, Loss: 416.2148
Step: 1540, Loss: 659.9594


value updating ......:  33%|███▎      | 1633/5000 [00:06<00:09, 358.16it/s]

Step: 1560, Loss: 770.9537
Step: 1580, Loss: 506.3866
Step: 1600, Loss: 393.1492
Step: 1620, Loss: 761.0797


value updating ......:  34%|███▍      | 1704/5000 [00:06<00:09, 344.62it/s]

Step: 1640, Loss: 520.5682
Step: 1660, Loss: 569.3446
Step: 1680, Loss: 413.8614
Step: 1700, Loss: 549.5835


value updating ......:  35%|███▌      | 1772/5000 [00:06<00:10, 319.77it/s]

Step: 1720, Loss: 725.3757
Step: 1740, Loss: 402.0027
Step: 1760, Loss: 1012.4977
Step: 1780, Loss: 541.9210


value updating ......:  37%|███▋      | 1837/5000 [00:07<00:10, 315.29it/s]

Step: 1800, Loss: 478.0930
Step: 1820, Loss: 654.5812
Step: 1840, Loss: 818.4982
Step: 1860, Loss: 624.0108


value updating ......:  39%|███▉      | 1939/5000 [00:07<00:09, 323.48it/s]

Step: 1880, Loss: 547.3718
Step: 1900, Loss: 769.2697
Step: 1920, Loss: 917.1426


value updating ......:  39%|███▉      | 1972/5000 [00:07<00:09, 315.92it/s]

Step: 1940, Loss: 874.3759
Step: 1960, Loss: 404.5074
Step: 1980, Loss: 817.7238
Step: 2000, Loss: 496.8673


value updating ......:  41%|████▏     | 2073/5000 [00:07<00:09, 323.94it/s]

Step: 2020, Loss: 752.0190
Step: 2040, Loss: 605.4684
Step: 2060, Loss: 731.6558
Step: 2080, Loss: 457.7534


value updating ......:  43%|████▎     | 2139/5000 [00:08<00:08, 318.53it/s]

Step: 2100, Loss: 512.4647
Step: 2120, Loss: 496.7715
Step: 2140, Loss: 893.2159
Step: 2160, Loss: 496.4716


value updating ......:  44%|████▍     | 2207/5000 [00:08<00:08, 312.79it/s]

Step: 2180, Loss: 607.5563
Step: 2200, Loss: 862.6934
Step: 2220, Loss: 701.3168


value updating ......:  46%|████▌     | 2307/5000 [00:08<00:08, 321.60it/s]

Step: 2240, Loss: 572.6505
Step: 2260, Loss: 706.7531
Step: 2280, Loss: 613.0215
Step: 2300, Loss: 566.8564


value updating ......:  47%|████▋     | 2373/5000 [00:08<00:08, 320.73it/s]

Step: 2320, Loss: 505.9854
Step: 2340, Loss: 555.7771
Step: 2360, Loss: 589.0645
Step: 2380, Loss: 585.1517


value updating ......:  49%|████▉     | 2439/5000 [00:08<00:07, 322.04it/s]

Step: 2400, Loss: 422.7511
Step: 2420, Loss: 605.2994
Step: 2440, Loss: 536.0291
Step: 2460, Loss: 489.8204


value updating ......:  51%|█████     | 2543/5000 [00:09<00:07, 339.01it/s]

Step: 2480, Loss: 515.4106
Step: 2500, Loss: 817.1881
Step: 2520, Loss: 355.2080
Step: 2540, Loss: 658.7552


value updating ......:  52%|█████▏    | 2611/5000 [00:09<00:07, 317.69it/s]

Step: 2560, Loss: 705.6857
Step: 2580, Loss: 615.5677
Step: 2600, Loss: 868.4702
Step: 2620, Loss: 425.8678


value updating ......:  54%|█████▎    | 2677/5000 [00:09<00:08, 281.03it/s]

Step: 2640, Loss: 774.0927
Step: 2660, Loss: 539.6527
Step: 2680, Loss: 465.8843


value updating ......:  55%|█████▍    | 2741/5000 [00:09<00:07, 298.44it/s]

Step: 2700, Loss: 657.1459
Step: 2720, Loss: 931.4008
Step: 2740, Loss: 449.6349
Step: 2760, Loss: 708.7882


value updating ......:  56%|█████▌    | 2807/5000 [00:10<00:07, 308.59it/s]

Step: 2780, Loss: 557.0690
Step: 2800, Loss: 473.7819
Step: 2820, Loss: 952.0511


value updating ......:  57%|█████▊    | 2875/5000 [00:10<00:06, 312.08it/s]

Step: 2840, Loss: 407.2787
Step: 2860, Loss: 957.2119
Step: 2880, Loss: 626.4003
Step: 2900, Loss: 945.8451


value updating ......:  59%|█████▉    | 2973/5000 [00:10<00:06, 313.64it/s]

Step: 2920, Loss: 561.4493
Step: 2940, Loss: 547.9085
Step: 2960, Loss: 490.9080
Step: 2980, Loss: 461.2778


value updating ......:  61%|██████    | 3038/5000 [00:10<00:06, 290.18it/s]

Step: 3000, Loss: 552.9227
Step: 3020, Loss: 605.9680
Step: 3040, Loss: 833.0936


value updating ......:  62%|██████▏   | 3107/5000 [00:11<00:06, 313.78it/s]

Step: 3060, Loss: 454.4328
Step: 3080, Loss: 301.8225
Step: 3100, Loss: 508.6702
Step: 3120, Loss: 1070.9968


value updating ......:  64%|██████▎   | 3179/5000 [00:11<00:05, 333.57it/s]

Step: 3140, Loss: 600.9622
Step: 3160, Loss: 685.0624
Step: 3180, Loss: 534.8539


value updating ......:  65%|██████▍   | 3243/5000 [00:11<00:06, 283.15it/s]

Step: 3200, Loss: 507.1703
Step: 3220, Loss: 423.0503
Step: 3240, Loss: 663.2650


value updating ......:  66%|██████▌   | 3305/5000 [00:11<00:05, 293.19it/s]

Step: 3260, Loss: 481.0671
Step: 3280, Loss: 518.9271
Step: 3300, Loss: 644.1630
Step: 3320, Loss: 514.4845


value updating ......:  68%|██████▊   | 3407/5000 [00:12<00:04, 320.89it/s]

Step: 3340, Loss: 539.7743
Step: 3360, Loss: 324.8545
Step: 3380, Loss: 911.8685
Step: 3400, Loss: 529.6162


value updating ......:  69%|██████▉   | 3473/5000 [00:12<00:04, 312.34it/s]

Step: 3420, Loss: 353.6418
Step: 3440, Loss: 350.6833
Step: 3460, Loss: 592.7999
Step: 3480, Loss: 806.2081


value updating ......:  71%|███████   | 3539/5000 [00:12<00:04, 318.60it/s]

Step: 3500, Loss: 742.9350
Step: 3520, Loss: 684.9734
Step: 3540, Loss: 524.0303
Step: 3560, Loss: 529.0164


value updating ......:  72%|███████▏  | 3608/5000 [00:12<00:04, 300.91it/s]

Step: 3580, Loss: 709.4185
Step: 3600, Loss: 865.5955
Step: 3620, Loss: 742.2057


value updating ......:  74%|███████▎  | 3675/5000 [00:12<00:04, 300.82it/s]

Step: 3640, Loss: 566.9672
Step: 3660, Loss: 598.7563
Step: 3680, Loss: 723.9078
Step: 3700, Loss: 661.6117


value updating ......:  76%|███████▌  | 3780/5000 [00:13<00:03, 333.13it/s]

Step: 3720, Loss: 851.9012
Step: 3740, Loss: 295.1391
Step: 3760, Loss: 522.5359
Step: 3780, Loss: 591.1624
Step: 3800, Loss: 404.7218


value updating ......:  78%|███████▊  | 3885/5000 [00:13<00:03, 339.29it/s]

Step: 3820, Loss: 422.8117
Step: 3840, Loss: 695.4582
Step: 3860, Loss: 670.3823
Step: 3880, Loss: 495.1016


value updating ......:  79%|███████▉  | 3955/5000 [00:13<00:03, 333.86it/s]

Step: 3900, Loss: 587.4418
Step: 3920, Loss: 357.8606
Step: 3940, Loss: 639.3068
Step: 3960, Loss: 679.0305


value updating ......:  80%|████████  | 4023/5000 [00:14<00:03, 318.84it/s]

Step: 3980, Loss: 428.6386
Step: 4000, Loss: 419.2843
Step: 4020, Loss: 644.9881
Step: 4040, Loss: 506.1121


value updating ......:  82%|████████▏ | 4088/5000 [00:14<00:03, 288.39it/s]

Step: 4060, Loss: 506.4437
Step: 4080, Loss: 353.2959
Step: 4100, Loss: 410.2612


value updating ......:  83%|████████▎ | 4154/5000 [00:14<00:02, 296.15it/s]

Step: 4120, Loss: 966.2755
Step: 4140, Loss: 657.4513
Step: 4160, Loss: 970.6635
Step: 4180, Loss: 792.2530


value updating ......:  85%|████████▌ | 4260/5000 [00:14<00:02, 314.79it/s]

Step: 4200, Loss: 563.3919
Step: 4220, Loss: 632.3127
Step: 4240, Loss: 473.7846


value updating ......:  86%|████████▌ | 4292/5000 [00:14<00:02, 309.27it/s]

Step: 4260, Loss: 474.1518
Step: 4280, Loss: 815.3550
Step: 4300, Loss: 390.4822
Step: 4320, Loss: 460.0091


value updating ......:  88%|████████▊ | 4400/5000 [00:15<00:01, 341.07it/s]

Step: 4340, Loss: 500.9609
Step: 4360, Loss: 808.7103
Step: 4380, Loss: 620.7711
Step: 4400, Loss: 612.9519


value updating ......:  89%|████████▉ | 4470/5000 [00:15<00:01, 337.28it/s]

Step: 4420, Loss: 653.4590
Step: 4440, Loss: 662.5223
Step: 4460, Loss: 732.6746
Step: 4480, Loss: 199.2470


value updating ......:  91%|█████████ | 4538/5000 [00:15<00:01, 313.61it/s]

Step: 4500, Loss: 680.5910
Step: 4520, Loss: 628.6027
Step: 4540, Loss: 896.7379
Step: 4560, Loss: 457.5738


value updating ......:  93%|█████████▎| 4643/5000 [00:15<00:01, 329.16it/s]

Step: 4580, Loss: 546.0240
Step: 4600, Loss: 673.7510
Step: 4620, Loss: 472.0642
Step: 4640, Loss: 816.7435


value updating ......:  94%|█████████▍| 4713/5000 [00:16<00:00, 331.22it/s]

Step: 4660, Loss: 831.9493
Step: 4680, Loss: 445.5744
Step: 4700, Loss: 763.8932
Step: 4720, Loss: 388.0011


value updating ......:  96%|█████████▌| 4786/5000 [00:16<00:00, 344.88it/s]

Step: 4740, Loss: 726.6476
Step: 4760, Loss: 653.7624
Step: 4780, Loss: 709.2326
Step: 4800, Loss: 666.7553


value updating ......:  97%|█████████▋| 4855/5000 [00:16<00:00, 308.79it/s]

Step: 4820, Loss: 544.1670
Step: 4840, Loss: 660.7374
Step: 4860, Loss: 644.0764


value updating ......:  98%|█████████▊| 4918/5000 [00:16<00:00, 270.01it/s]

Step: 4880, Loss: 753.3221
Step: 4900, Loss: 739.7057
Step: 4920, Loss: 407.8567


value updating ......: 100%|██████████| 5000/5000 [00:17<00:00, 291.73it/s]


Step: 4940, Loss: 614.8984
Step: 4960, Loss: 602.3511
Step: 4980, Loss: 855.9218
Value parameters saved in logs\20241219\value.pt


Q_bc updating ......:   0%|          | 25/5000 [00:00<00:19, 249.73it/s]

Step: 0, Loss: 0.1865
Step: 20, Loss: 0.1759
Step: 40, Loss: 0.1910


Q_bc updating ......:   2%|▏         | 80/5000 [00:00<00:18, 260.16it/s]

Step: 60, Loss: 0.3480
Step: 80, Loss: 0.1583
Step: 100, Loss: 0.2342


Q_bc updating ......:   3%|▎         | 158/5000 [00:00<00:21, 229.78it/s]

Step: 120, Loss: 0.1752
Step: 140, Loss: 0.2106
Step: 160, Loss: 0.1978


Q_bc updating ......:   4%|▍         | 210/5000 [00:00<00:20, 231.98it/s]

Step: 180, Loss: 0.1700
Step: 200, Loss: 0.2530
Step: 220, Loss: 0.2174


Q_bc updating ......:   5%|▌         | 266/5000 [00:01<00:18, 250.63it/s]

Step: 240, Loss: 0.2284
Step: 260, Loss: 0.1076
Step: 280, Loss: 0.1891


Q_bc updating ......:   6%|▋         | 323/5000 [00:01<00:19, 242.33it/s]

Step: 300, Loss: 0.2864
Step: 320, Loss: 0.2702
Step: 340, Loss: 0.1504


Q_bc updating ......:   8%|▊         | 399/5000 [00:01<00:18, 245.31it/s]

Step: 360, Loss: 0.1220
Step: 380, Loss: 0.2180
Step: 400, Loss: 0.3750


Q_bc updating ......:   9%|▉         | 449/5000 [00:01<00:19, 236.23it/s]

Step: 420, Loss: 0.2318
Step: 440, Loss: 0.1581
Step: 460, Loss: 0.3800


Q_bc updating ......:  11%|█         | 532/5000 [00:02<00:17, 259.65it/s]

Step: 480, Loss: 0.1936
Step: 500, Loss: 0.2596
Step: 520, Loss: 0.2920


Q_bc updating ......:  12%|█▏        | 586/5000 [00:02<00:16, 262.87it/s]

Step: 540, Loss: 0.3227
Step: 560, Loss: 0.2476
Step: 580, Loss: 0.2543


Q_bc updating ......:  13%|█▎        | 640/5000 [00:02<00:16, 261.14it/s]

Step: 600, Loss: 0.3158
Step: 620, Loss: 0.4349
Step: 640, Loss: 0.2450


Q_bc updating ......:  14%|█▍        | 692/5000 [00:02<00:18, 239.21it/s]

Step: 660, Loss: 0.4265
Step: 680, Loss: 0.3000
Step: 700, Loss: 0.2516


Q_bc updating ......:  15%|█▍        | 747/5000 [00:03<00:16, 255.19it/s]

Step: 720, Loss: 0.1603
Step: 740, Loss: 0.2044
Step: 760, Loss: 0.2220


Q_bc updating ......:  17%|█▋        | 832/5000 [00:03<00:15, 265.53it/s]

Step: 780, Loss: 0.1471
Step: 800, Loss: 0.2628
Step: 820, Loss: 0.3209


Q_bc updating ......:  18%|█▊        | 886/5000 [00:03<00:16, 256.02it/s]

Step: 840, Loss: 0.1854
Step: 860, Loss: 0.1750
Step: 880, Loss: 0.2398


Q_bc updating ......:  19%|█▉        | 939/5000 [00:03<00:16, 253.54it/s]

Step: 900, Loss: 0.2097
Step: 920, Loss: 0.1982
Step: 940, Loss: 0.2158


Q_bc updating ......:  20%|█▉        | 990/5000 [00:04<00:17, 233.87it/s]

Step: 960, Loss: 0.2997
Step: 980, Loss: 0.2584


Q_bc updating ......:  21%|██        | 1042/5000 [00:04<00:16, 236.76it/s]

Step: 1000, Loss: 0.3455
Step: 1020, Loss: 0.2381
Step: 1040, Loss: 0.1859


Q_bc updating ......:  22%|██▏       | 1094/5000 [00:04<00:16, 237.73it/s]

Step: 1060, Loss: 0.1864
Step: 1080, Loss: 0.3411
Step: 1100, Loss: 0.1916


Q_bc updating ......:  23%|██▎       | 1144/5000 [00:04<00:16, 237.59it/s]

Step: 1120, Loss: 0.2869
Step: 1140, Loss: 0.1726
Step: 1160, Loss: 0.1860


Q_bc updating ......:  24%|██▍       | 1215/5000 [00:05<00:18, 208.38it/s]

Step: 1180, Loss: 0.3771
Step: 1200, Loss: 0.1895


Q_bc updating ......:  25%|██▌       | 1270/5000 [00:05<00:15, 237.72it/s]

Step: 1220, Loss: 0.3059
Step: 1240, Loss: 0.2429
Step: 1260, Loss: 0.2065


Q_bc updating ......:  26%|██▋       | 1322/5000 [00:05<00:15, 234.85it/s]

Step: 1280, Loss: 0.2375
Step: 1300, Loss: 0.2230
Step: 1320, Loss: 0.2477


Q_bc updating ......:  28%|██▊       | 1376/5000 [00:05<00:14, 252.10it/s]

Step: 1340, Loss: 0.3009
Step: 1360, Loss: 0.2473
Step: 1380, Loss: 0.2089


Q_bc updating ......:  29%|██▊       | 1427/5000 [00:05<00:15, 235.20it/s]

Step: 1400, Loss: 0.2635
Step: 1420, Loss: 0.2895
Step: 1440, Loss: 0.3266


Q_bc updating ......:  30%|██▉       | 1482/5000 [00:06<00:14, 246.86it/s]

Step: 1460, Loss: 0.2950
Step: 1480, Loss: 0.2123
Step: 1500, Loss: 0.1835


Q_bc updating ......:  31%|███▏      | 1563/5000 [00:06<00:13, 258.96it/s]

Step: 1520, Loss: 0.2453
Step: 1540, Loss: 0.2273
Step: 1560, Loss: 0.2030


Q_bc updating ......:  32%|███▏      | 1614/5000 [00:06<00:13, 245.97it/s]

Step: 1580, Loss: 0.3687
Step: 1600, Loss: 0.1661
Step: 1620, Loss: 0.2506


Q_bc updating ......:  34%|███▍      | 1691/5000 [00:06<00:13, 243.67it/s]

Step: 1640, Loss: 0.2626
Step: 1660, Loss: 0.2934
Step: 1680, Loss: 0.2973


Q_bc updating ......:  35%|███▍      | 1747/5000 [00:07<00:12, 260.24it/s]

Step: 1700, Loss: 0.1709
Step: 1720, Loss: 0.3165
Step: 1740, Loss: 0.1700


Q_bc updating ......:  36%|███▌      | 1802/5000 [00:07<00:12, 261.09it/s]

Step: 1760, Loss: 0.1634
Step: 1780, Loss: 0.1633
Step: 1800, Loss: 0.2431


Q_bc updating ......:  37%|███▋      | 1855/5000 [00:07<00:12, 251.96it/s]

Step: 1820, Loss: 0.1514
Step: 1840, Loss: 0.2745
Step: 1860, Loss: 0.1535


Q_bc updating ......:  38%|███▊      | 1908/5000 [00:07<00:12, 253.93it/s]

Step: 1880, Loss: 0.2892
Step: 1900, Loss: 0.1647
Step: 1920, Loss: 0.4185


Q_bc updating ......:  39%|███▉      | 1963/5000 [00:08<00:12, 252.14it/s]

Step: 1940, Loss: 0.1176
Step: 1960, Loss: 0.2595
Step: 1980, Loss: 0.1299


Q_bc updating ......:  41%|████      | 2040/5000 [00:08<00:12, 236.88it/s]

Step: 2000, Loss: 0.2841
Step: 2020, Loss: 0.1969
Step: 2040, Loss: 0.2763


Q_bc updating ......:  42%|████▏     | 2090/5000 [00:08<00:12, 239.91it/s]

Step: 2060, Loss: 0.2336
Step: 2080, Loss: 0.1896
Step: 2100, Loss: 0.1856


Q_bc updating ......:  43%|████▎     | 2170/5000 [00:08<00:11, 254.28it/s]

Step: 2120, Loss: 0.2079
Step: 2140, Loss: 0.2043
Step: 2160, Loss: 0.3079


Q_bc updating ......:  44%|████▍     | 2222/5000 [00:09<00:11, 246.98it/s]

Step: 2180, Loss: 0.3409
Step: 2200, Loss: 0.2267
Step: 2220, Loss: 0.1611


Q_bc updating ......:  46%|████▌     | 2276/5000 [00:09<00:10, 254.06it/s]

Step: 2240, Loss: 0.2198
Step: 2260, Loss: 0.1329
Step: 2280, Loss: 0.2133


Q_bc updating ......:  47%|████▋     | 2328/5000 [00:09<00:11, 233.11it/s]

Step: 2300, Loss: 0.2505
Step: 2320, Loss: 0.2189
Step: 2340, Loss: 0.2811


Q_bc updating ......:  48%|████▊     | 2405/5000 [00:09<00:10, 244.48it/s]

Step: 2360, Loss: 0.1565
Step: 2380, Loss: 0.2720
Step: 2400, Loss: 0.3689


Q_bc updating ......:  49%|████▉     | 2456/5000 [00:10<00:10, 244.47it/s]

Step: 2420, Loss: 0.1740
Step: 2440, Loss: 0.2697
Step: 2460, Loss: 0.3131


Q_bc updating ......:  50%|█████     | 2505/5000 [00:10<00:11, 221.20it/s]

Step: 2480, Loss: 0.2183
Step: 2500, Loss: 0.3246
Step: 2520, Loss: 0.1906


Q_bc updating ......:  52%|█████▏    | 2575/5000 [00:10<00:10, 225.85it/s]

Step: 2540, Loss: 0.2872
Step: 2560, Loss: 0.3447
Step: 2580, Loss: 0.1875


Q_bc updating ......:  52%|█████▏    | 2621/5000 [00:10<00:11, 212.18it/s]

Step: 2600, Loss: 0.2646
Step: 2620, Loss: 0.1711
Step: 2640, Loss: 0.2961


Q_bc updating ......:  54%|█████▍    | 2699/5000 [00:11<00:09, 239.61it/s]

Step: 2660, Loss: 0.2020
Step: 2680, Loss: 0.2234
Step: 2700, Loss: 0.1789


Q_bc updating ......:  55%|█████▌    | 2756/5000 [00:11<00:08, 253.62it/s]

Step: 2720, Loss: 0.3858
Step: 2740, Loss: 0.2639
Step: 2760, Loss: 0.2822


Q_bc updating ......:  56%|█████▌    | 2808/5000 [00:11<00:08, 251.21it/s]

Step: 2780, Loss: 0.2118
Step: 2800, Loss: 0.2934
Step: 2820, Loss: 0.3490


Q_bc updating ......:  58%|█████▊    | 2888/5000 [00:11<00:08, 259.42it/s]

Step: 2840, Loss: 0.1860
Step: 2860, Loss: 0.2021
Step: 2880, Loss: 0.1794


Q_bc updating ......:  59%|█████▉    | 2940/5000 [00:12<00:08, 252.69it/s]

Step: 2900, Loss: 0.1799
Step: 2920, Loss: 0.3831
Step: 2940, Loss: 0.3301


Q_bc updating ......:  60%|█████▉    | 2991/5000 [00:12<00:08, 235.98it/s]

Step: 2960, Loss: 0.2831
Step: 2980, Loss: 0.2451
Step: 3000, Loss: 0.2136


Q_bc updating ......:  61%|██████▏   | 3064/5000 [00:12<00:08, 239.21it/s]

Step: 3020, Loss: 0.1852
Step: 3040, Loss: 0.3188
Step: 3060, Loss: 0.2214


Q_bc updating ......:  62%|██████▏   | 3110/5000 [00:12<00:08, 215.76it/s]

Step: 3080, Loss: 0.3054
Step: 3100, Loss: 0.2250
Step: 3120, Loss: 0.4325


Q_bc updating ......:  64%|██████▎   | 3183/5000 [00:13<00:07, 233.36it/s]

Step: 3140, Loss: 0.3861
Step: 3160, Loss: 0.2323
Step: 3180, Loss: 0.3429


Q_bc updating ......:  65%|██████▍   | 3234/5000 [00:13<00:08, 217.40it/s]

Step: 3200, Loss: 0.2450
Step: 3220, Loss: 0.1525


Q_bc updating ......:  66%|██████▌   | 3279/5000 [00:13<00:08, 214.46it/s]

Step: 3240, Loss: 0.2724
Step: 3260, Loss: 0.2408
Step: 3280, Loss: 0.1725


Q_bc updating ......:  67%|██████▋   | 3346/5000 [00:13<00:07, 216.89it/s]

Step: 3300, Loss: 0.2499
Step: 3320, Loss: 0.2347
Step: 3340, Loss: 0.2469


Q_bc updating ......:  68%|██████▊   | 3395/5000 [00:14<00:07, 228.31it/s]

Step: 3360, Loss: 0.2631
Step: 3380, Loss: 0.2008
Step: 3400, Loss: 0.1704


Q_bc updating ......:  69%|██████▉   | 3444/5000 [00:14<00:06, 228.31it/s]

Step: 3420, Loss: 0.1427
Step: 3440, Loss: 0.3482
Step: 3460, Loss: 0.1339


Q_bc updating ......:  70%|███████   | 3521/5000 [00:14<00:06, 230.40it/s]

Step: 3480, Loss: 0.3599
Step: 3500, Loss: 0.2116
Step: 3520, Loss: 0.4168


Q_bc updating ......:  71%|███████▏  | 3573/5000 [00:14<00:05, 242.27it/s]

Step: 3540, Loss: 0.2150
Step: 3560, Loss: 0.2720
Step: 3580, Loss: 0.2021


Q_bc updating ......:  73%|███████▎  | 3652/5000 [00:15<00:05, 246.19it/s]

Step: 3600, Loss: 0.2033
Step: 3620, Loss: 0.2852
Step: 3640, Loss: 0.3533


Q_bc updating ......:  74%|███████▍  | 3700/5000 [00:15<00:06, 206.87it/s]

Step: 3660, Loss: 0.2309
Step: 3680, Loss: 0.3443


Q_bc updating ......:  74%|███████▍  | 3722/5000 [00:15<00:06, 196.94it/s]

Step: 3700, Loss: 0.2527
Step: 3720, Loss: 0.1526
Step: 3740, Loss: 0.2734


Q_bc updating ......:  76%|███████▌  | 3796/5000 [00:15<00:05, 218.91it/s]

Step: 3760, Loss: 0.2282
Step: 3780, Loss: 0.3654
Step: 3800, Loss: 0.2826


Q_bc updating ......:  77%|███████▋  | 3858/5000 [00:16<00:06, 179.15it/s]

Step: 3820, Loss: 0.5019
Step: 3840, Loss: 0.2490


Q_bc updating ......:  78%|███████▊  | 3877/5000 [00:16<00:06, 175.68it/s]

Step: 3860, Loss: 0.2394
Step: 3880, Loss: 0.2420


Q_bc updating ......:  79%|███████▉  | 3945/5000 [00:16<00:05, 209.70it/s]

Step: 3900, Loss: 0.1740
Step: 3920, Loss: 0.1409
Step: 3940, Loss: 0.3101


Q_bc updating ......:  80%|███████▉  | 3992/5000 [00:16<00:04, 208.80it/s]

Step: 3960, Loss: 0.3379
Step: 3980, Loss: 0.2265
Step: 4000, Loss: 0.3007


Q_bc updating ......:  81%|████████▏ | 4070/5000 [00:17<00:03, 238.26it/s]

Step: 4020, Loss: 0.2836
Step: 4040, Loss: 0.1358
Step: 4060, Loss: 0.2539


Q_bc updating ......:  82%|████████▏ | 4119/5000 [00:17<00:03, 227.80it/s]

Step: 4080, Loss: 0.2935
Step: 4100, Loss: 0.2137
Step: 4120, Loss: 0.2322


Q_bc updating ......:  83%|████████▎ | 4166/5000 [00:17<00:03, 227.42it/s]

Step: 4140, Loss: 0.1842
Step: 4160, Loss: 0.2322
Step: 4180, Loss: 0.3285


Q_bc updating ......:  85%|████████▍ | 4237/5000 [00:18<00:03, 215.55it/s]

Step: 4200, Loss: 0.3001
Step: 4220, Loss: 0.2837
Step: 4240, Loss: 0.2506


Q_bc updating ......:  86%|████████▋ | 4315/5000 [00:18<00:02, 238.47it/s]

Step: 4260, Loss: 0.1440
Step: 4280, Loss: 0.2799
Step: 4300, Loss: 0.1301


Q_bc updating ......:  87%|████████▋ | 4364/5000 [00:18<00:02, 234.01it/s]

Step: 4320, Loss: 0.2773
Step: 4340, Loss: 0.2656
Step: 4360, Loss: 0.3121


Q_bc updating ......:  88%|████████▊ | 4412/5000 [00:18<00:02, 233.06it/s]

Step: 4380, Loss: 0.2679
Step: 4400, Loss: 0.1611
Step: 4420, Loss: 0.2293


Q_bc updating ......:  90%|████████▉ | 4488/5000 [00:19<00:02, 243.09it/s]

Step: 4440, Loss: 0.2332
Step: 4460, Loss: 0.2813
Step: 4480, Loss: 0.3241


Q_bc updating ......:  91%|█████████ | 4540/5000 [00:19<00:01, 231.30it/s]

Step: 4500, Loss: 0.1879
Step: 4520, Loss: 0.2248
Step: 4540, Loss: 0.2356


Q_bc updating ......:  92%|█████████▏| 4588/5000 [00:19<00:01, 231.87it/s]

Step: 4560, Loss: 0.2097
Step: 4580, Loss: 0.1580
Step: 4600, Loss: 0.1196


Q_bc updating ......:  93%|█████████▎| 4661/5000 [00:19<00:01, 234.21it/s]

Step: 4620, Loss: 0.2780
Step: 4640, Loss: 0.2522
Step: 4660, Loss: 0.2135


Q_bc updating ......:  94%|█████████▍| 4708/5000 [00:20<00:01, 225.26it/s]

Step: 4680, Loss: 0.1030
Step: 4700, Loss: 0.2575
Step: 4720, Loss: 0.3302


Q_bc updating ......:  96%|█████████▌| 4785/5000 [00:20<00:00, 238.70it/s]

Step: 4740, Loss: 0.1909
Step: 4760, Loss: 0.1911
Step: 4780, Loss: 0.2260


Q_bc updating ......:  97%|█████████▋| 4833/5000 [00:20<00:00, 217.92it/s]

Step: 4800, Loss: 0.2316
Step: 4820, Loss: 0.1601
Step: 4840, Loss: 0.1789


Q_bc updating ......:  98%|█████████▊| 4911/5000 [00:20<00:00, 242.60it/s]

Step: 4860, Loss: 0.2220
Step: 4880, Loss: 0.2553
Step: 4900, Loss: 0.2112


Q_bc updating ......:  99%|█████████▉| 4962/5000 [00:21<00:00, 243.09it/s]

Step: 4920, Loss: 0.1744
Step: 4940, Loss: 0.2275
Step: 4960, Loss: 0.2906


Q_bc updating ......: 100%|██████████| 5000/5000 [00:21<00:00, 234.87it/s]

Step: 4980, Loss: 0.3263
Q function parameters saved in logs\20241219\Q_bc.pt





In [9]:
mean, std = 0., 1.

# bc training
best_bc_path = os.path.join(path, 'bc_best.pt')
if os.path.exists(best_bc_path):
    bc.load(best_bc_path)
else:
    best_bc_score = 0
    for step in tqdm.tqdm(range(int(bc_steps)), desc='bc updating ......'):
        bc_loss = bc.update(replay_buffer)
        if step % int(log_freq) == 0:
            current_bc_score = bc.offline_evaluate(env_name, seed, mean, std)
            if current_bc_score > best_bc_score:
                best_bc_score = current_bc_score
                bc.save(best_bc_path)
                np.savetxt(os.path.join(path, 'best_bc.csv'), [best_bc_score], fmt='%f', delimiter=',')
            print(f"Step: {step}, Loss: {bc_loss:.4f}, Score: {current_bc_score:.4f}")
            logger.add_scalar('bc_loss', bc_loss, global_step=(step+1))
            logger.add_scalar('bc_score', current_bc_score, global_step=(step+1))
    bc.save(os.path.join(path, 'bc_last.pt'))
    bc.load(best_bc_path)



bc updating ......:   0%|          | 0/500 [00:00<?, ?it/s]


AttributeError: 'str' object has no attribute 'reset'

In [None]:
# bppo training
bppo.load(best_bc_path)
best_bppo_path = os.path.join(path, current_time, 'bppo_best.pt')
Q = Q_bc # If advantage replacement, then Q_{\pi k}=Q_{\pi\beta}
best_bppo_score = bppo.offline_evaluate(env_name, seed, mean, std)
print('best_bppo_score:',best_bppo_score,'-------------------------')
for step in tqdm(range(int(bppo_steps)), desc='bppo updating ......'):
    if step > 200:
        is_clip_decay = False
        is_bppo_lr_decay = False
    bppo_loss = bppo.update(replay_buffer, Q, value, is_clip_decay, is_bppo_lr_decay)
    current_bppo_score = bppo.offline_evaluate(env_name, seed, mean, std) # J_{\pi k}
    if current_bppo_score > best_bppo_score:
        best_bppo_score = current_bppo_score
        print('best_bppo_score:',best_bppo_score,'-------------------------')
        bppo.save(best_bppo_path)
        np.savetxt(os.path.join(path, current_time, 'best_bppo.csv'), [best_bppo_score], fmt='%f', delimiter=',')
        if is_update_old_policy:
            bppo.set_old_policy() # Set the old policy to the current policy
    if is_offpolicy_update: # If not using advantage replacement, calculate Q_{\pi k} by Q-learning
        for _ in tqdm(range(int(q_pi_steps)), desc='Q_pi updating ......'): 
            Q_pi_loss = Q_pi.update(replay_buffer, bppo)
        Q = Q_pi
    print(f"Step: {step}, Loss: {bppo_loss:.4f}, Score: {current_bppo_score:.4f}")
    logger.add_scalar('bppo_loss', bppo_loss, global_step=(step+1))
    logger.add_scalar('bppo_score', current_bppo_score, global_step=(step+1))

logger.close()