In [17]:
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm

import cv2
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator

from ale_py import ALEInterface

### Data and Model Preparation

In [2]:
# Data Loading
data_dir = './_data/'
train_filenames = ['2022-04-22 19:56:07_0_cont']
test_filenames = ['2022-04-22 19:56:07_0_cont']

# Prepare Train Data
df_l = []
for f in train_filenames:
    filepath = data_dir + f + '.csv'
    df_l.append(pd.read_csv(filepath, index_col=None, header=0))
df_train = pd.concat(df_l, axis=0, ignore_index=True)

# Prepare Test Data
df_l = []
for f in test_filenames:
    filepath = data_dir + f + '.csv'
    df_l.append(pd.read_csv(filepath, index_col=None, header=0))
df_test = pd.concat(df_l, axis=0, ignore_index=True)

In [23]:
print(len(df_train))
df_train.head()

2072


Unnamed: 0,paddle_position,paddle_state,previous_paddle_position,ball_position,ball_state,previous_ball_position,keypress,previous_keypress,pickle_path
0,[ 0 56],right,"[0, 0]","[0, 0]",stagnant,"[0, 0]",0,0,./_data/2022-04-22 19:56:07_pickle/0
1,[ 0 56],stagnant,[ 0 56],"[0, 0]",stagnant,"[0, 0]",0,0,./_data/2022-04-22 19:56:07_pickle/1
2,[ 0 56],stagnant,[ 0 56],"[0, 0]",stagnant,"[0, 0]",0,0,./_data/2022-04-22 19:56:07_pickle/2
3,[ 0 56],stagnant,[ 0 56],"[0, 0]",stagnant,"[0, 0]",0,0,./_data/2022-04-22 19:56:07_pickle/3
4,[ 0 56],stagnant,[ 0 56],"[0, 0]",stagnant,"[0, 0]",0,0,./_data/2022-04-22 19:56:07_pickle/4


In [4]:
# Construct PGM
model_struct = BayesianNetwork([('ball_state','keypress'), ('paddle_state', 'keypress')])


### Train

In [5]:
# Training using MLE
mle = MaximumLikelihoodEstimator(model=model_struct, data=df_train)

In [6]:
# Print Learned Parameters
print(mle.estimate_cpd(node="keypress"))

+--------------+-----+------------------------+
| ball_state   | ... | ball_state(stagnant)   |
+--------------+-----+------------------------+
| paddle_state | ... | paddle_state(stagnant) |
+--------------+-----+------------------------+
| keypress(0)  | ... | 0.908315565031983      |
+--------------+-----+------------------------+
| keypress(1)  | ... | 0.015991471215351813   |
+--------------+-----+------------------------+
| keypress(2)  | ... | 0.029850746268656716   |
+--------------+-----+------------------------+
| keypress(3)  | ... | 0.04584221748400853    |
+--------------+-----+------------------------+


### Test

In [7]:
# Helper Functions
def test_state_generation(ball_state, paddle_state):
    if (ball_state+ ',' + paddle_state) == 'left,left': 
        state1 = 0
        state2 = 0
    if (ball_state+ ',' + paddle_state) == 'left,right': 
        state1 = 0
        state2 = 1
    if (ball_state+ ',' + paddle_state) == 'left,stagnant':
        state1 = 0
        state2 = 2
    if (ball_state+ ',' + paddle_state) == 'right,left': 
        state1 = 1
        state2 = 0 
    if (ball_state+ ',' + paddle_state) == 'right,right': 
        state1 = 1
        state2 = 1
    if (ball_state+ ',' + paddle_state) == 'right,stagnant': 
        state1 = 1
        state2 = 2
    if (ball_state+ ',' + paddle_state) == 'stagnant,left': 
        state1 = 2
        state2 = 0
    if (ball_state+ ',' + paddle_state) == 'stagnant,right': 
        state1 = 2
        state2 = 1
    if (ball_state+ ',' + paddle_state) == 'stagnant,stagnant': 
        state1 = 2
        state2 = 2
    return state1, state2

#### Test Accuracy of Prediction

In [8]:
# Test Setup
pos = 0
neg = 0
model_data = mle.estimate_cpd(node="keypress").values

# Main Test Loop
for i, row in df_test.iterrows():
    ball_state = row['ball_state']
    paddle_state = row['paddle_state']
    keypress = row['keypress']

    state1, state2 = test_state_generation(ball_state, paddle_state)

    choices = [model_data[i][state1][state2] for i in range (4)]
    best_choice = max(choices)

    predicted = choices.index(best_choice)
    if predicted == 2: predicted = 3
    if predicted == 3: predicted = 4  

    if predicted == keypress: 
        pos += 1
    else: neg += 1

acc = pos/(pos+neg)
print(f'Accuracy : {acc}')
print(f'Pos      : {pos}')
print(f'Neg      : {neg}')

Accuracy : 0.8103281853281853
Pos      : 1679
Neg      : 393


#### Test Branching Predictions

In [13]:
# Setup ALE
ale = ALEInterface()
ale.setInt("random_seed", 42)
ale.setFloat("repeat_action_probability", 0)

# Load ROM File
rom_file = "./rom/breakout.bin"
ale.loadROM(rom_file)

# Get the list of legal actions
legal_actions = ale.getLegalActionSet()
minimal_actions = ale.getMinimalActionSet()

# Reset and Prep Screen 
(screen_width,screen_height) = ale.getScreenDims()
screen_data = np.zeros((screen_width,screen_height,3),dtype=np.uint8)
ale.getScreenRGB(screen_data)

A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
Game console created:
  ROM file:  ./rom/breakout.bin
  Cart Name: Breakout - Breakaway IV (1978) (Atari)
  Cart MD5:  f34f08e5eb96e500e851a80be3277a56
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is 42


In [14]:
# Helper Functions
def generate_game_object_position(frame, all_black_pixels_ball):
    # Get ball and paddle states from the frame
    colored_region = frame[57:57 + 36, 43:146]
    ball_frame = frame[107:107 + 82, 43:146]
    paddle_frame = frame[190:190 + 3, 43:146]
    red_pixels_ball = np.argwhere(cv2.inRange(ball_frame, (1, 1, 1), (255, 255, 255)))
    red_pixels_paddle = np.argwhere(cv2.inRange(paddle_frame, (1, 1, 1), (255, 255, 255)))

    if len(red_pixels_ball) != 0:
        ball_position = red_pixels_ball[0]
    else:
        black_pixels_ball = np.argwhere(cv2.inRange(colored_region, (0, 0, 0), (0, 0, 0)))
        values = []

        for j in range(len(all_black_pixels_ball)):
            values = all_black_pixels_ball[j]

        red_pixels_ball = [i for i in black_pixels_ball if i not in values]
        try:
            ball_position = red_pixels_ball[0]
        except:
            ball_position = [0, 0]
        all_black_pixels_ball.append(black_pixels_ball)
    try:
        paddle_position = red_pixels_paddle[0]
    except:
        paddle_position = [0, 0]

    return paddle_position, ball_position

def generate_game_state(objects_position, previous_objects_position):
    paddle_position = objects_position[0]
    ball_position = objects_position[1]

    previous_paddle_position = previous_objects_position[0]
    previous_ball_position = previous_objects_position[1]

    if previous_paddle_position[1] < paddle_position[1]:
        paddle_state = "right"
    elif previous_paddle_position[1] > paddle_position[1]:
        paddle_state = "left"
    else:
        paddle_state = "stagnant"

    if previous_ball_position[1] < ball_position[1]:
        ball_state = "right"
    elif previous_ball_position[1] > ball_position[1]:
        ball_state = "left"
    else:
        ball_state = "stagnant"

    return paddle_state, ball_state

def create_three_timestep_dataset(df):
    rand_idx = np.random.randint(0, len(df_test)-3)
    return df.iloc[rand_idx], df.iloc[rand_idx+1], df.iloc[rand_idx+2]

def convert_keypress(kp):
    if kp == 3: kp = 2
    if kp == 4: kp = 3

    return kp

In [20]:
num_test = 10 #int(len(df_test)/3)
threshold = 0.2
inputs = [1, 2, 3]

pos = 0
neg = 0

total_frames_count = []

for _ in tqdm(range(num_test)):
    df_1, df_2, df_3 = create_three_timestep_dataset(df_test)
    kp_2 = convert_keypress(df_2['keypress'])
    kp_3 = convert_keypress(df_3['keypress'])

    prob_2 = []
    prob_3 = []

    frames_count = 0

    for key_2 in inputs:
        with open(df_1['pickle_path'], 'rb') as f:
            state_1 = pickle.load(f)
        ale.restoreState(state_1)

        a = minimal_actions[key_2]
        _ = ale.act(a);
        frame = ale.getScreenRGB()

        paddle_position_2, ball_position_2 = generate_game_object_position(frame, [])
        print(paddle_position_2)
        print(df_1['paddle_position'])
        paddle_state, ball_state = generate_game_state([df_1['paddle_position'], df_1['ball_position']], 
                                                    [paddle_position_2, ball_position_2])
        test_state1, test_state2 = test_state_generation(paddle_state, ball_state)
        prob_2.append(model_data[key_2][test_state1][test_state2])

        state_2 = ale.cloneState(include_rng=True)

        prob_3_buff = []

        for key_3 in inputs:

            ale.restoreState(state_2)

            a = minimal_actions[key_3]
            _ = ale.act(a);
            frame = ale.getScreenRGB()

            paddle_position_3, ball_position_3 = generate_game_object_position(frame, [])
            paddle_state, ball_state = generate_game_state([paddle_position_3, ball_position_3],
                                                    [paddle_position_2, ball_position_2])
            test_state1, test_state2 = test_state_generation(paddle_state, ball_state)
            prob_3_buff.append(model_data[key_3][test_state1][test_state2])
        
        prob_3.append(prob_3_buff)

    if prob_2[kp_2 - 1] >= threshold:
        if prob_3[kp_2 - 1][kp_3 - 1] > threshold:
            pos += 1
        else:
            neg +=1
    else:
        neg += 1

    for i in inputs:
        if prob_2[i] >= threshold:
            for j in inputs:
                if prob_3[i][j] >= threshold:
                    frames_count+=1

    total_frames_count.append(frames_count)

acc = pos/(pos+neg)
print(f'Accuracy : {acc}')
print(f'Pos      : {pos}')
print(f'Neg      : {neg}')
print(f'Avg Frame: {np.mean(np.array(total_frames_count))}')
    


  0%|          | 0/10 [00:00<?, ?it/s]

[ 0 69]
[ 0 68]





TypeError: '<' not supported between instances of 'numpy.ndarray' and 'str'