In [1]:
import gym

import numpy as np
import numpy.random as rnd

import matplotlib.pyplot as plt

import sklearn
import sklearn.cluster

In [5]:
def permutations(elems, T):
    def recurse(xs):
        if len(xs[0]) >= T:
            return xs
        else:
            return recurse([x+[e] for x in xs for e in elems]) # + xs
            
    return recurse([[e] for e in elems])

class RndOptionPolicy():
    def __init__(self, n_actions, n_time_steps):
        self.options = np.array(permutations(range(n_actions), n_time_steps))
        
    def __call__(self, x):
        rnd_idx = rnd.choice(np.arange(self.options.shape[0]))
        return self.options[rnd_idx]
     
class OptionEnvWrapper():
    def __init__(self, env):
        self.env = env
        
    def step(self, actions):
        R = 0
        for a in actions:
            s, r, done, info = self.env.step(a)
            R += r
        return s, R, done, info

    def reset(self):
        return self.env.reset()

In [13]:
env = gym.make('Acrobot-v1')
rnd_policy = lambda obs: env.action_space.sample()
op = RndOptionPolicy(env.action_space.n, 12)
env = OptionEnvWrapper(env)
len(op.options), op(2)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


(531441, array([0, 0, 2, 2, 1, 0, 1, 0, 2, 0, 2, 1]))

In [20]:
def play_episode(env, player, T=5):
    # reset
    s = env.reset()
    R = 0
    done = False
    pairs = []
    
    # play an episode
    while not done:

        a = player(s)
        new_s, r, done, info = env.step(a)
        R += r
        
        pairs.append((s, a, new_s))
        s = new_s
            
    return pairs

In [22]:
play_episode(env, op)

[(array([ 0.99998542, -0.00540047,  0.99940696,  0.0344345 ,  0.07235695,
         -0.00344042]),
  array([1, 1, 1, 0, 0, 2, 2, 2, 2, 0, 2, 2]),
  array([ 0.98904334, -0.14762542,  0.93069285,  0.36580162,  0.28681797,
         -0.2624945 ])),
 (array([ 0.98904334, -0.14762542,  0.93069285,  0.36580162,  0.28681797,
         -0.2624945 ]),
  array([2, 2, 2, 0, 1, 0, 2, 0, 2, 2, 1, 0]),
  array([ 0.96781963, -0.25164493,  0.97168708,  0.23627147,  0.12950684,
         -0.14954337])),
 (array([ 0.96781963, -0.25164493,  0.97168708,  0.23627147,  0.12950684,
         -0.14954337]),
  array([1, 2, 0, 2, 0, 1, 2, 0, 0, 2, 0, 1]),
  array([ 0.99051499, -0.13740472,  0.98697457,  0.16087632,  0.14798114,
         -0.75304686])),
 (array([ 0.99051499, -0.13740472,  0.98697457,  0.16087632,  0.14798114,
         -0.75304686]),
  array([2, 1, 0, 0, 0, 2, 1, 0, 1, 2, 0, 2]),
  array([ 0.98780209, -0.15571458,  0.87502773,  0.4840728 ,  0.18868197,
         -0.08509837])),
 (array([ 0.98780209, -0

In [27]:
def get_pairs():
    pairs = play_episode(env,op)
    pairs = tuple(zip(*pairs))
    return tuple([np.vstack(p) for p in pairs])

In [28]:
def get_n(n):
    pairs = [get_pairs() for i in range(n)]
    pairs = tuple(zip(*pairs))
    return tuple([np.vstack(p) for p in pairs])

In [30]:
get_n(3)[0].shape, get_n(3)[1].shape, get_n(3)[2].shape

((126, 6), (126, 12), (126, 6))

In [None]:
kmeans = sklearn.cluster.MiniBatchKMeans(n_clusters=32)

In [None]:
for _ in range(20):
    states = []
    p, s = get_n(50)
    kmeans.partial_fit(s)

In [None]:
kmeans.cluster_centers_.shape

In [None]:
plt.hist(kmeans.predict(s))