In [1]:
import json
import kaggle_environments
import numpy as np
from pathlib import Path
import requests
import torch

# Local imports
import vectorized_env as ve

Loading environment football failed: No module named 'gfootball'


In [5]:
def batch_convert_replays_to_s_a_r_d_s(replays, reward_type, obs_type, normalize_reward):
    kaggle_envs = [kaggle_environments.make(
        'mab',
        configuration=replay['configuration'],
        steps=replay['steps'],
        info=replay['info']
    ) for replay in replays]
    sim_env = ve.KaggleMABEnvTorchVectorized(
        n_envs=len(replays),
        reward_type=reward_type,
        obs_type=obs_type,
        normalize_reward=normalize_reward,
        env_device=torch.device('cpu'),
        out_device=torch.device('cpu')
    )
    kaggle_n_steps = len(kaggle_envs[0].steps)
    assert np.all(np.array([len(ke.steps) for ke in kaggle_envs]) == kaggle_n_steps)
    
    actions = []
    thresholds = []
    for step in range(kaggle_n_steps):
        actions.append(torch.stack([
            torch.tensor([s['action'] for s in ke.steps[step]]) for ke in kaggle_envs
        ]))
        thresholds.append(torch.stack([
            torch.tensor(ke.steps[step][0]['observation']['thresholds']) for ke in kaggle_envs
        ]).float())

    s_batch = []
    a_batch = []
    r_batch = []
    d_batch = []
    next_s_batch = []
    sim_env.reset()
    sim_env.orig_thresholds = thresholds[0].clone()
    actions.pop(0)
    thresholds.pop(0)
    s = sim_env.obs
    for i, a in enumerate(actions):
        a = a
        next_s, r, done, _ = sim_env.step(a)
        s_batch.append(s.clone())
        a_batch.append(a.clone())
        r_batch.append(r.clone())
        d_batch.append(torch.zeros(r.shape) if not done else torch.ones(r.shape))
        next_s_batch.append(next_s.clone())
        s = next_s
        assert torch.allclose(sim_env.thresholds.view(-1), thresholds[i].view(-1)), f'ERROR: {i}'
    
    return (torch.cat(s_batch).view(-1, *s.shape[-2:]),
            torch.cat(a_batch).view(-1),
            torch.cat(r_batch).view(-1),
            torch.cat(d_batch).view(-1),
            torch.cat(next_s_batch).view(-1, *s.shape[-2:]))

In [9]:
replays = []
for replay_filename in Path('/home/isaiah/Downloads/episodes/').glob('*.json'):
    with open(replay_filename, 'r') as f:
        replays.append(json.load(f))
    if len(replays) >= 100:
        break

s_batch, a_batch, r_batch, d_batch, next_s_batch = batch_convert_replays_to_s_a_r_d_s(replays, ve.EVERY_STEP_EV_ZEROSUM, ve.SUMMED_OBS, False)
s_batch.shape, a_batch.shape, r_batch.shape, d_batch.shape, next_s_batch.shape

(torch.Size([399800, 100, 3]),
 torch.Size([399800]),
 torch.Size([399800]),
 torch.Size([399800]),
 torch.Size([399800, 100, 3]))

## Convert one replay at a time

In [31]:
def convert_replay_to_s_a_r_d_s(replay, reward_type, obs_type, normalize_reward):
    kaggle_env = kaggle_environments.make(
        'mab',
        configuration=replay['configuration'],
        steps=replay['steps'],
        info=replay['info']
    )
    sim_env = ve.KaggleMABEnvTorchVectorized(
        reward_type=reward_type,
        obs_type=obs_type,
        normalize_reward=normalize_reward,
        env_device=torch.device('cpu'),
        out_device=torch.device('cpu')
    )

    actions = []
    thresholds = []
    for step_info in kaggle_env.steps[1:]:
        actions.append(torch.tensor([s['action'] for s in step_info]))
        thresholds.append(torch.tensor(step_info[0]['observation']['thresholds']).float())

    s_batch = []
    a_batch = []
    r_batch = []
    d_batch = []
    next_s_batch = []
    sim_env.reset()
    sim_env.orig_thresholds = torch.tensor(kaggle_env.steps[0][0]['observation']['thresholds'], dtype=sim_env.orig_thresholds.dtype).view(sim_env.orig_thresholds.shape)
    s = sim_env.obs
    for i, a in enumerate(actions):
        a = a.unsqueeze(0)
        next_s, r, done, _ = sim_env.step(a)
        s_batch.append(s.clone())
        a_batch.append(a.clone())
        r_batch.append(r.clone())
        d_batch.append(torch.ones(r.shape) if done else torch.zeros(r.shape))
        next_s_batch.append(next_s.clone())
        s = next_s
        assert torch.allclose(sim_env.thresholds.view(-1), thresholds[i].view(-1)), f'ERROR: {i}'
    
    return (torch.cat(s_batch).view(-1, *s.shape[-2:]),
            torch.cat(a_batch).view(-1),
            torch.cat(r_batch).view(-1),
            torch.cat(d_batch).view(-1),
            torch.cat(next_s_batch).view(-1, *s.shape[-2:]))