In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from utils import seed_random_generators
seed_random_generators()

In [3]:
from pathlib import Path
from npz_loader import load_episodes

dataset_root = Path("dataset")
holdout_game_dirs = [
    dataset_root.joinpath("BeamRiderNoFrameskip-v4", "BeamRiderNoFrameskip-v4"),
    dataset_root.joinpath("BreakoutNoFrameskip-v4", "BreakoutNoFrameskip-v4")
]

main_game_dirs = [
    dataset_root.joinpath("EnduroNoFrameskip-v4", "EnduroNoFrameskip-v4"),
    dataset_root.joinpath("MsPacmanNoFrameskip-v4", "MsPacmanNoFrameskip-v4"),
    dataset_root.joinpath("PongNoFrameskip-v4", "PongNoFrameskip-v4"),
    dataset_root.joinpath("QbertNoFrameskip-v4", "QbertNoFrameskip-v4"),
    dataset_root.joinpath("SeaquestNoFrameskip-v4", "SeaquestNoFrameskip-v4"),
    dataset_root.joinpath("SpaceInvadersNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4")
]
episodes = load_episodes(main_game_dirs, holdout_game_dirs)

Loaded 1880 episodes


In [4]:
from utils import sample_list
sampled_episodes = sample_list(episodes, fraction=1)

Sampled 1880 items (100.0% of 1880 total)


In [5]:
from epsiode_dataloader import make_train_val_dataloaders

main_bundle, holdout_bundle, bins = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,

    # We should experiment with this, but it throws off steps being equal in terms of tokens/timesteps seen
    # So I think we keep it as some fixed number for all experiments except for an experiment specifically looking at it
    timestep_window_size=4, 
)

In [6]:
base_dir = Path("output")

# Baseline

In [7]:
# Kenny started this, might just steal those
# Some experiments like freeze below should just use the best params from baseline since there's no changes to the original model,
# but other experiments should find their own best params if there are changes to the model (like patch vs CNN)

# Freeze

In [None]:
from experiment_freeze import run_experiment_freeze
from mgdt_model import Freezeable

freeze_params = run_experiment_freeze(
    title_prefix="Freeze Transformer",
    main_bundle=main_bundle,
    holdout_bundle=holdout_bundle,
    bins=bins,
    freeze_components=[Freezeable.Transformer],
    experiment_dir=base_dir.joinpath("freeze_transformer"),
)  # switch to take best params from baseline later

[I 2025-12-08 23:17:32,436] A new study created in memory with name: no-name-40fa0255-b8b1-4f58-aba7-545f3274c1ef


Cleared 0 files from output\freeze_transformer
Trial params: {'lr': 4.105152517595741e-05, 'emb_size': 128, 'n_layers': 5, 'n_heads': 2, 'num_epochs': 4}


In [None]:
_ = run_experiment_freeze(
    title_prefix="Freeze Obs Encoder",
    main_bundle=main_bundle,
    holdout_bundle=holdout_bundle,
    bins=bins,
    freeze_components=[Freezeable.ObsEncoder],
    experiment_dir=base_dir.joinpath("freeze_obs_encoder"),
    best_params=freeze_params,
)

# CNN

In [None]:
from experiment_basic import run_experiment_basic
from mgdt_model_trainer import Encoder

_ = run_experiment_basic(
    "CNN",
    main_bundle,
    holdout_bundle,
    bins,
    base_dir.joinpath("cnn"),
    encoder_type=Encoder.CNN,
)

# Window Sizes

In [None]:
main_bundle_window_8, holdout_bundle_window_8, bins_window_8 = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,
    timestep_window_size=8, 
)
_ = run_experiment_basic(
    "Window Size 8",
    main_bundle_window_8,
    holdout_bundle_window_8,
    bins_window_8,
    base_dir.joinpath("window_size_8"),
    encoder_type=Encoder.Patch,
)

In [None]:
main_bundle_window_16, holdout_bundle_window_16, bins_window_16 = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,
    timestep_window_size=16, 
)
_ = run_experiment_basic(
    "Window Size 16",
    main_bundle_window_16,
    holdout_bundle_window_16,
    bins_window_16,
    base_dir.joinpath("window_size_16"),
    encoder_type=Encoder.Patch,
)

In [None]:
main_bundle_window_32, holdout_bundle_window_32, bins_window_32 = make_train_val_dataloaders(
    episodes=sampled_episodes,
    holdout_game_dirs=holdout_game_dirs,
    train_frac=0.8,
    timestep_window_size=32, 
)
_ = run_experiment_basic(
    "Window Size 32",
    main_bundle_window_32,
    holdout_bundle_window_32,
    bins_window_32,
    base_dir.joinpath("window_size_32"),
    encoder_type=Encoder.Patch,
)

# Patch

In [None]:
# This is the same as baseline, but just in case something goes wrong with kenny's run
# Set to run last in case this takes >1 day
from experiment_basic import run_experiment_basic
from mgdt_model_trainer import Encoder

_ = run_experiment_basic(
    "Patch",
    main_bundle,
    holdout_bundle,
    bins,
    base_dir.joinpath("patch"),
    encoder_type=Encoder.Patch,
)

# Comparison 
## *Keep this at bottom of notebook and add new experiments to it*

In [None]:
from utils import load_checkpoint
from mgdt_model_stats import ExperimentData

def load_experiment_data(name: str, output_dir: Path) -> ExperimentData:
    checkpoint = load_checkpoint(output_dir)
    return ExperimentData(
        name=name,
        main_train_stats=checkpoint.main_train_stats,
        main_val_stats=checkpoint.main_val_stats,
        holdout_train_stats=checkpoint.holdout_train_stats,
        holdout_val_stats=checkpoint.holdout_val_stats,
    )

In [None]:
from mgdt_model_stats import experiment_comparison

experiments = [
    load_experiment_data("Freeze Transformer", base_dir.joinpath("freeze_transformer")),
    load_experiment_data("Freeze Obs Encoder", base_dir.joinpath("freeze_obs_encoder")),
    load_experiment_data("CNN", base_dir.joinpath("cnn")),
    load_experiment_data("Window Size 8", base_dir.joinpath("window_size_8")),
    load_experiment_data("Window Size 16", base_dir.joinpath("window_size_16")),
    load_experiment_data("Window Size 32", base_dir.joinpath("window_size_32")),
    load_experiment_data("Patch", base_dir.joinpath("patch")),
]

experiment_comparison(experiments, output_dir=base_dir.joinpath("experiment_comparison"))