In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from utils import seed_random_generators
seed_random_generators()

In [None]:
from pathlib import Path
from npz_loader import load_episodes

dataset_root = Path("dataset")
holdout_game_dirs = [
    dataset_root.joinpath("BeamRiderNoFrameskip-v4", "BeamRiderNoFrameskip-v4"),
    dataset_root.joinpath("BreakoutNoFrameskip-v4", "BreakoutNoFrameskip-v4")
]

main_game_dirs = [
    dataset_root.joinpath("EnduroNoFrameskip-v4", "EnduroNoFrameskip-v4"),
    dataset_root.joinpath("MsPacmanNoFrameskip-v4", "MsPacmanNoFrameskip-v4"),
    dataset_root.joinpath("PongNoFrameskip-v4", "PongNoFrameskip-v4"),
    dataset_root.joinpath("QbertNoFrameskip-v4", "QbertNoFrameskip-v4"),
    dataset_root.joinpath("SeaquestNoFrameskip-v4", "SeaquestNoFrameskip-v4"),
    dataset_root.joinpath("SpaceInvadersNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4")
]
episodes = load_episodes(main_game_dirs, holdout_game_dirs)

KeyboardInterrupt: 

In [4]:
from utils import sample_list
sampled_episodes = sample_list(episodes, fraction=1)

Sampled 1880 items (100.0% of 1880 total)


In [5]:
from epsiode_dataloader import make_train_val_dataloaders

main_bundle, holdout_bundle, bins = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,

    # We should experiment with this, but it throws off steps being equal in terms of tokens/timesteps seen
    # So I think we keep it as some fixed number for all experiments except for an experiment specifically looking at it
    timestep_window_size=4, 
)

In [6]:
base_dir = Path("output")

# Baseline

In [7]:
# Kenny started this, might just steal those
# Some experiments like freeze below should just use the best params from baseline since there's no changes to the original model,
# but other experiments should find their own best params if there are changes to the model (like patch vs CNN)
best_baseline_params = {
    'lr': 0.002226768831180977,
    'emb_size': 128,
    'n_layers': 2,
    'n_heads': 2,
    'num_epochs': 2,
}

# Freeze

In [8]:
from experiment_freeze import run_experiment_freeze
from mgdt_model import Freezeable

freeze_params = run_experiment_freeze(
    title_prefix="Freeze Transformer",
    main_bundle=main_bundle,
    holdout_bundle=holdout_bundle,
    bins=bins,
    freeze_components=[Freezeable.Transformer],
    experiment_dir=base_dir.joinpath("freeze_transformer"),
    best_params=best_baseline_params,
)  # switch to take best params from baseline later

Cleared 0 files from output\freeze_transformer


Epoch 1/2: 100%|██████████| 41734/41734 [20:56<00:00, 33.22it/s]   
Epoch 2/2: 100%|██████████| 41734/41734 [19:58<00:00, 34.82it/s]   
Finetune 1/2: 100%|██████████| 4249/4249 [03:54<00:00, 18.13it/s] 
Finetune 2/2: 100%|██████████| 4249/4249 [02:26<00:00, 29.02it/s]
                                                               

Model and stats saved to output\freeze_transformer\model_checkpoint.pt
Saved plot to output\freeze_transformer\model_freeze_transformer_-_main_losses_per_head.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_main_losses_combined.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_main_losses_ema_per_head.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_main_losses_ema_combined.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_holdout_losses_per_head.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_holdout_losses_combined.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_holdout_losses_ema_per_head.png
Saved plot to output\freeze_transformer\model_freeze_transformer_-_holdout_losses_ema_combined.png
Saved plot to output\freeze_transformer\comparison_freeze_transformer_-_comparison_main_vs_holdout.png
Saved plot to output\freeze_transformer\comparison_freeze_tran

In [9]:
_ = run_experiment_freeze(
    title_prefix="Freeze Obs Encoder",
    main_bundle=main_bundle,
    holdout_bundle=holdout_bundle,
    bins=bins,
    freeze_components=[Freezeable.ObsEncoder],
    experiment_dir=base_dir.joinpath("freeze_obs_encoder"),
    best_params=freeze_params,
)

Cleared 12 files from output\freeze_obs_encoder


Epoch 1/2: 100%|██████████| 41734/41734 [21:09<00:00, 32.86it/s]   
Epoch 2/2: 100%|██████████| 41734/41734 [19:58<00:00, 34.82it/s]   
Finetune 1/2: 100%|██████████| 4249/4249 [02:45<00:00, 25.73it/s]
Finetune 2/2: 100%|██████████| 4249/4249 [02:27<00:00, 28.76it/s]
                                                               

Model and stats saved to output\freeze_obs_encoder\model_checkpoint.pt
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_main_losses_per_head.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_main_losses_combined.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_main_losses_ema_per_head.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_main_losses_ema_combined.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_holdout_losses_per_head.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_holdout_losses_combined.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_holdout_losses_ema_per_head.png
Saved plot to output\freeze_obs_encoder\model_freeze_obs_encoder_-_holdout_losses_ema_combined.png
Saved plot to output\freeze_obs_encoder\comparison_freeze_obs_encoder_-_comparison_main_vs_holdout.png
Saved plot to output\freeze_obs_encoder\comparison_freeze_obs_

# CNN

In [10]:
from experiment_basic import run_experiment_basic
from mgdt_model_trainer import Encoder

_ = run_experiment_basic(
    "CNN",
    main_bundle,
    holdout_bundle,
    bins,
    base_dir.joinpath("cnn"),
    encoder_type=Encoder.CNN,
)

[I 2025-12-09 10:16:01,778] A new study created in memory with name: no-name-a11a1905-54d3-41f9-b3fd-65eda4e2bf87


Trial params: {'lr': 7.312775732692473e-05, 'emb_size': 64, 'n_layers': 3, 'n_heads': 4, 'num_epochs': 3}


Epoch 1/3: 100%|██████████| 41734/41734 [14:44<00:00, 47.17it/s]
Epoch 2/3: 100%|██████████| 41734/41734 [14:57<00:00, 46.52it/s]
Epoch 3/3: 100%|██████████| 41734/41734 [14:45<00:00, 47.11it/s]
[I 2025-12-09 11:13:10,174] Trial 0 finished with value: 1.5751007795333862 and parameters: {'lr': 7.312775732692473e-05, 'emb_size': 64, 'n_layers': 3, 'n_heads': 4, 'num_epochs': 3}. Best is trial 0 with value: 1.5751007795333862.


Trial params: {'lr': 3.088368723207604e-05, 'emb_size': 64, 'n_layers': 6, 'n_heads': 4, 'num_epochs': 3}


Epoch 1/3: 100%|██████████| 41734/41734 [18:56<00:00, 36.71it/s]
Epoch 2/3: 100%|██████████| 41734/41734 [18:57<00:00, 36.67it/s]
Epoch 3/3: 100%|██████████| 41734/41734 [19:05<00:00, 36.45it/s]
[I 2025-12-09 12:23:06,355] Trial 1 finished with value: 1.7595354318618774 and parameters: {'lr': 3.088368723207604e-05, 'emb_size': 64, 'n_layers': 6, 'n_heads': 4, 'num_epochs': 3}. Best is trial 0 with value: 1.5751007795333862.


Trial params: {'lr': 0.006743403888083983, 'emb_size': 512, 'n_layers': 5, 'n_heads': 1, 'num_epochs': 5}


Epoch 1/5: 100%|██████████| 41734/41734 [24:44<00:00, 28.11it/s]  
Epoch 2/5: 100%|██████████| 41734/41734 [16:19<00:00, 42.59it/s]
Epoch 3/5: 100%|██████████| 41734/41734 [16:36<00:00, 41.87it/s]
Epoch 4/5: 100%|██████████| 41734/41734 [16:21<00:00, 42.53it/s]
Epoch 5/5: 100%|██████████| 41734/41734 [16:35<00:00, 41.90it/s]
[I 2025-12-09 14:17:41,939] Trial 2 finished with value: 7335.21240234375 and parameters: {'lr': 0.006743403888083983, 'emb_size': 512, 'n_layers': 5, 'n_heads': 1, 'num_epochs': 5}. Best is trial 0 with value: 1.5751007795333862.
[I 2025-12-09 14:17:41,940] Trial 3 pruned. 


Trial params: {'lr': 0.0004072739983235087, 'emb_size': 256, 'n_layers': 5, 'n_heads': 3, 'num_epochs': 4}
Trial params: {'lr': 3.821873872825288e-05, 'emb_size': 128, 'n_layers': 5, 'n_heads': 4, 'num_epochs': 1}


Epoch 1/1: 100%|██████████| 41734/41734 [17:11<00:00, 40.46it/s]
[I 2025-12-09 14:39:09,531] Trial 4 finished with value: 1.8784879446029663 and parameters: {'lr': 3.821873872825288e-05, 'emb_size': 128, 'n_layers': 5, 'n_heads': 4, 'num_epochs': 1}. Best is trial 0 with value: 1.5751007795333862.


Trial params: {'lr': 0.007190890680147319, 'emb_size': 64, 'n_layers': 5, 'n_heads': 1, 'num_epochs': 1}


Epoch 1/1: 100%|██████████| 41734/41734 [16:42<00:00, 41.65it/s]
[I 2025-12-09 15:00:06,369] Trial 5 finished with value: 1.7167880535125732 and parameters: {'lr': 0.007190890680147319, 'emb_size': 64, 'n_layers': 5, 'n_heads': 1, 'num_epochs': 1}. Best is trial 0 with value: 1.5751007795333862.


Trial params: {'lr': 0.0014038110531100418, 'emb_size': 256, 'n_layers': 4, 'n_heads': 4, 'num_epochs': 4}


Epoch 1/4: 100%|██████████| 41734/41734 [15:17<00:00, 45.51it/s]
Epoch 2/4: 100%|██████████| 41734/41734 [15:22<00:00, 45.22it/s]
Epoch 3/4: 100%|██████████| 41734/41734 [15:17<00:00, 45.49it/s]
Epoch 4/4: 100%|██████████| 41734/41734 [15:16<00:00, 45.54it/s]
[I 2025-12-09 16:18:27,868] Trial 6 finished with value: 1.341511845588684 and parameters: {'lr': 0.0014038110531100418, 'emb_size': 256, 'n_layers': 4, 'n_heads': 4, 'num_epochs': 4}. Best is trial 6 with value: 1.341511845588684.


Trial params: {'lr': 0.0034597371545648167, 'emb_size': 256, 'n_layers': 6, 'n_heads': 2, 'num_epochs': 2}


Epoch 1/2: 100%|██████████| 41734/41734 [18:47<00:00, 37.01it/s]
Epoch 2/2:  29%|██▉       | 12160/41734 [04:51<11:49, 41.68it/s]
[W 2025-12-09 16:47:16,125] Trial 7 failed with parameters: {'lr': 0.0034597371545648167, 'emb_size': 256, 'n_layers': 6, 'n_heads': 2, 'num_epochs': 2} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\idanc\local\projects\AtariDeepLearning\.venv_atari\Lib\site-packages\optuna\study\_optimize.py", line 205, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "c:\Users\idanc\local\projects\AtariDeepLearning\optuna_tuning.py", line 33, in objective
    model, main_train_stats, main_val_stats = train_mgdt(
                                              ^^^^^^^^^^^
  File "c:\Users\idanc\local\projects\AtariDeepLearning\mgdt_model_trainer.py", line 163, in train_mgdt
    total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
                     

KeyboardInterrupt: 

# Window Sizes

In [None]:
main_bundle_window_8, holdout_bundle_window_8, bins_window_8 = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,
    timestep_window_size=8, 
)
_ = run_experiment_basic(
    "Window Size 8",
    main_bundle_window_8,
    holdout_bundle_window_8,
    bins_window_8,
    base_dir.joinpath("window_size_8"),
    encoder_type=Encoder.Patch,
)

In [None]:
main_bundle_window_16, holdout_bundle_window_16, bins_window_16 = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,
    timestep_window_size=16, 
)
_ = run_experiment_basic(
    "Window Size 16",
    main_bundle_window_16,
    holdout_bundle_window_16,
    bins_window_16,
    base_dir.joinpath("window_size_16"),
    encoder_type=Encoder.Patch,
)

In [None]:
main_bundle_window_32, holdout_bundle_window_32, bins_window_32 = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,
    timestep_window_size=32, 
)
_ = run_experiment_basic(
    "Window Size 32",
    main_bundle_window_32,
    holdout_bundle_window_32,
    bins_window_32,
    base_dir.joinpath("window_size_32"),
    encoder_type=Encoder.Patch,
)

# Patch

In [None]:
# This is the same as baseline, but just in case something goes wrong with kenny's run
# Set to run last in case this takes >1 day
from experiment_basic import run_experiment_basic
from mgdt_model_trainer import Encoder

# _ = run_experiment_basic(
#     "Patch",
#     main_bundle,
#     holdout_bundle,
#     bins,
#     base_dir.joinpath("patch"),
#     encoder_type=Encoder.Patch,
# )

# Comparison 
## *Keep this at bottom of notebook and add new experiments to it*

In [1]:
from utils import load_checkpoint
from mgdt_model_stats import ExperimentData
from pathlib import Path

base_dir = Path("output")

def load_experiment_data(name: str, output_dir: Path) -> ExperimentData:
    checkpoint = load_checkpoint(output_dir)
    return ExperimentData(
        name=name,
        main_train_stats=checkpoint.main_train_stats,
        main_val_stats=checkpoint.main_val_stats,
        holdout_train_stats=checkpoint.holdout_train_stats,
        holdout_val_stats=checkpoint.holdout_val_stats,
    )

In [2]:
test_data_1 = load_experiment_data("Freeze Transformer", base_dir.joinpath("freeze_transformer"))
test_data_2 = load_experiment_data("Freeze Obs Encoder", base_dir.joinpath("freeze_obs_encoder"))

Loaded checkpoint from output\freeze_transformer\model_checkpoint.pt
Loaded checkpoint from output\freeze_obs_encoder\model_checkpoint.pt


In [3]:
# Sanity check: verify F1 data is present
def check_f1_data(data: ExperimentData):
    print(f"Experiment: {data.name}")
    print("-" * 40)
    
    # Check training stats
    if data.holdout_train_stats:
        has_train_f1 = "action_f1" in data.holdout_train_stats[0]
        print(f"  Holdout train F1: {'✓' if has_train_f1 else '✗'}")
    else:
        print(f"  Holdout train stats: empty")
    
    # Check validation stats
    if data.holdout_val_stats:
        has_val_f1 = "action_f1" in data.holdout_val_stats[0]
        print(f"  Holdout val F1:   {'✓' if has_val_f1 else '✗'}")
        if has_val_f1:
            f1_values = [s.get("action_f1") for s in data.holdout_val_stats]
            print(f"    F1 range: {min(f1_values):.4f} - {max(f1_values):.4f}")
    else:
        print(f"  Holdout val stats: empty")
    
    print()

check_f1_data(test_data_1)
check_f1_data(test_data_2)

Experiment: Freeze Transformer
----------------------------------------
  Holdout train F1: ✗
  Holdout val F1:   ✓
    F1 range: 0.1146 - 1.0000

Experiment: Freeze Obs Encoder
----------------------------------------
  Holdout train F1: ✗
  Holdout val F1:   ✓
    F1 range: 0.1393 - 1.0000



In [5]:
from mgdt_model_stats import experiment_comparison

experiments = [
    load_experiment_data("Freeze Transformer", base_dir.joinpath("freeze_transformer")),
    load_experiment_data("Freeze Obs Encoder", base_dir.joinpath("freeze_obs_encoder")),
    # load_experiment_data("CNN", base_dir.joinpath("cnn")),
    # load_experiment_data("Window Size 8", base_dir.joinpath("window_size_8")),
    # load_experiment_data("Window Size 16", base_dir.joinpath("window_size_16")),
    # load_experiment_data("Window Size 32", base_dir.joinpath("window_size_32")),
    # load_experiment_data("Patch", base_dir.joinpath("patch")),
]

experiment_comparison(experiments, output_dir=base_dir.joinpath("experiment_comparison"), no_show=True)

Loaded checkpoint from output\freeze_transformer\model_checkpoint.pt
Loaded checkpoint from output\freeze_obs_encoder\model_checkpoint.pt
Saved plot to output\experiment_comparison\experiment_comparison_holdout_val_loss.png
Saved plot to output\experiment_comparison\experiment_comparison_holdout_val_f1.png
Saved plot to output\experiment_comparison\experiment_comparison_holdout_val_acc.png
Saved plot to output\experiment_comparison\experiment_comparison_holdout_train_acc.png
Saved plot to output\experiment_comparison\experiment_comparison_steps_to_acc.png
Saved plot to output\experiment_comparison\experiment_comparison_steps_to_f1.png

EXPERIMENT COMPARISON SUMMARY

Freeze Transformer:
----------------------------------------
  Total holdout training steps: 8498
  Final holdout val loss: 1.4502
  Final holdout val F1: 0.3939
  Final holdout val accuracy: 0.6408
  Steps to reach F1 thresholds (validation):
    F1 >= 0.3: 849 steps
    F1 >= 0.4: 849 steps
    F1 >= 0.5: not reached
    

{'steps_to_f1_threshold': {'Freeze Transformer': {0.3: 849,
   0.4: 849,
   0.5: None,
   0.6: None},
  'Freeze Obs Encoder': {0.3: 849, 0.4: 849, 0.5: 8494, 0.6: None}},
 'steps_to_acc_threshold': {'Freeze Transformer': {0.3: 1,
   0.4: 4,
   0.5: 4,
   0.6: 34,
   0.7: 34},
  'Freeze Obs Encoder': {0.3: 2, 0.4: 2, 0.5: 7, 0.6: 28, 0.7: 43}}}