# Prepare Atari data for video model training.

In [4]:
import numpy as np
import os
import tqdm
import random
import shutil

dataset = 'dmc'

In [6]:
if dataset == 'dmc':
    DATA_DIR = '/shared/ale/datasets/distill/dmc_unfiltered/'
    SAVE_DIR = '/shared/ale/datasets/distill/VIPER_DATA/dmc/'
    NUM_EPS_COMPUTE_STATS = 2000
    MIN_EP_LEN = 549
    MAX_EPISODES_PER_TASK = 500
    REMOVE_FIRST_N_FRAMES = 50
    FILTER_PERCENTILE = 90
    TRAIN_SPLIT = 0.95

    TASK_MAP = {
        'dm_dmc_acrobot_swingup1678341981': 'acrobot_swingup',
        'dm_dmc_cartpole_balance1678145672': 'cartpole_balance',
        'dm_dmc_cartpole_swingup1678145708': 'cartpole_swingup',
        'dm_dmc_cheetah_run1678145762': 'cheetah_run',
        'dm_dmc_cup_catch1678145481': 'cup_catch',
        'dm_dmc_finger_spin1678145861': 'finger_spin',
        'dm_dmc_finger_turn_hard1678240045': 'finger_turn_hard',
        'dm_dmc_manipulator_bring_ball1678386540': 'manipulator_bring_ball',
        'dm_dmc_hopper_stand1678146024': 'hopper_stand',
        'dm_dmc_pendulum_swingup1678238477': 'pendulum_swingup',
        'dm_dmc_pointmass_easy1678146369': 'pointmass_easy',
        'dm_dmc_pointmass_hard1678341667': 'pointmass_hard',
        'dm_dmc_quadruped_run1678238524': 'quadruped_run',
        'dm_dmc_quadruped_walk1678146538': 'quadruped_walk',
        'dm_dmc_reacher_easy1678238931': 'reacher_easy',
        'dm_dmc_reacher_hard1678238982': 'reacher_hard',
        'dm_dmc_walker_walk1679516828': 'walker_walk'
    }
elif dataset == 'atari':
    DATA_DIR = '/shared/ale/datasets/distill/distill/'
    SAVE_DIR = '/shared/ale/datasets/distill/VIPER_DATA/atari/'
    NUM_EPS_COMPUTE_STATS = 0
    MIN_EP_LEN = 0
    MAX_EPISODES_PER_TASK = 500
    REMOVE_FIRST_N_FRAMES = 0
    FILTER_PERCENTILE = -1
    TRAIN_SPLIT = 0.95

    TASK_MAP = {
        'trajectories_assault_expert': 'assault',
        'trajectories_atari_atlantis_expert': 'atlantis',
        'trajectories_atari_defender_expert': 'defender',
        'trajectories_atari_freeway_expert': 'freeway',
        'trajectories_atari_kangaroo_expert': 'kangaroo',
        'trajectories_boxing_expert': 'boxing',
        'trajectories_pong_expert': 'pong',
        'trajectories_zaxxon_expert': 'zaxxon',
    }

    os.makedirs(SAVE_DIR, exist_ok=True)

In [7]:
for task in TASK_MAP:
    print('Processing task: {}'.format(task))
    episode_dir = os.path.join(DATA_DIR, task, "saved_episodes")
    episode_files = os.listdir(episode_dir)
    random.shuffle(episode_files)
    print(f'\tLoading episodes from {episode_dir}')
    print(f'\tFound {len(episode_files)} episode files')
    save_dir = os.path.join(SAVE_DIR, TASK_MAP[task])

    rewards = []
    for episode in tqdm.tqdm(episode_files[:NUM_EPS_COMPUTE_STATS]):
        with open(os.path.join(episode_dir, episode), "rb") as f:
            data = np.load(f)
            if len(data['reward']) < MIN_EP_LEN:
                continue
            reward = np.sum(data['reward'])
            rewards.append(reward)
    if FILTER_PERCENTILE == -1:
        rew_percentile = -float('inf')
    else:
        rew_percentile = np.percentile(rewards, FILTER_PERCENTILE)
        print(f'\t{FILTER_PERCENTILE}th percentile reward: {rew_percentile}')
    num_episodes_saved = 0
    pbar = tqdm.tqdm(total=MAX_EPISODES_PER_TASK)

    for episode in episode_files:
        if num_episodes_saved > MAX_EPISODES_PER_TASK:
            break
        with open(os.path.join(episode_dir, episode), "rb") as f:
            data = np.load(f)
            if len(data['reward']) < MIN_EP_LEN:
                continue
            reward = np.sum(data['reward'])
            try:
                if reward >= rew_percentile and len(data['reward']) >= MIN_EP_LEN:
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    np.savez_compressed(os.path.join(save_dir, episode), data['image'][REMOVE_FIRST_N_FRAMES:])
                    num_episodes_saved += 1
                    pbar.update(1)
            except:
                continue


Processing task: dm_dmc_acrobot_swingup1678341981
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_acrobot_swingup1678341981/saved_episodes
	Found 60812 episode files


100%|██████████| 2000/2000 [01:35<00:00, 20.99it/s]


	90th percentile reward: 730.997802734375


 68%|██████▊   | 338/500 [2:12:50<1:03:40, 23.58s/it]


Processing task: dm_dmc_cartpole_balance1678145672
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_cartpole_balance1678145672/saved_episodes
	Found 8292 episode files


100%|██████████| 2000/2000 [00:14<00:00, 136.22it/s]


	90th percentile reward: 997.0733642578125


501it [04:04,  2.05it/s]00:00<?, ?it/s]
501it [01:25,  3.03it/s]                         

Processing task: dm_dmc_cartpole_swingup1678145708
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_cartpole_swingup1678145708/saved_episodes
	Found 8164 episode files


100%|██████████| 2000/2000 [00:16<00:00, 124.62it/s]


	90th percentile reward: 861.6147705078125


501it [01:41,  4.95it/s]


Processing task: dm_dmc_cheetah_run1678145762
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_cheetah_run1678145762/saved_episodes
	Found 12116 episode files


100%|██████████| 2000/2000 [01:19<00:00, 25.26it/s]


	90th percentile reward: 918.5155883789063


501it [02:50,  2.93it/s]00:00<?, ?it/s]
501it [04:06,  1.14it/s]                         

Processing task: dm_dmc_cup_catch1678145481
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_cup_catch1678145481/saved_episodes
	Found 7012 episode files


100%|██████████| 2000/2000 [00:17<00:00, 115.51it/s]


	90th percentile reward: 974.0


501it [04:24,  1.90it/s]


Processing task: dm_dmc_finger_spin1678145861
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_finger_spin1678145861/saved_episodes
	Found 12028 episode files


100%|██████████| 2000/2000 [01:13<00:00, 27.26it/s]


	90th percentile reward: 983.0


501it [02:40,  3.13it/s]00:00<?, ?it/s]
501it [03:08,  2.17it/s]                         

Processing task: dm_dmc_finger_turn_hard1678240045
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_finger_turn_hard1678240045/saved_episodes
	Found 4768 episode files


100%|██████████| 2000/2000 [00:40<00:00, 49.03it/s]


	90th percentile reward: 986.0


501it [03:49,  2.18it/s]


Processing task: dm_dmc_manipulator_bring_ball1678386540
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_manipulator_bring_ball1678386540/saved_episodes
	Found 23324 episode files


100%|██████████| 2000/2000 [01:14<00:00, 26.83it/s]


	90th percentile reward: 242.0665390014653


501it [03:15,  2.57it/s]00:00<?, ?it/s]
501it [03:08,  1.83it/s]                         

Processing task: dm_dmc_hopper_stand1678146024
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_hopper_stand1678146024/saved_episodes
	Found 8248 episode files


100%|██████████| 2000/2000 [00:35<00:00, 55.86it/s]


	90th percentile reward: 956.372509765625


501it [03:44,  2.23it/s]


Processing task: dm_dmc_pendulum_swingup1678238477
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_pendulum_swingup1678238477/saved_episodes
	Found 5396 episode files


100%|██████████| 2000/2000 [00:02<00:00, 969.73it/s] 


	90th percentile reward: 940.0


501it [02:19,  3.59it/s]00:00<?, ?it/s]
501it [00:53,  8.75it/s]                         

Processing task: dm_dmc_pointmass_easy1678146369
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_pointmass_easy1678146369/saved_episodes
	Found 6760 episode files


100%|██████████| 2000/2000 [00:02<00:00, 857.69it/s]


	90th percentile reward: 937.4158935546875


501it [00:55,  9.01it/s]


Processing task: dm_dmc_pointmass_hard1678341667
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_pointmass_hard1678341667/saved_episodes
	Found 8432 episode files


100%|██████████| 2000/2000 [00:31<00:00, 63.59it/s] 


	90th percentile reward: 898.4111328125


501it [01:19,  6.33it/s]00:00<?, ?it/s]
501it [01:36,  3.17it/s]                         

Processing task: dm_dmc_quadruped_run1678238524
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_quadruped_run1678238524/saved_episodes
	Found 9816 episode files


100%|██████████| 2000/2000 [02:12<00:00, 15.05it/s]


	90th percentile reward: 950.9744995117187


501it [03:49,  2.18it/s]


Processing task: dm_dmc_quadruped_walk1678146538
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_quadruped_walk1678146538/saved_episodes
	Found 10916 episode files


100%|██████████| 2000/2000 [02:11<00:00, 15.22it/s]


	90th percentile reward: 974.7005432128907


501it [11:33,  1.38s/it]00:00<?, ?it/s]
501it [07:29,  1.24it/s]                         

Processing task: dm_dmc_reacher_easy1678238931
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_reacher_easy1678238931/saved_episodes
	Found 4812 episode files


100%|██████████| 2000/2000 [00:23<00:00, 83.44it/s] 


	90th percentile reward: 990.0


501it [07:53,  1.06it/s]


Processing task: dm_dmc_reacher_hard1678238982
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_reacher_hard1678238982/saved_episodes
	Found 4820 episode files


100%|██████████| 2000/2000 [00:17<00:00, 111.49it/s]


	90th percentile reward: 981.0


501it [01:38,  5.10it/s]00:00<?, ?it/s]
 94%|█████████▍| 469/500 [01:14<00:07,  3.93it/s]

Processing task: dm_dmc_walker_walk1679516828
	Loading episodes from /shared/ale/datasets/distill/dmc_unfiltered/dm_dmc_walker_walk1679516828/saved_episodes
	Found 523 episode files


100%|██████████| 523/523 [00:25<00:00, 20.90it/s]


	90th percentile reward: 750.7719116210938


 94%|█████████▍| 469/500 [01:39<00:06,  4.71it/s]


In [None]:
for task in TASK_MAP:
    episode_dir = os.path.join(SAVE_DIR, TASK_MAP[task])
    episode_files = os.listdir(episode_dir)
    random.shuffle(episode_files)
    train_files = episode_files[:int(TRAIN_SPLIT * len(episode_files))]
    test_files = episode_files[int(TRAIN_SPLIT * len(episode_files)):]

    train_save_dir = os.path.join(SAVE_DIR, TASK_MAP[task], 'train')
    test_save_dir = os.path.join(SAVE_DIR, TASK_MAP[task], 'test')
    os.makedirs(train_save_dir, exist_ok=True)
    os.makedirs(test_save_dir, exist_ok=True)

    for train_file in train_files:
        shutil.move(os.path.join(episode_dir, train_file), os.path.join(train_save_dir, train_file))
    
    for test_file in test_files:
        shutil.move(os.path.join(episode_dir, test_file), os.path.join(test_save_dir, test_file))