# Train Agents


In [1]:
import os
import torch
from tqdm import tqdm
import numpy as np
from stable_baselines3 import DQN
import warnings
import pandas as pd
import gc  # Import garbage collector module

warnings.filterwarnings("ignore")

In [2]:
# ensure the module is re-imported after changes
import importlib

import datasets.dataset_utils
importlib.reload(datasets.dataset_utils)

from datasets.dataset_utils import set_all_seeds, create_environment, load_dataset, preprocess_and_split, create_dataloaders, load_and_prepare_dataset

In [3]:
# ensure the module is re-imported after changes
import importlib

import offline_rl_models.behavioral_cloning_bc.bc_utils
importlib.reload(offline_rl_models.behavioral_cloning_bc.bc_utils)

from offline_rl_models.behavioral_cloning_bc.bc_utils import train_and_evaluate_BC

import offline_rl_models.implicit_q_learning_iql.iql_utils
importlib.reload(offline_rl_models.implicit_q_learning_iql.iql_utils)

from offline_rl_models.implicit_q_learning_iql.iql_utils import train_and_evaluate_IQL

import offline_rl_models.behavior_value_estimation_bve.bve_utils
importlib.reload(offline_rl_models.behavior_value_estimation_bve.bve_utils)

from offline_rl_models.behavior_value_estimation_bve.bve_utils import train_and_evaluate_BVE

In [4]:
SEED = 12345
ENV_ID = 'SeaquestNoFrameskip-v4'
EPOCHS = 10
SEEDS = 3
BATCH_SIZE = 64

In [5]:
# set seed for reproducability
set_all_seeds(SEED)

# force PyTorch to use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# initialize enviornment
env = create_environment(env_id=ENV_ID, seed=SEED)

Device: cuda


# Training all agents on: Beginner Dataset

### 0% Perturbation

In [6]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb0.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

=== Loading seaquest_beginner_perturb0 dataset ===
Preprocessing and splitting seaquest_beginner_perturb0 dataset...
Creating dataloaders for seaquest_beginner_perturb0...
Dataloaders ready for: seaquest_beginner_perturb0


## BC

In [None]:
%%time

# train and evaluate the BC model on the Beginner dataset with 0% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 0% -----")

Training BC on seaquest_beginner_perturb0
-- Starting Seed 1/3 --


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 0% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 0% -----") 

Training IQL on seaquest_beginner_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:04:33<00:00, 387.38s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Actor Loss: -0.83623
    ➤ Avg Critic1 Loss: -2.64333
    ➤ Avg Critic2 Loss: -2.65399
    ➤ Avg Value Loss: -3.38895
    ➤ Avg Test Loss: 0.04530
    ➤ Avg Reward: 188.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb0/iql_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:04:09<00:00, 384.91s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Actor Loss: -0.85013
    ➤ Avg Critic1 Loss: -2.67549
    ➤ Avg Critic2 Loss: -2.65668
    ➤ Avg Value Loss: -3.38000
    ➤ Avg Test Loss: 0.03654
    ➤ Avg Reward: 148.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:10:25<00:00, 422.57s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Actor Loss: -0.65777
    ➤ Avg Critic1 Loss: -2.59554
    ➤ Avg Critic2 Loss: -2.58297
    ➤ Avg Value Loss: -3.36736
    ➤ Avg Test Loss: 0.03317
    ➤ Avg Reward: 156.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb0/stats_perturb0.pkl
----- Execution time: IQL - Beginner | Perturbation 0% -----
CPU times: total: 3h 55min 8s
Wall time: 3h 19min 21s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 0% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 0% -----")

Training BVE on seaquest_beginner_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:08:19<00:00, 409.99s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -2.47953
    ➤ Avg Test Loss: -2.35064
    ➤ Avg Reward: 80.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb0/bve_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:08:50<00:00, 413.05s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -2.47857
    ➤ Avg Test Loss: -2.45743
    ➤ Avg Reward: 118.40
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:09:26<00:00, 416.69s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -2.44482
    ➤ Avg Test Loss: -2.45266
    ➤ Avg Reward: 90.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb0/stats_perturb0.pkl
----- Execution time: BVE - Beginner | Perturbation 0% -----
CPU times: total: 4h 1min 45s
Wall time: 3h 26min 47s


-----------------------------

### 5% Perturbation

In [6]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb5.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb5 dataset ===
Preprocessing and splitting seaquest_beginner_perturb5 dataset...
Creating dataloaders for seaquest_beginner_perturb5...
Dataloaders ready for: seaquest_beginner_perturb5
dict_keys(['seaquest_beginner_perturb5'])


## BC

In [7]:
%%time

# train and evaluate the BC model on the Beginner dataset with 5% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb5',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 5% -----")

Training BC on seaquest_beginner_perturb5
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [52:22<00:00, 314.29s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -0.19064
    ➤ Avg Test Loss: 0.08850
    ➤ Avg Reward: 214.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb5/bc_model_perturb5.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [51:51<00:00, 311.17s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -0.15716
    ➤ Avg Test Loss: 0.07163
    ➤ Avg Reward: 218.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [51:46<00:00, 310.68s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -0.17560
    ➤ Avg Test Loss: 0.07710
    ➤ Avg Reward: 220.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb5/stats_perturb5.pkl
----- Execution time: BC - Beginner | Perturbation 5% -----
CPU times: total: 3h 14min 55s
Wall time: 2h 36min 8s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 5% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb5',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 5% -----") 

Training IQL on seaquest_beginner_perturb5
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:04:18<00:00, 385.87s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Actor Loss: -0.47796
    ➤ Avg Critic1 Loss: -2.53296
    ➤ Avg Critic2 Loss: -2.55300
    ➤ Avg Value Loss: -3.32596
    ➤ Avg Test Loss: 0.14116
    ➤ Avg Reward: 194.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb5/iql_model_perturb5.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:03:53<00:00, 383.37s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Actor Loss: -0.67254
    ➤ Avg Critic1 Loss: -2.64086
    ➤ Avg Critic2 Loss: -2.62093
    ➤ Avg Value Loss: -3.37172
    ➤ Avg Test Loss: 0.14403
    ➤ Avg Reward: 142.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:43<00:00, 400.38s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Actor Loss: -0.56072
    ➤ Avg Critic1 Loss: -2.58448
    ➤ Avg Critic2 Loss: -2.57717
    ➤ Avg Value Loss: -3.37373
    ➤ Avg Test Loss: 0.13988
    ➤ Avg Reward: 188.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb5/stats_perturb5.pkl
----- Execution time: IQL - Beginner | Perturbation 5% -----
CPU times: total: 3h 51min 43s
Wall time: 3h 15min 9s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 5% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb5',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 5% -----")

Training BVE on seaquest_beginner_perturb5
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:06:12<00:00, 397.22s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -2.50455
    ➤ Avg Test Loss: -2.37970
    ➤ Avg Reward: 74.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb5/bve_model_perturb5.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:07:57<00:00, 407.80s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -2.49340
    ➤ Avg Test Loss: -2.44997
    ➤ Avg Reward: 88.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:08:08<00:00, 408.85s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -2.39372
    ➤ Avg Test Loss: -2.32510
    ➤ Avg Reward: 56.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb5/stats_perturb5.pkl
----- Execution time: BVE - Beginner | Perturbation 5% -----
CPU times: total: 3h 59min 49s
Wall time: 3h 22min 27s


-----------------

### 10% Perturbation

In [6]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb10.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb10 dataset ===
Preprocessing and splitting seaquest_beginner_perturb10 dataset...
Creating dataloaders for seaquest_beginner_perturb10...
Dataloaders ready for: seaquest_beginner_perturb10
dict_keys(['seaquest_beginner_perturb10'])


## BC

In [7]:
%%time

# train and evaluate the BC model on the Beginner dataset with 10% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb10',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 10% -----")

Training BC on seaquest_beginner_perturb10
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [58:46<00:00, 352.65s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -0.08343
    ➤ Avg Test Loss: 0.17054
    ➤ Avg Reward: 204.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb10/bc_model_perturb10.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [59:23<00:00, 356.30s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -0.06545
    ➤ Avg Test Loss: 0.15933
    ➤ Avg Reward: 224.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:00:02<00:00, 360.29s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -0.05826
    ➤ Avg Test Loss: 0.16045
    ➤ Avg Reward: 220.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb10/stats_perturb10.pkl
----- Execution time: BC - Beginner | Perturbation 10% -----
CPU times: total: 3h 39min 37s
Wall time: 2h 58min 20s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 10% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb10',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 10% -----") 

Training IQL on seaquest_beginner_perturb10
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:06:14<00:00, 397.49s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Actor Loss: -0.43965
    ➤ Avg Critic1 Loss: -2.51346
    ➤ Avg Critic2 Loss: -2.53032
    ➤ Avg Value Loss: -3.36091
    ➤ Avg Test Loss: 0.19188
    ➤ Avg Reward: 170.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb10/iql_model_perturb10.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:08:07<00:00, 408.71s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Actor Loss: -0.65075
    ➤ Avg Critic1 Loss: -2.62577
    ➤ Avg Critic2 Loss: -2.61171
    ➤ Avg Value Loss: -3.39306
    ➤ Avg Test Loss: 0.19203
    ➤ Avg Reward: 174.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:09:09<00:00, 414.97s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Actor Loss: -0.54737
    ➤ Avg Critic1 Loss: -2.57961
    ➤ Avg Critic2 Loss: -2.57128
    ➤ Avg Value Loss: -3.39913
    ➤ Avg Test Loss: 0.18910
    ➤ Avg Reward: 166.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb10/stats_perturb10.pkl
----- Execution time: IQL - Beginner | Perturbation 10% -----
CPU times: total: 4h 4min
Wall time: 3h 23min 44s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 10% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb10',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 10% -----")

Training BVE on seaquest_beginner_perturb10
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:07:51<00:00, 407.18s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -2.42478
    ➤ Avg Test Loss: -2.48109
    ➤ Avg Reward: 18.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb10/bve_model_perturb10.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:08:28<00:00, 410.90s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -2.47520
    ➤ Avg Test Loss: -2.48756
    ➤ Avg Reward: 66.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:08:41<00:00, 412.15s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -2.42990
    ➤ Avg Test Loss: -2.45287
    ➤ Avg Reward: 54.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb10/stats_perturb10.pkl
----- Execution time: BVE - Beginner | Perturbation 10% -----
CPU times: total: 4h 1min 14s
Wall time: 3h 25min 12s


-----------------

### 20% Perturbation

In [6]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb20.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb20 dataset ===
Preprocessing and splitting seaquest_beginner_perturb20 dataset...
Creating dataloaders for seaquest_beginner_perturb20...
Dataloaders ready for: seaquest_beginner_perturb20
dict_keys(['seaquest_beginner_perturb20'])


## BC

In [7]:
%%time

# train and evaluate the BC model on the Beginner dataset with 20% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb20',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 20% -----")

Training BC on seaquest_beginner_perturb20
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [53:42<00:00, 322.24s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: 0.06716
    ➤ Avg Test Loss: 0.28116
    ➤ Avg Reward: 194.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb20/bc_model_perturb20.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [53:48<00:00, 322.84s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: 0.07951
    ➤ Avg Test Loss: 0.26962
    ➤ Avg Reward: 222.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [53:16<00:00, 319.65s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: 0.06541
    ➤ Avg Test Loss: 0.28003
    ➤ Avg Reward: 228.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb20/stats_perturb20.pkl
----- Execution time: BC - Beginner | Perturbation 20% -----
CPU times: total: 3h 20min 56s
Wall time: 2h 40min 54s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 20% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb20',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 20% -----") 

Training IQL on seaquest_beginner_perturb20
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:07:58<00:00, 407.86s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Actor Loss: -0.67840
    ➤ Avg Critic1 Loss: -2.67207
    ➤ Avg Critic2 Loss: -2.68481
    ➤ Avg Value Loss: -3.41514
    ➤ Avg Test Loss: 0.27123
    ➤ Avg Reward: 174.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb20/iql_model_perturb20.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:35:18<00:00, 571.85s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Actor Loss: -0.65608
    ➤ Avg Critic1 Loss: -2.70958
    ➤ Avg Critic2 Loss: -2.69465
    ➤ Avg Value Loss: -3.43669
    ➤ Avg Test Loss: 0.27219
    ➤ Avg Reward: 142.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:30:02<00:00, 540.30s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Actor Loss: -0.44251
    ➤ Avg Critic1 Loss: -2.61685
    ➤ Avg Critic2 Loss: -2.60457
    ➤ Avg Value Loss: -3.40170
    ➤ Avg Test Loss: 0.26973
    ➤ Avg Reward: 172.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb20/stats_perturb20.pkl
----- Execution time: IQL - Beginner | Perturbation 20% -----
CPU times: total: 4h 52min 14s
Wall time: 4h 13min 39s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 20% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb20',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 20% -----")

Training BVE on seaquest_beginner_perturb20
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:07:11<00:00, 403.15s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: -2.53825
    ➤ Avg Test Loss: -2.43279
    ➤ Avg Reward: 104.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb20/bve_model_perturb20.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:08:21<00:00, 410.12s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: -2.52542
    ➤ Avg Test Loss: -2.41370
    ➤ Avg Reward: 118.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:08:48<00:00, 412.85s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: -2.49181
    ➤ Avg Test Loss: -2.43264
    ➤ Avg Reward: 112.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb20/stats_perturb20.pkl
----- Execution time: BVE - Beginner | Perturbation 20% -----
CPU times: total: 4h 16s
Wall time: 3h 24min 31s


------------------

# Training all agents on: Intermediate Dataset

### 0% Perturbation

In [7]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/intermediate_logs/seaquest_intermediate_perturb0.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_intermediate_perturb0 dataset ===
Preprocessing and splitting seaquest_intermediate_perturb0 dataset...
Creating dataloaders for seaquest_intermediate_perturb0...
✅ Dataloaders ready for: seaquest_intermediate_perturb0
dict_keys(['seaquest_intermediate_perturb0'])


## BC

In [7]:
%%time

# train and evaluate the BC model on the Intermediate dataset with 0% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_intermediate_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Intermediate | Perturbation 0% -----")

Training BC on seaquest_intermediate_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [56:16<00:00, 337.68s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -3.03064
    ➤ Avg Test Loss: -3.34562
    ➤ Avg Reward: 358.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_intermediate/perturb0/bc_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [55:40<00:00, 334.09s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -2.99913
    ➤ Avg Test Loss: -3.30587
    ➤ Avg Reward: 344.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [55:52<00:00, 335.27s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -2.84021
    ➤ Avg Test Loss: -2.97219
    ➤ Avg Reward: 310.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_intermediate/perturb0/stats_perturb0.pkl
----- Execution time: BC - Intermediate | Perturbation 0% -----
CPU times: total: 3h 27min 6s
Wall time: 2h 47min 57s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Intermediate dataset with 0% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_intermediate_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Intermediate | Perturbation 0% -----") 

Training IQL on seaquest_intermediate_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:04:07<00:00, 384.78s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Actor Loss: 0.38233
    ➤ Avg Critic1 Loss: -1.10483
    ➤ Avg Critic2 Loss: -1.10999
    ➤ Avg Value Loss: -1.61832
    ➤ Avg Test Loss: -0.57688
    ➤ Avg Reward: 174.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_intermediate/perturb0/iql_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:03:51<00:00, 383.14s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Actor Loss: 0.49039
    ➤ Avg Critic1 Loss: -1.13186
    ➤ Avg Critic2 Loss: -1.12787
    ➤ Avg Value Loss: -1.58814
    ➤ Avg Test Loss: -0.62337
    ➤ Avg Reward: 156.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:35<00:00, 399.59s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Actor Loss: 0.27473
    ➤ Avg Critic1 Loss: -1.20798
    ➤ Avg Critic2 Loss: -1.13456
    ➤ Avg Value Loss: -1.64877
    ➤ Avg Test Loss: -0.60901
    ➤ Avg Reward: 164.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_intermediate/perturb0/stats_perturb0.pkl
----- Execution time: IQL - Intermediate | Perturbation 0% -----
CPU times: total: 3h 52min
Wall time: 3h 14min 47s


## BVE

In [8]:
%%time

# train and evaluate the BVE model on the Intermediate dataset with 0% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_intermediate_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Intermediate | Perturbation 0% -----")

Training BVE on seaquest_intermediate_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:05:50<00:00, 395.01s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -1.05895
    ➤ Avg Test Loss: -0.98404
    ➤ Avg Reward: 90.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_intermediate/perturb0/bve_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:17:28<00:00, 464.88s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -1.10359
    ➤ Avg Test Loss: -1.05455
    ➤ Avg Reward: 100.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:19:31<00:00, 477.19s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -1.12968
    ➤ Avg Test Loss: -1.04024
    ➤ Avg Reward: 100.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_intermediate/perturb0/stats_perturb0.pkl
----- Execution time: BVE - Intermediate | Perturbation 0% -----
CPU times: total: 4h 22min 45s
Wall time: 3h 43min


-----------------------------

### 5% Perturbation

## BC

## IQL

## BVE

-----------------

### 10% Perturbation

----------------

# Training all agents on: Expert Dataset

### 0% Perturbation