In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt

calculate normalized score

In [2]:
import pandas as pd

def calculate_normalized_score(raw_scores, score_random, score_expert):
    # print(f"raw_scores: {raw_scores}, score_random: {score_random}, score_expert: {score_expert}")
    return 100 * ((raw_scores - score_random) / (score_expert - score_random))

def process_experiment_results(base_path, experiment_prefix, game):
    seeds = ["123", "132", "321"]
    normalized_scores = []

    # Game specific random and expert scores
    score_random = {
        'Breakout': 1.7, 'Qbert': 163.9, 'Pong': -20.7, 'Seaquest': 68.4, 
        'Hero': 1027.0, 'KungFuMaster': 258.5, 'Alien': 227.8, 'RoadRunner': 11.5,
        'BattleZone': 2360.0, 'BankHeist': 14.0, 'FishingDerby': -92.0, 'Zaxxon': 32.0,
        'MsPacman': 307.3, 'SpaceInvaders': 148,
    }
    score_expert = {
        'Breakout': 30.5, 'Qbert': 13455.0, 'Pong': 14.6, 'Seaquest': 42054.7, 
        'Hero': 30826.4, 'KungFuMaster': 22736.3, 'Alien': 7127.7, 'RoadRunner': 7845.0,
        'BattleZone': 37187.5, 'BankHeist': 753.0, 'FishingDerby': -39.0, 'Zaxxon': 9173.0,
        'MsPacman': 6951.6, 'SpaceInvaders': 1669.0,
    }

    for seed in seeds:
        file_path = f"{base_path}{experiment_prefix}{seed}/summary.csv"
        if pd.read_csv(file_path).empty:
            print(f"Warning: No data in {file_path}")
            continue

        # load the best epoch data
        data = pd.read_csv(file_path)
        best_epoch = data['evaluation/eval_return'].idxmax()
        raw_score = data.loc[best_epoch]['evaluation/eval_return']

        # # load the average data of all epochs
        # data = pd.read_csv(file_path)
        # raw_score = data['evaluation/eval_return'].mean()
        
        # Calculate normalized score
        normalized_score = calculate_normalized_score(raw_score, score_random[game], score_expert[game])
        normalized_scores.append(normalized_score)

    if normalized_scores:
        return pd.Series(normalized_scores).agg(['mean', 'std'])
    else:
        return None


### Action fusion

In [3]:
base_path = "~/msc-project/atari/output/atari_10_new/"

experiments = [
    ("dmamba_hero", "Hero"),
    ("dtrans_hero", "Hero"),
    ("dmamba_kungfumaster", "KungFuMaster"),
    ("dtrans_kungfumaster", "KungFuMaster"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()

Game: Hero 
Model: Decision Mamba
mean    7.26
std     1.21
dtype: float64

Game: Hero 
Model: Decision Transformer
mean    16.06
std      0.74
dtype: float64

Game: KungFuMaster 
Model: Decision Mamba
mean    2.79
std     0.33
dtype: float64

Game: KungFuMaster 
Model: Decision Transformer
mean    1.55
std     0.86
dtype: float64



In [5]:
base_path = "~/msc-project/atari/output/atari_10_1p_reverse/"

experiments = [
    ("dmamba_hero", "Hero"),
    ("dtrans_hero", "Hero"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()

Game: Hero 
Model: Decision Mamba
mean    7.26
std     1.21
dtype: float64

Game: Hero 
Model: Decision Transformer
mean    16.06
std      0.74
dtype: float64



In [9]:
base_path = "~/msc-project/atari/output/atari_10/"

experiments = [
    # ("dmamba_breakout", "Breakout"),
    # ("dtrans_breakout", "Breakout"),
    # ("dmamba_qbert", "Qbert"),
    # ("dtrans_qbert", "Qbert"),
    # ("dmamba_hero", "Hero"),
    # ("dtrans_hero", "Hero"),
    # ("dmamba_kungfumaster", "KungFuMaster"),
    # ("dtrans_kungfumaster", "KungFuMaster"),
    # ("dmamba_pong", "Pong"),
    # ("dtrans_pong", "Pong"),
    # ("dmamba_seaquest", "Seaquest"),
    # ("dtrans_seaquest", "Seaquest"),
    ("dmamba_alien", "Alien"),
    ("dtrans_alien", "Alien"),
    ("dmamba_roadrunner", "RoadRunner"),
    ("dtrans_roadrunner", "RoadRunner"),
    ("dmamba_battlezone", "BattleZone"),
    ("dtrans_battlezone", "BattleZone"),
    ("dmamba_bankheist", "BankHeist"),
    ("dtrans_bankheist", "BankHeist"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()


Game: Alien 
Model: Decision Mamba
mean    11.23
std      2.78
dtype: float64

Game: Alien 
Model: Decision Transformer
mean    13.29
std      0.00
dtype: float64

Game: RoadRunner 
Model: Decision Mamba
mean    26.28
std      9.93
dtype: float64

Game: RoadRunner 
Model: Decision Transformer
mean    20.53
std      3.79
dtype: float64

Game: BattleZone 
Model: Decision Mamba
mean    8.01
std     1.02
dtype: float64

Game: BattleZone 
Model: Decision Transformer
mean    11.89
std      4.87
dtype: float64

Game: BankHeist 
Model: Decision Mamba
mean   -0.41
std     0.19
dtype: float64

Game: BankHeist 
Model: Decision Transformer
mean    0.34
std     0.10
dtype: float64



In [32]:
base_path = "~/msc-project/atari/output/atari_30/"

experiments = [
    ("dmamba_breakout", "Breakout"),
    ("dtrans_breakout", "Breakout"),
    ("dmamba_qbert", "Qbert"),
    ("dtrans_qbert", "Qbert"),
    ("dmamba_hero", "Hero"),
    ("dtrans_hero", "Hero"),
    ("dmamba_kungfumaster", "KungFuMaster"),
    ("dtrans_kungfumaster", "KungFuMaster"),
    # ("dmamba_pong", "Pong"),
    # ("dtrans_pong", "Pong"),
    # ("dmamba_seaquest", "Seaquest"),
    # ("dtrans_seaquest", "Seaquest"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()


Game: Breakout 
Model: Decision Mamba
mean    390.16
std      50.43
dtype: float64

Game: Breakout 
Model: Decision Transformer
mean    274.65
std      70.69
dtype: float64

Game: Qbert 
Model: Decision Mamba
mean    23.07
std     10.61
dtype: float64

Game: Qbert 
Model: Decision Transformer
mean    12.05
std     11.95
dtype: float64

Game: Hero 
Model: Decision Mamba
mean    6.84
std     0.28
dtype: float64

Game: Hero 
Model: Decision Transformer
mean    30.12
std      4.24
dtype: float64

Game: KungFuMaster 
Model: Decision Mamba
mean    7.42
std     0.69
dtype: float64

Game: KungFuMaster 
Model: Decision Transformer
mean    10.49
std      6.49
dtype: float64



In [33]:
base_path = "~/msc-project/atari/output/atari_50/"

experiments = [
    ("dmamba_breakout", "Breakout"),
    ("dtrans_breakout", "Breakout"),
    ("dmamba_qbert", "Qbert"),
    ("dtrans_qbert", "Qbert"),
    ("dmamba_hero", "Hero"),
    ("dtrans_hero", "Hero"),
    ("dmamba_kungfumaster", "KungFuMaster"),
    ("dtrans_kungfumaster", "KungFuMaster"),
    # ("dmamba_pong", "Pong"),
    # ("dtrans_pong", "Pong"),
    # ("dmamba_seaquest", "Seaquest"),
    # ("dtrans_seaquest", "Seaquest"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()


Game: Breakout 
Model: Decision Mamba
mean    343.63
std     106.93
dtype: float64

Game: Breakout 
Model: Decision Transformer
mean    206.48
std      33.53
dtype: float64

Game: Qbert 
Model: Decision Mamba
mean    22.73
std      2.31
dtype: float64

Game: Qbert 
Model: Decision Transformer
mean    11.48
std      6.11
dtype: float64

Game: Hero 
Model: Decision Mamba
mean    7.58
std     1.58
dtype: float64

Game: Hero 
Model: Decision Transformer
mean    29.31
std      8.40
dtype: float64

Game: KungFuMaster 
Model: Decision Mamba
mean    5.15
std     3.51
dtype: float64

Game: KungFuMaster 
Model: Decision Transformer
mean    11.77
std     10.34
dtype: float64



In [34]:
base_path = "~/msc-project/atari/output/atari_10_fused/"

experiments = [
    ("dmamba_hero", "Hero"),
    ("dtrans_hero", "Hero"),
    ("dmamba_kungfumaster", "KungFuMaster"),
    ("dtrans_kungfumaster", "KungFuMaster"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()


Game: Hero 
Model: Decision Mamba
mean    7.94
std     1.34
dtype: float64

Game: Hero 
Model: Decision Transformer
mean    26.80
std      3.94
dtype: float64

Game: KungFuMaster 
Model: Decision Mamba
mean    2.35
std     1.51
dtype: float64

Game: KungFuMaster 
Model: Decision Transformer
mean    4.68
std     0.60
dtype: float64



In [35]:
base_path = "~/msc-project/atari/output/atari_10_simplest/"

experiments = [
    ("dmamba_hero", "Hero"),
    ("dtrans_hero", "Hero"),
    ("dmamba_kungfumaster", "KungFuMaster"),
    ("dtrans_kungfumaster", "KungFuMaster"),
]

for experiment_prefix, game in experiments:
    results = process_experiment_results(base_path, experiment_prefix, game)
    results = results.round(2)  # Round to 2 decimal places
    model_name = "Decision Mamba" if "dmamba" in experiment_prefix else "Decision Transformer"
    print(f"Game: {game} \nModel: {model_name}")
    print(results)
    print()


Game: Hero 
Model: Decision Mamba
mean   -0.28
std     0.40
dtype: float64

Game: Hero 
Model: Decision Transformer
mean   -0.61
std     0.18
dtype: float64

Game: KungFuMaster 
Model: Decision Mamba
mean    5.54
std     1.45
dtype: float64

Game: KungFuMaster 
Model: Decision Transformer
mean    5.55
std     1.68
dtype: float64

