# Train Agents


In [1]:
import os
import torch
from tqdm import tqdm
import numpy as np
from stable_baselines3 import DQN
import warnings
import pandas as pd
import gc  # Import garbage collector module

warnings.filterwarnings("ignore")

In [2]:
# ensure the module is re-imported after changes
import importlib

import datasets.dataset_utils
importlib.reload(datasets.dataset_utils)

from datasets.dataset_utils import set_all_seeds, create_environment, load_dataset, preprocess_and_split, create_dataloaders, inspect_dataset_sample, load_and_prepare_dataset

In [4]:
# ensure the module is re-imported after changes
import importlib

import agent_methods.behavioral_cloning_bc.bc_utils
importlib.reload(agent_methods.behavioral_cloning_bc.bc_utils)

from agent_methods.behavioral_cloning_bc.bc_utils import train_and_evaluate_BC

import agent_methods.implicit_q_learning_iql.iql_utils
importlib.reload(agent_methods.implicit_q_learning_iql.iql_utils)

from agent_methods.implicit_q_learning_iql.iql_utils import train_and_evaluate_IQL

import agent_methods.behavior_value_estimation_bve.bve_utils
importlib.reload(agent_methods.behavior_value_estimation_bve.bve_utils)

from agent_methods.behavior_value_estimation_bve.bve_utils import train_and_evaluate_BVE

In [5]:
SEED = 12345
ENV_ID = 'SeaquestNoFrameskip-v4'
EPOCHS = 10
SEEDS = 3
BATCH_SIZE = 64

In [6]:
# set seed for reproducability
set_all_seeds(SEED)

# force PyTorch to use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# initialize enviornment
env = create_environment(env_id=ENV_ID, seed=SEED)

Device: cuda


# Training all agents on: Beginner Dataset

### 0% Perturbation

In [None]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb0.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb0 dataset ===
Preprocessing and splitting seaquest_beginner_perturb0 dataset...
Creating dataloaders for seaquest_beginner_perturb0...
✅ Dataloaders ready for: seaquest_beginner_perturb0
dict_keys(['train', 'test', 'tuning'])


## BC

In [None]:
%%time

# train and evaluate the BC model on the Beginner dataset with 0% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 0% -----")

Training BC on seaquest_beginner_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [57:38<00:00, 345.80s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -1.68118
    ➤ Avg Test Loss: -1.71870
    ➤ Avg Reward: 262.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb0/bc_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [57:16<00:00, 343.67s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -1.67643
    ➤ Avg Test Loss: -1.89362
    ➤ Avg Reward: 248.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [57:05<00:00, 342.55s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -1.69835
    ➤ Avg Test Loss: -1.67948
    ➤ Avg Reward: 260.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb0/stats_perturb0.pkl
----- Execution time: BC - Beginner | Perturbation 0% -----
CPU times: total: 3h 31min 37s
Wall time: 2h 52min 8s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 0% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 0% -----") 

Training IQL on seaquest_beginner_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:08:44<00:00, 412.45s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Actor Loss: -0.46661
    ➤ Avg Critic1 Loss: -1.67099
    ➤ Avg Critic2 Loss: -1.63494
    ➤ Avg Value Loss: -2.27581
    ➤ Avg Test Loss: -0.39819
    ➤ Avg Reward: 210.27
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb0/iql_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:05:57<00:00, 395.72s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Actor Loss: -0.45320
    ➤ Avg Critic1 Loss: -1.70217
    ➤ Avg Critic2 Loss: -1.68298
    ➤ Avg Value Loss: -2.28281
    ➤ Avg Test Loss: -0.38992
    ➤ Avg Reward: 232.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:07:14<00:00, 403.49s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Actor Loss: -0.50819
    ➤ Avg Critic1 Loss: -1.70487
    ➤ Avg Critic2 Loss: -1.66206
    ➤ Avg Value Loss: -2.32156
    ➤ Avg Test Loss: -0.38967
    ➤ Avg Reward: 222.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb0/stats_perturb0.pkl
----- Execution time: IQL - Beginner | Perturbation 0% -----
CPU times: total: 4h 10s
Wall time: 3h 22min 9s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 0% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 0% -----")

Training BVE on seaquest_beginner_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:06:23<00:00, 398.39s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -1.56911
    ➤ Avg Test Loss: -1.46587
    ➤ Avg Reward: 20.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb0/bve_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:06:35<00:00, 399.51s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -1.61088
    ➤ Avg Test Loss: -1.50807
    ➤ Avg Reward: 0.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:33<00:00, 399.30s/it]


Finished Training on seaquest_beginner_perturb0
    ➤ Avg Train Loss: -1.56141
    ➤ Avg Test Loss: -1.48893
    ➤ Avg Reward: 44.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb0/stats_perturb0.pkl
----- Execution time: BVE - Beginner | Perturbation 0% -----
CPU times: total: 3h 55min 46s
Wall time: 3h 19min 41s


-----------------------------

### 5% Perturbation

In [7]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb5.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb5 dataset ===
Preprocessing and splitting seaquest_beginner_perturb5 dataset...
Creating dataloaders for seaquest_beginner_perturb5...
✅ Dataloaders ready for: seaquest_beginner_perturb5
dict_keys(['train', 'test', 'tuning'])


## BC

In [None]:
%%time

# train and evaluate the BC model on the Beginner dataset with 5% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb5',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 5% -----")

Training BC on seaquest_beginner_perturb5
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [55:00<00:00, 330.07s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -0.20687
    ➤ Avg Test Loss: -0.03602
    ➤ Avg Reward: 266.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb5/bc_model_perturb5.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [54:55<00:00, 329.50s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -0.19231
    ➤ Avg Test Loss: -0.03501
    ➤ Avg Reward: 258.40
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [54:33<00:00, 327.35s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -0.21610
    ➤ Avg Test Loss: -0.02573
    ➤ Avg Reward: 252.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb5/stats_perturb5.pkl
----- Execution time: BC - Beginner | Perturbation 5% -----
CPU times: total: 3h 24min 4s
Wall time: 2h 44min 36s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 5% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb5',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 5% -----") 

Training IQL on seaquest_beginner_perturb5
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:06:08<00:00, 396.81s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Actor Loss: -0.94095
    ➤ Avg Critic1 Loss: -2.42957
    ➤ Avg Critic2 Loss: -2.44215
    ➤ Avg Value Loss: -3.31282
    ➤ Avg Test Loss: 0.08829
    ➤ Avg Reward: 236.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb5/iql_model_perturb5.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:05:52<00:00, 395.26s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Actor Loss: -0.84263
    ➤ Avg Critic1 Loss: -2.41529
    ➤ Avg Critic2 Loss: -2.40645
    ➤ Avg Value Loss: -3.28561
    ➤ Avg Test Loss: 0.08608
    ➤ Avg Reward: 246.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:35<00:00, 399.56s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Actor Loss: -0.71577
    ➤ Avg Critic1 Loss: -2.36511
    ➤ Avg Critic2 Loss: -2.35986
    ➤ Avg Value Loss: -3.29700
    ➤ Avg Test Loss: 0.07744
    ➤ Avg Reward: 228.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb5/stats_perturb5.pkl
----- Execution time: IQL - Beginner | Perturbation 5% -----
CPU times: total: 3h 56min 31s
Wall time: 3h 18min 48s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 5% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb5',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 5% -----")

Training BVE on seaquest_beginner_perturb5
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:06:00<00:00, 396.09s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -2.26127
    ➤ Avg Test Loss: -2.25012
    ➤ Avg Reward: 0.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb5/bve_model_perturb5.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:08:11<00:00, 409.12s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -2.31078
    ➤ Avg Test Loss: -2.30306
    ➤ Avg Reward: 0.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:08:05<00:00, 408.55s/it]


Finished Training on seaquest_beginner_perturb5
    ➤ Avg Train Loss: -2.14594
    ➤ Avg Test Loss: -2.05492
    ➤ Avg Reward: 0.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb5/stats_perturb5.pkl
----- Execution time: BVE - Beginner | Perturbation 5% -----
CPU times: total: 3h 56min 32s
Wall time: 3h 22min 28s


-----------------

### 10% Perturbation

In [6]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb10.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb10 dataset ===
Preprocessing and splitting seaquest_beginner_perturb10 dataset...
Creating dataloaders for seaquest_beginner_perturb10...
✅ Dataloaders ready for: seaquest_beginner_perturb10
dict_keys(['seaquest_beginner_perturb10'])


## BC

In [8]:
%%time

# train and evaluate the BC model on the Beginner dataset with 10% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb10',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 10% -----")

Training BC on seaquest_beginner_perturb10
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [56:55<00:00, 341.57s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -0.05753
    ➤ Avg Test Loss: 0.09269
    ➤ Avg Reward: 256.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb10/bc_model_perturb10.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [56:47<00:00, 340.71s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -0.06202
    ➤ Avg Test Loss: 0.09579
    ➤ Avg Reward: 256.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [56:00<00:00, 336.07s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -0.06067
    ➤ Avg Test Loss: 0.09879
    ➤ Avg Reward: 260.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb10/stats_perturb10.pkl
----- Execution time: BC - Beginner | Perturbation 10% -----
CPU times: total: 3h 27min 33s
Wall time: 2h 49min 51s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 10% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb10',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 10% -----") 

Training IQL on seaquest_beginner_perturb10
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:04:27<00:00, 386.79s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Actor Loss: -0.81557
    ➤ Avg Critic1 Loss: -2.40869
    ➤ Avg Critic2 Loss: -2.40088
    ➤ Avg Value Loss: -3.32256
    ➤ Avg Test Loss: 0.16868
    ➤ Avg Reward: 234.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb10/iql_model_perturb10.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:05:29<00:00, 392.90s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Actor Loss: -0.74422
    ➤ Avg Critic1 Loss: -2.38240
    ➤ Avg Critic2 Loss: -2.37378
    ➤ Avg Value Loss: -3.25866
    ➤ Avg Test Loss: 0.16652
    ➤ Avg Reward: 214.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:07:17<00:00, 403.72s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Actor Loss: -0.58591
    ➤ Avg Critic1 Loss: -2.31838
    ➤ Avg Critic2 Loss: -2.31469
    ➤ Avg Value Loss: -3.25519
    ➤ Avg Test Loss: 0.15643
    ➤ Avg Reward: 242.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb10/stats_perturb10.pkl
----- Execution time: IQL - Beginner | Perturbation 10% -----
CPU times: total: 3h 53min 37s
Wall time: 3h 17min 27s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 10% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb10',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 10% -----")

Training BVE on seaquest_beginner_perturb10
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:05:35<00:00, 393.60s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -2.23228
    ➤ Avg Test Loss: -2.08859
    ➤ Avg Reward: 0.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb10/bve_model_perturb10.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:06:16<00:00, 397.70s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -2.27569
    ➤ Avg Test Loss: -2.19120
    ➤ Avg Reward: 0.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:37<00:00, 399.75s/it]


Finished Training on seaquest_beginner_perturb10
    ➤ Avg Train Loss: -2.18774
    ➤ Avg Test Loss: -2.14127
    ➤ Avg Reward: 0.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb10/stats_perturb10.pkl
----- Execution time: BVE - Beginner | Perturbation 10% -----
CPU times: total: 3h 52min 34s
Wall time: 3h 18min 40s


-----------------

### 20% Perturbation

In [6]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/beginner_logs/seaquest_beginner_perturb20.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_beginner_perturb20 dataset ===
Preprocessing and splitting seaquest_beginner_perturb20 dataset...
Creating dataloaders for seaquest_beginner_perturb20...
✅ Dataloaders ready for: seaquest_beginner_perturb20
dict_keys(['seaquest_beginner_perturb20'])


## BC

In [7]:
%%time

# train and evaluate the BC model on the Beginner dataset with 20% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb20',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Beginner | Perturbation 20% -----")

Training BC on seaquest_beginner_perturb20
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [52:42<00:00, 316.30s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: 0.09098
    ➤ Avg Test Loss: 0.23155
    ➤ Avg Reward: 234.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb20/bc_model_perturb20.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:29:18<00:00, 535.88s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: 0.09116
    ➤ Avg Test Loss: 0.23004
    ➤ Avg Reward: 252.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:00:53<00:00, 365.31s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: 0.08593
    ➤ Avg Test Loss: 0.23636
    ➤ Avg Reward: 244.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_beginner/perturb20/stats_perturb20.pkl
----- Execution time: BC - Beginner | Perturbation 20% -----
CPU times: total: 3h 59min 17s
Wall time: 3h 23min 3s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Beginner dataset with 20% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb20',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Beginner | Perturbation 20% -----") 

Training IQL on seaquest_beginner_perturb20
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:04:27<00:00, 386.77s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Actor Loss: -0.66980
    ➤ Avg Critic1 Loss: -2.38932
    ➤ Avg Critic2 Loss: -2.39880
    ➤ Avg Value Loss: -3.32326
    ➤ Avg Test Loss: 0.26007
    ➤ Avg Reward: 216.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb20/iql_model_perturb20.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:05:04<00:00, 390.45s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Actor Loss: -0.81070
    ➤ Avg Critic1 Loss: -2.44431
    ➤ Avg Critic2 Loss: -2.43619
    ➤ Avg Value Loss: -3.34523
    ➤ Avg Test Loss: 0.25729
    ➤ Avg Reward: 242.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:28<00:00, 398.82s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Actor Loss: -0.60481
    ➤ Avg Critic1 Loss: -2.37869
    ➤ Avg Critic2 Loss: -2.36903
    ➤ Avg Value Loss: -3.34906
    ➤ Avg Test Loss: 0.25173
    ➤ Avg Reward: 226.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_beginner/perturb20/stats_perturb20.pkl
----- Execution time: IQL - Beginner | Perturbation 20% -----
CPU times: total: 3h 52min 38s
Wall time: 3h 16min 13s


## BVE

In [7]:
%%time

# train and evaluate the BVE model on the Beginner dataset with 20% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_beginner_perturb20',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Beginner | Perturbation 20% -----")

Training BVE on seaquest_beginner_perturb20
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:07:06<00:00, 402.65s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: -2.25355
    ➤ Avg Test Loss: -2.09426
    ➤ Avg Reward: 0.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb20/bve_model_perturb20.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:05:43<00:00, 394.36s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: -2.22645
    ➤ Avg Test Loss: -2.06080
    ➤ Avg Reward: 0.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:24<00:00, 398.48s/it]


Finished Training on seaquest_beginner_perturb20
    ➤ Avg Train Loss: -2.14737
    ➤ Avg Test Loss: -1.97296
    ➤ Avg Reward: 0.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_beginner/perturb20/stats_perturb20.pkl
----- Execution time: BVE - Beginner | Perturbation 20% -----
CPU times: total: 3h 54min 18s
Wall time: 3h 19min 25s


------------------

# Training all agents on: Intermediate Dataset

### 0% Perturbation

In [7]:
dataloaders = load_and_prepare_dataset(
    dataset_path='datasets/intermediate_logs/seaquest_intermediate_perturb0.pkl',
    batch_size=BATCH_SIZE,
    seed=SEED
)

print(dataloaders.keys())

=== Loading seaquest_intermediate_perturb0 dataset ===
Preprocessing and splitting seaquest_intermediate_perturb0 dataset...
Creating dataloaders for seaquest_intermediate_perturb0...
✅ Dataloaders ready for: seaquest_intermediate_perturb0
dict_keys(['seaquest_intermediate_perturb0'])


## BC

In [7]:
%%time

# train and evaluate the BC model on the Intermediate dataset with 0% perturbation
train_and_evaluate_BC(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_intermediate_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BC - Intermediate | Perturbation 0% -----")

Training BC on seaquest_intermediate_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [56:16<00:00, 337.68s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -3.03064
    ➤ Avg Test Loss: -3.34562
    ➤ Avg Reward: 358.00
Model saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_intermediate/perturb0/bc_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [55:40<00:00, 334.09s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -2.99913
    ➤ Avg Test Loss: -3.30587
    ➤ Avg Reward: 344.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [55:52<00:00, 335.27s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -2.84021
    ➤ Avg Test Loss: -2.97219
    ➤ Avg Reward: 310.00
Return Stats saved to agent_methods/behavioral_cloning_bc/bc_logs/seaquest_intermediate/perturb0/stats_perturb0.pkl
----- Execution time: BC - Intermediate | Perturbation 0% -----
CPU times: total: 3h 27min 6s
Wall time: 2h 47min 57s


## IQL

In [7]:
%%time

# train and evaluate the IQL model on the Intermediate dataset with 0% perturbation
train_and_evaluate_IQL(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_intermediate_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: IQL - Intermediate | Perturbation 0% -----") 

Training IQL on seaquest_intermediate_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:04:07<00:00, 384.78s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Actor Loss: 0.38233
    ➤ Avg Critic1 Loss: -1.10483
    ➤ Avg Critic2 Loss: -1.10999
    ➤ Avg Value Loss: -1.61832
    ➤ Avg Test Loss: -0.57688
    ➤ Avg Reward: 174.00
Model saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_intermediate/perturb0/iql_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:03:51<00:00, 383.14s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Actor Loss: 0.49039
    ➤ Avg Critic1 Loss: -1.13186
    ➤ Avg Critic2 Loss: -1.12787
    ➤ Avg Value Loss: -1.58814
    ➤ Avg Test Loss: -0.62337
    ➤ Avg Reward: 156.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:06:35<00:00, 399.59s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Actor Loss: 0.27473
    ➤ Avg Critic1 Loss: -1.20798
    ➤ Avg Critic2 Loss: -1.13456
    ➤ Avg Value Loss: -1.64877
    ➤ Avg Test Loss: -0.60901
    ➤ Avg Reward: 164.00
Return Stats saved to agent_methods/implicit_q_learning_iql/iql_logs/seaquest_intermediate/perturb0/stats_perturb0.pkl
----- Execution time: IQL - Intermediate | Perturbation 0% -----
CPU times: total: 3h 52min
Wall time: 3h 14min 47s


## BVE

In [8]:
%%time

# train and evaluate the BVE model on the Intermediate dataset with 0% perturbation
train_and_evaluate_BVE(
    dataloaders=dataloaders,
    device=device,
    seeds=SEEDS,
    epochs=EPOCHS,
    dataset='seaquest_intermediate_perturb0',
    env_id=ENV_ID,
    seed=SEED
)

# print execution time it took to train the model
print("----- Execution time: BVE - Intermediate | Perturbation 0% -----")

Training BVE on seaquest_intermediate_perturb0
-- Starting Seed 1/3 --


Epochs: 100%|██████████| 10/10 [1:05:50<00:00, 395.01s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -1.05895
    ➤ Avg Test Loss: -0.98404
    ➤ Avg Reward: 90.00
Saved model to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_intermediate/perturb0/bve_model_perturb0.pth
-- Starting Seed 2/3 --


Epochs: 100%|██████████| 10/10 [1:17:28<00:00, 464.88s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -1.10359
    ➤ Avg Test Loss: -1.05455
    ➤ Avg Reward: 100.00
-- Starting Seed 3/3 --


Epochs: 100%|██████████| 10/10 [1:19:31<00:00, 477.19s/it]


Finished Training on seaquest_intermediate_perturb0
    ➤ Avg Train Loss: -1.12968
    ➤ Avg Test Loss: -1.04024
    ➤ Avg Reward: 100.00
Saved stats to agent_methods/behavior_value_estimation_bve/bve_logs/seaquest_intermediate/perturb0/stats_perturb0.pkl
----- Execution time: BVE - Intermediate | Perturbation 0% -----
CPU times: total: 4h 22min 45s
Wall time: 3h 43min


-----------------------------

### 5% Perturbation

## BC

## IQL

## BVE

-----------------

### 10% Perturbation

----------------

# Training all agents on: Expert Dataset

### 0% Perturbation