In [114]:
import random
import time

# Imports

import pandas as pd
import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

In [3]:
# Importing gym and its other stuff
import gym
from gym import logger as gymlogger
gymlogger.set_level(40) # err only

env = gym.make("CartPole-v1")

In [95]:
### LOGGING OBJECT

class Logger:
    def __init__(self, filename='./data.csv'):
        self.filename = filename

    def __parseStats(self, stats) -> str:
        print(stats)
        obsv = stats[0]
        reward = stats[1]
        done = stats[2]
        info = stats[3]

        obsvCSV = ','.join([x.astype(str) for x in obsv])
        return f"{obsvCSV},{reward},{done},{info}"

    def log(self, stats, action, totalReward) -> None:
        # open file
        self.__parseStats(stats=stats)
        with open(self.filename, 'a') as file:
            file.write(f"{self.__parseStats(stats=stats)},{totalReward},{action}\n")
            pass


In [124]:
### Agent OBJECT

class Agent:
    def __init__(self, environment, logger:Logger) -> None:
        self.env = environment
        self.logger = logger
        self.isRunning = True
        self.curReward = 0

    def __reset(self):
        startState = self.env.reset()
        self.isRunning = True
        self.curReward = 0
        return startState

    def policy(self) -> int:
        return random.randint(0,1)

    def run(self) -> None:
        self.__reset()
        while self.isRunning:
            action = self.policy() # determining action
            stats = self.env.step(action) # execute action

            # update state
            self.isRunning = stats[2]==False
            self.curReward += stats[1] if self.isRunning else 0

            # log information etc
            self.logger.log(stats=stats, action=action, totalReward=self.curReward)

In [126]:
for _ in range(2):
    agt = Agent(gym.make("CartPole-v1"), Logger())
    agt.run()

(array([-0.01118731,  0.1855544 ,  0.03622197, -0.3023817 ], dtype=float32), 1.0, False, False, {})
(array([-0.01118731,  0.1855544 ,  0.03622197, -0.3023817 ], dtype=float32), 1.0, False, False, {})
(array([-0.00747623,  0.38014188,  0.03017434, -0.5834245 ], dtype=float32), 1.0, False, False, {})
(array([-0.00747623,  0.38014188,  0.03017434, -0.5834245 ], dtype=float32), 1.0, False, False, {})
(array([ 1.2661106e-04,  5.7482839e-01,  1.8505851e-02, -8.6645144e-01],
      dtype=float32), 1.0, False, False, {})
(array([ 1.2661106e-04,  5.7482839e-01,  1.8505851e-02, -8.6645144e-01],
      dtype=float32), 1.0, False, False, {})
(array([ 0.01162318,  0.37945956,  0.00117682, -0.56800795], dtype=float32), 1.0, False, False, {})
(array([ 0.01162318,  0.37945956,  0.00117682, -0.56800795], dtype=float32), 1.0, False, False, {})
(array([ 0.01921237,  0.574565  , -0.01018334, -0.8603199 ], dtype=float32), 1.0, False, False, {})
(array([ 0.01921237,  0.574565  , -0.01018334, -0.8603199 ], dty

# Analysing data

Q-learning requires discrete states from the environment, however, the Cartpole problem provides states that are continuous.


We can approach this by creating discrete steps in the environment. While this can be done through trial & error by making intuitive guesses on the different states, a more calculative approach will be to use machine learning to create clusters ("buckets") to be used as our steps.

In [128]:
### Importing data

rawData = pd.read_csv('./data.csv')
rawData

Unnamed: 0,cartPos,cartVel,poleAngle,poleVel,reward,done,idkwhatthisis,totalReward,action
0,0.024236,0.191275,0.003320,-0.287982,1.0,False,False,1.0,1
1,0.028062,-0.003895,-0.002439,0.005746,1.0,False,False,2.0,0
2,0.027984,0.191262,-0.002324,-0.287706,1.0,False,False,3.0,1
3,0.031809,-0.003826,-0.008079,0.004243,1.0,False,False,4.0,0
4,0.031733,0.191410,-0.007994,-0.290978,1.0,False,False,5.0,1
...,...,...,...,...,...,...,...,...,...
4893,0.176216,0.581834,-0.135048,-0.912569,1.0,False,False,33.0,1
4894,0.187853,0.388772,-0.153299,-0.665197,1.0,False,False,34.0,0
4895,0.195628,0.585656,-0.166603,-1.001953,1.0,False,False,35.0,1
4896,0.207341,0.782565,-0.186642,-1.341982,1.0,False,False,36.0,1


## Clustering Approach

We can consider a few different buckets that can be derived from our data.
* Clustering purely from observations.
* Clustering based on observations and action done.

Furthermore, there are a handful of methods we can use to cluster our data for prediction.
* k-means (requires configured number of clusters)
* Affinity propagation
* Mean-shift

In [None]:
obsvData =