# Load and test trained model

## Load libraries

In [1]:
import numpy as np
import torch

import torch.nn as nn
from enduro_lstm import *

In [2]:
device = conf_cuda(False)

Selected CPU


In [3]:
device

device(type='cpu')

In [4]:
device.type

'cpu'

## Set configurations

In [5]:
import os
dir_path = "best_models/chuncked_sequence/lstm/softmax/softmax_m45to50_f1to1020_H200_epoch10000" + "/"
arr = os.listdir(f'./{dir_path}')
for i in range(len(arr)):
    print(arr[i])

loss_file.txt
shin_chunked_m45to50_f1to1020_epoch10000_H200
shin_chunked_m45to50_f1to1020_epoch10000_H200.npz
train_loss.png
train_loss_arr.npz
_home_ryo_.local_lib_python3.8_site-packages_ipykernel_launcher.py (Ubuntu) 2021-05-14 17-51-39.mp4
_home_ryo_.local_lib_python3.8_site-packages_ipykernel_launcher.py (Ubuntu) 2021-05-14 17-53-00.mp4
_home_ryo_.local_lib_python3.8_site-packages_ipykernel_launcher.py (Ubuntu) 2021-05-14 17-54-19.mp4
_home_ryo_.local_lib_python3.8_site-packages_ipykernel_launcher.py (Ubuntu) 2021-05-14 17-55-42.mp4
_home_ryo_.local_lib_python3.8_site-packages_ipykernel_launcher.py (Ubuntu) 2021-05-14 17-57-01.mp4


In [6]:
model_path = dir_path + "shin_chunked_m45to50_f1to1020_epoch10000_H200"

In [7]:
start_match = 115
end_match = 115

hidden_neurons = 200
zigzag = False
is_softmax = True

start_frame = 1
end_frame = 31141

In [8]:
data_path = r"../1-generate/data/"

use_cuda = False
load_checkpoint = True

## Load trained model

In [11]:
if zigzag:
    output_size = 2
else:
    output_size = 9

In [12]:
def load_checkpoint(model, filename='checkpoint.pth.tar'):
    
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])

    return model, checkpoint['optimizer']['state'][0]['step'], checkpoint['losslogger']

In [13]:
if load_checkpoint:
    model = LSTMModel(device=device, input_size=12000, output_size=output_size, hidden_dim=hidden_neurons, n_layers=1, is_softmax=is_softmax)
    model, last_epoch, last_logger = load_checkpoint(model, model_path)
else:
    model = Model(device=device, input_size=12000, output_size=output_size, hidden_dim=hidden_neurons, n_layers=1, is_softmax=is_softmax)
    model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

=> loading checkpoint 'best_models/chuncked_sequence/lstm/softmax/softmax_m45to50_f1to1020_H200_epoch10000/shin_chunked_m45to50_f1to1020_epoch10000_H200'


LSTMModel(
  (lstm): LSTM(12000, 200, batch_first=True)
  (fc): Linear(in_features=200, out_features=9, bias=True)
  (out): Softmax(dim=None)
)

In [14]:
ACTIONS_LIST = get_actions_list(zigzag=zigzag)

In [15]:
def load_npz(data_path, m):
    
    path = data_path + "match_" + str(m) + "/npz/"

    actions = np.load(path + 'actions.npz')
    frames = np.load(path + 'frames.npz')
    rewards = np.load(path + 'reward.npz')

    arr_actions = actions.f.arr_0
    arr_frames = frames.f.arr_0
    arr_rewards = rewards.f.arr_0

    print("Successfully loaded NPZ.")

    return arr_actions.shape[0], arr_frames, arr_actions, arr_rewards

In [14]:
num_of_frames_arr = []
frames_arr = []
actions_arr = []

for m in range(start_match, end_match + 1):
    
    num_of_frames, frames, actions, rewards = load_npz(data_path, m)
    frames = frames[start_frame - 1:end_frame]
    frames = frames.reshape(end_frame - start_frame + 1, 170, 120)
    frames = frames[:, 30:130, :]
    frames = frames.reshape(end_frame - start_frame + 1, 12000)
    actions = actions[start_frame - 1:end_frame]
    
    action_one_hot = [prepare_action_data(i, ACTIONS_LIST) for i in actions]
    actions = np.array(action_one_hot)
    actions = actions.reshape(len(actions), -1)
    
    frames_arr.append(frames)
    actions_arr.append(actions)
    num_of_frames_arr.append(end_frame - start_frame + 1) 

Successfully loaded NPZ.


In [15]:
X_train = np.array(frames_arr)/255
Y_train = np.array(actions_arr)

In [16]:
X_train.shape

(1, 31141, 12000)

## Prepare cell with trained model 

In [16]:
lstmcell = nn.LSTMCell(12000, hidden_neurons)
linear = nn.Linear(hidden_neurons, output_size)
if is_softmax:
    output = nn.Softmax()
else:
    output = nn.Sigmoid()

In [17]:
lstmcell.weight_ih = model.lstm.weight_ih_l0
lstmcell.weight_hh = model.lstm.weight_hh_l0
lstmcell.bias_hh = model.lstm.bias_hh_l0
lstmcell.bias_ih = model.lstm.bias_ih_l0
linear.weight = model.fc.weight
linear.bias = model.fc.bias

In [18]:
hx = torch.zeros(1, hidden_neurons)
cx = torch.zeros(1, hidden_neurons)

In [20]:
print(model.lstm.weight_ih_l0.shape)
print(model.lstm.weight_hh_l0.shape)
print(model.lstm.bias_ih_l0.shape)
print(model.lstm.bias_hh_l0.shape)
print(model.fc.weight.shape)
print(model.fc.bias.shape)

torch.Size([800, 12000])
torch.Size([800, 200])
torch.Size([800])
torch.Size([800])
torch.Size([9, 200])
torch.Size([9])


## Testing outputs of model

In [21]:
hx = torch.zeros(1, hidden_neurons)
cx = torch.zeros(1, hidden_neurons)
out_arr = []
# for i in range(end_frame - start_frame + 1):
for i in range(1000):
    step_input = X_train[0][i]
    step_input = step_input.reshape(1, -1)    
    
    step_input = torch.tensor(step_input).float()
    
    hx, cx = lstmcell(step_input, (hx, cx))
    out = linear(hx)
    out = output(out)
    print(i, ": ", out[0])
    out_arr.append(out)

  out = output(out)


0 :  tensor([6.8418e-01, 2.9785e-01, 4.9430e-03, 8.1617e-03, 3.9823e-06, 2.0405e-06,
        3.0957e-06, 1.2913e-05, 4.8460e-03], grad_fn=<SelectBackward>)
1 :  tensor([9.8802e-01, 1.0674e-02, 1.3638e-05, 1.8163e-05, 5.5483e-08, 2.9406e-08,
        5.3306e-08, 2.2173e-07, 1.2722e-03], grad_fn=<SelectBackward>)
2 :  tensor([9.9745e-01, 2.1933e-03, 8.0062e-06, 4.9566e-06, 1.8798e-08, 9.9165e-09,
        1.7840e-08, 7.6569e-08, 3.4048e-04], grad_fn=<SelectBackward>)
3 :  tensor([2.5975e-01, 7.3818e-01, 1.7826e-05, 4.5887e-05, 1.1962e-07, 6.6884e-08,
        1.1569e-07, 4.8121e-07, 2.0004e-03], grad_fn=<SelectBackward>)
4 :  tensor([7.2885e-02, 9.2627e-01, 1.0433e-05, 2.7642e-05, 6.0666e-08, 3.4381e-08,
        6.0850e-08, 2.5032e-07, 8.0601e-04], grad_fn=<SelectBackward>)
5 :  tensor([7.4926e-02, 9.2452e-01, 6.7705e-06, 2.1278e-05, 5.3433e-08, 3.0351e-08,
        5.2016e-08, 2.2219e-07, 5.2369e-04], grad_fn=<SelectBackward>)
6 :  tensor([9.8875e-02, 9.0044e-01, 8.4435e-06, 2.5795e-05, 6.1

63 :  tensor([4.6375e-04, 9.9890e-01, 1.8938e-04, 1.2120e-05, 1.4226e-08, 7.4203e-09,
        1.1622e-08, 5.0274e-08, 4.3592e-04], grad_fn=<SelectBackward>)
64 :  tensor([3.3992e-04, 9.9844e-01, 1.9877e-04, 2.9305e-05, 1.5997e-08, 8.4778e-09,
        1.3268e-08, 5.5906e-08, 9.9605e-04], grad_fn=<SelectBackward>)
65 :  tensor([5.0046e-04, 9.9805e-01, 1.7737e-04, 4.2858e-05, 1.9923e-08, 1.0615e-08,
        1.6562e-08, 6.9712e-08, 1.2281e-03], grad_fn=<SelectBackward>)
66 :  tensor([9.7385e-04, 9.9717e-01, 2.1112e-04, 4.4381e-05, 2.6330e-08, 1.3938e-08,
        2.2348e-08, 9.2566e-08, 1.6042e-03], grad_fn=<SelectBackward>)
67 :  tensor([4.7144e-04, 9.9787e-01, 1.8652e-04, 5.0615e-05, 2.0467e-08, 1.0951e-08,
        1.7087e-08, 7.1727e-08, 1.4201e-03], grad_fn=<SelectBackward>)
68 :  tensor([1.4283e-03, 9.9405e-01, 3.8870e-05, 1.0537e-04, 2.9269e-08, 1.6174e-08,
        2.6618e-08, 1.0503e-07, 4.3814e-03], grad_fn=<SelectBackward>)
69 :  tensor([5.4936e-03, 9.8922e-01, 1.4813e-05, 1.3382e-

137 :  tensor([1.4629e-02, 8.4140e-01, 1.3949e-01, 2.6842e-04, 1.7742e-07, 8.5090e-08,
        1.3775e-07, 5.8604e-07, 4.2094e-03], grad_fn=<SelectBackward>)
138 :  tensor([2.0329e-02, 8.6365e-01, 1.1294e-01, 1.0524e-04, 1.6794e-07, 7.7868e-08,
        1.2922e-07, 5.3018e-07, 2.9734e-03], grad_fn=<SelectBackward>)
139 :  tensor([2.0839e-02, 8.0549e-01, 1.6769e-01, 8.8322e-05, 1.8929e-07, 8.9292e-08,
        1.5093e-07, 6.1565e-07, 5.8899e-03], grad_fn=<SelectBackward>)
140 :  tensor([3.6909e-03, 8.5372e-01, 1.3486e-01, 6.8275e-05, 9.8925e-08, 4.7850e-08,
        7.7740e-08, 3.3099e-07, 7.6663e-03], grad_fn=<SelectBackward>)
141 :  tensor([3.6155e-02, 5.9273e-01, 3.5151e-01, 3.5516e-03, 4.8106e-07, 2.3981e-07,
        3.9953e-07, 1.6560e-06, 1.6045e-02], grad_fn=<SelectBackward>)
142 :  tensor([4.4894e-03, 8.1716e-02, 6.6097e-01, 1.8807e-02, 2.2297e-07, 1.0860e-07,
        1.7171e-07, 8.4642e-07, 2.3401e-01], grad_fn=<SelectBackward>)
143 :  tensor([2.7571e-04, 6.6882e-02, 9.0406e-01, 6

191 :  tensor([4.1953e-03, 4.1398e-03, 1.6421e-02, 9.7291e-01, 8.6507e-08, 4.5824e-08,
        6.3067e-08, 2.4594e-07, 2.3322e-03], grad_fn=<SelectBackward>)
192 :  tensor([4.5145e-01, 5.3226e-02, 8.5419e-02, 4.0309e-01, 9.6773e-07, 5.1592e-07,
        7.3453e-07, 2.6224e-06, 6.8108e-03], grad_fn=<SelectBackward>)
193 :  tensor([9.8226e-01, 5.2348e-03, 9.1522e-03, 3.2230e-03, 1.6223e-07, 7.9529e-08,
        1.2271e-07, 4.2752e-07, 1.2639e-04], grad_fn=<SelectBackward>)
194 :  tensor([9.8614e-01, 6.1758e-03, 6.6025e-03, 9.6381e-04, 1.4104e-07, 6.8178e-08,
        1.0839e-07, 3.5959e-07, 1.2188e-04], grad_fn=<SelectBackward>)
195 :  tensor([9.8366e-01, 5.9477e-03, 9.0794e-03, 1.1327e-03, 1.6403e-07, 7.9572e-08,
        1.2820e-07, 4.0781e-07, 1.7851e-04], grad_fn=<SelectBackward>)
196 :  tensor([7.9566e-01, 8.4235e-02, 2.4053e-02, 9.3851e-02, 7.5093e-07, 4.0298e-07,
        5.7794e-07, 2.1469e-06, 2.1971e-03], grad_fn=<SelectBackward>)
197 :  tensor([9.0611e-01, 1.0284e-02, 3.1717e-03, 7

245 :  tensor([4.0061e-02, 8.1365e-01, 1.3922e-01, 1.5313e-03, 4.5704e-07, 2.2529e-07,
        3.7734e-07, 1.6621e-06, 5.5307e-03], grad_fn=<SelectBackward>)
246 :  tensor([4.5393e-01, 4.6864e-01, 6.4093e-02, 5.3474e-04, 5.3007e-07, 2.6822e-07,
        4.3731e-07, 2.0641e-06, 1.2801e-02], grad_fn=<SelectBackward>)
247 :  tensor([2.6631e-01, 5.2882e-01, 1.6538e-01, 7.2810e-04, 7.6855e-07, 3.6924e-07,
        6.1790e-07, 2.6686e-06, 3.8759e-02], grad_fn=<SelectBackward>)
248 :  tensor([4.3082e-02, 2.4695e-01, 4.2600e-01, 2.6913e-03, 5.1852e-07, 2.5773e-07,
        4.0483e-07, 1.7997e-06, 2.8128e-01], grad_fn=<SelectBackward>)
249 :  tensor([2.1067e-02, 2.6836e-01, 5.9565e-02, 1.3993e-02, 3.9504e-07, 1.9659e-07,
        3.4266e-07, 1.3497e-06, 6.3701e-01], grad_fn=<SelectBackward>)
250 :  tensor([2.5414e-03, 5.1787e-01, 2.4817e-02, 3.7729e-02, 2.4805e-07, 1.2653e-07,
        2.1498e-07, 8.2117e-07, 4.1704e-01], grad_fn=<SelectBackward>)
251 :  tensor([1.4918e-04, 5.9204e-01, 2.8269e-01, 1

314 :  tensor([3.5439e-04, 9.6645e-01, 1.2670e-02, 2.9542e-03, 6.5732e-08, 3.3125e-08,
        5.3704e-08, 2.2874e-07, 1.7569e-02], grad_fn=<SelectBackward>)
315 :  tensor([2.8330e-03, 9.4135e-01, 3.6546e-02, 1.1303e-03, 1.0652e-07, 5.7692e-08,
        8.4729e-08, 3.8661e-07, 1.8142e-02], grad_fn=<SelectBackward>)
316 :  tensor([6.8013e-02, 7.7561e-01, 1.0029e-01, 3.5521e-03, 4.8333e-07, 2.5514e-07,
        3.9570e-07, 1.8411e-06, 5.2528e-02], grad_fn=<SelectBackward>)
317 :  tensor([4.2367e-02, 8.4743e-01, 1.0257e-01, 2.5063e-03, 4.2669e-07, 2.1079e-07,
        3.2983e-07, 1.4884e-06, 5.1249e-03], grad_fn=<SelectBackward>)
318 :  tensor([5.9058e-03, 6.3813e-01, 3.0409e-01, 7.6198e-03, 3.2689e-07, 1.6220e-07,
        2.4791e-07, 1.1399e-06, 4.4253e-02], grad_fn=<SelectBackward>)
319 :  tensor([2.0961e-03, 8.8391e-01, 1.0003e-01, 3.4111e-03, 1.4898e-07, 7.7201e-08,
        1.1972e-07, 5.5017e-07, 1.0555e-02], grad_fn=<SelectBackward>)
320 :  tensor([1.9436e-02, 8.8799e-01, 6.3845e-02, 5

379 :  tensor([3.2547e-02, 4.4423e-02, 8.5591e-01, 6.1651e-02, 4.3749e-07, 2.1763e-07,
        3.1099e-07, 1.1764e-06, 5.4686e-03], grad_fn=<SelectBackward>)
380 :  tensor([2.9974e-01, 7.3722e-02, 5.3306e-01, 8.6043e-02, 1.0811e-06, 5.3504e-07,
        7.8383e-07, 2.8885e-06, 7.4287e-03], grad_fn=<SelectBackward>)
381 :  tensor([1.9206e-01, 7.4908e-02, 2.8577e-01, 4.2169e-01, 1.0971e-06, 5.4973e-07,
        7.9879e-07, 3.1354e-06, 2.5571e-02], grad_fn=<SelectBackward>)
382 :  tensor([4.5776e-01, 1.3142e-01, 2.9926e-01, 1.1021e-01, 1.2213e-06, 6.2479e-07,
        8.7002e-07, 3.7281e-06, 1.3524e-03], grad_fn=<SelectBackward>)
383 :  tensor([4.6357e-01, 4.8630e-03, 4.2357e-01, 1.0744e-01, 5.3784e-07, 2.6793e-07,
        3.7095e-07, 1.6983e-06, 5.4990e-04], grad_fn=<SelectBackward>)
384 :  tensor([1.4692e-01, 2.4292e-02, 7.6408e-01, 6.4066e-02, 6.0035e-07, 2.9835e-07,
        4.1100e-07, 1.8415e-06, 6.4327e-04], grad_fn=<SelectBackward>)
385 :  tensor([2.1058e-01, 1.7770e-02, 7.4714e-01, 2

443 :  tensor([8.4794e-01, 1.3639e-02, 3.0388e-02, 1.0657e-01, 6.0720e-07, 2.9769e-07,
        4.3730e-07, 1.8477e-06, 1.4664e-03], grad_fn=<SelectBackward>)
444 :  tensor([8.4828e-01, 1.9237e-02, 1.9577e-02, 1.1227e-01, 6.8253e-07, 3.4400e-07,
        5.1298e-07, 1.8889e-06, 6.3200e-04], grad_fn=<SelectBackward>)
445 :  tensor([6.1724e-01, 1.7192e-01, 4.9295e-02, 1.5315e-01, 1.6074e-06, 8.4342e-07,
        1.1451e-06, 4.1330e-06, 8.3850e-03], grad_fn=<SelectBackward>)
446 :  tensor([8.2531e-01, 9.7673e-02, 6.6738e-03, 6.9768e-02, 7.7006e-07, 4.0722e-07,
        5.9305e-07, 2.0561e-06, 5.6975e-04], grad_fn=<SelectBackward>)
447 :  tensor([5.1808e-01, 4.6898e-01, 5.7742e-03, 7.0266e-03, 6.2847e-07, 3.2582e-07,
        4.9081e-07, 1.5373e-06, 1.3428e-04], grad_fn=<SelectBackward>)
448 :  tensor([4.0758e-01, 5.6410e-01, 1.7393e-02, 1.0845e-02, 7.1650e-07, 3.6908e-07,
        5.2686e-07, 1.7124e-06, 8.3165e-05], grad_fn=<SelectBackward>)
449 :  tensor([9.0327e-01, 4.1098e-02, 2.8989e-02, 2

504 :  tensor([9.8840e-01, 9.9177e-04, 3.2225e-05, 1.0537e-02, 5.3880e-08, 3.0648e-08,
        4.6313e-08, 1.8422e-07, 4.0026e-05], grad_fn=<SelectBackward>)
505 :  tensor([7.7086e-01, 6.9130e-02, 5.2839e-02, 1.0385e-01, 1.1774e-06, 6.1975e-07,
        9.0692e-07, 3.8916e-06, 3.3180e-03], grad_fn=<SelectBackward>)
506 :  tensor([2.0065e-01, 3.0347e-01, 2.5393e-01, 2.2236e-01, 1.5888e-06, 8.4184e-07,
        1.1680e-06, 4.8531e-06, 1.9589e-02], grad_fn=<SelectBackward>)
507 :  tensor([4.3810e-01, 5.1801e-01, 1.6914e-02, 2.0850e-02, 9.6929e-07, 5.1586e-07,
        7.3464e-07, 3.2041e-06, 6.1183e-03], grad_fn=<SelectBackward>)
508 :  tensor([7.7796e-01, 2.1731e-01, 1.4483e-03, 2.3443e-03, 4.1185e-07, 2.1945e-07,
        3.3950e-07, 1.4217e-06, 9.4230e-04], grad_fn=<SelectBackward>)
509 :  tensor([9.5623e-01, 3.9358e-02, 2.7028e-03, 1.5982e-03, 2.5798e-07, 1.3035e-07,
        1.9795e-07, 7.6548e-07, 1.0620e-04], grad_fn=<SelectBackward>)
510 :  tensor([8.8345e-01, 7.1544e-02, 5.2383e-03, 3

571 :  tensor([7.5551e-01, 5.6524e-02, 3.3207e-03, 1.8032e-01, 7.1500e-07, 3.8205e-07,
        5.4938e-07, 2.1268e-06, 4.3202e-03], grad_fn=<SelectBackward>)
572 :  tensor([8.1370e-01, 7.1345e-02, 4.9412e-02, 6.3581e-02, 8.2711e-07, 4.4845e-07,
        6.3526e-07, 2.5130e-06, 1.9606e-03], grad_fn=<SelectBackward>)
573 :  tensor([7.3956e-01, 9.9128e-03, 3.2095e-02, 2.1169e-01, 5.9578e-07, 3.1077e-07,
        4.6615e-07, 1.8802e-06, 6.7366e-03], grad_fn=<SelectBackward>)
574 :  tensor([4.8811e-01, 5.4909e-03, 4.2901e-02, 4.4025e-01, 6.3307e-07, 3.3791e-07,
        5.1347e-07, 2.0268e-06, 2.3250e-02], grad_fn=<SelectBackward>)
575 :  tensor([7.7715e-01, 6.3828e-04, 2.1149e-03, 2.0655e-01, 2.4718e-07, 1.2751e-07,
        1.9946e-07, 7.8600e-07, 1.3546e-02], grad_fn=<SelectBackward>)
576 :  tensor([1.4514e-01, 4.5517e-03, 3.2888e-02, 7.1847e-01, 4.7902e-07, 2.4394e-07,
        3.7075e-07, 1.3303e-06, 9.8939e-02], grad_fn=<SelectBackward>)
577 :  tensor([1.3775e-01, 3.8189e-03, 3.0911e-02, 7

641 :  tensor([9.9265e-01, 2.2000e-03, 2.5518e-05, 4.9857e-03, 5.4528e-08, 3.0766e-08,
        5.0127e-08, 1.6865e-07, 1.3941e-04], grad_fn=<SelectBackward>)
642 :  tensor([9.4723e-01, 4.0060e-02, 6.8211e-05, 1.2574e-02, 1.8724e-07, 1.0173e-07,
        1.6508e-07, 5.2875e-07, 6.7610e-05], grad_fn=<SelectBackward>)
643 :  tensor([6.7039e-01, 1.7892e-02, 1.3778e-03, 3.0041e-01, 4.9536e-07, 2.5563e-07,
        4.1577e-07, 1.3209e-06, 9.9223e-03], grad_fn=<SelectBackward>)
644 :  tensor([2.7489e-02, 2.8793e-03, 5.5179e-03, 9.4050e-01, 1.6323e-07, 8.6751e-08,
        1.2959e-07, 4.6260e-07, 2.3609e-02], grad_fn=<SelectBackward>)
645 :  tensor([1.3313e-01, 3.7030e-03, 1.0770e-02, 8.1125e-01, 3.1755e-07, 1.6218e-07,
        2.5214e-07, 8.7161e-07, 4.1150e-02], grad_fn=<SelectBackward>)
646 :  tensor([4.7241e-03, 1.0113e-03, 4.2782e-03, 9.8392e-01, 5.1592e-08, 2.7256e-08,
        3.8070e-08, 1.4942e-07, 6.0685e-03], grad_fn=<SelectBackward>)
647 :  tensor([6.5467e-02, 5.1092e-03, 8.3739e-03, 9

711 :  tensor([9.5520e-01, 1.5483e-04, 9.8546e-04, 4.2243e-02, 1.0738e-07, 5.3431e-08,
        8.7063e-08, 3.0454e-07, 1.4155e-03], grad_fn=<SelectBackward>)
712 :  tensor([7.4175e-01, 1.1072e-03, 4.8727e-03, 2.4656e-01, 2.9377e-07, 1.5110e-07,
        2.3354e-07, 8.1833e-07, 5.7081e-03], grad_fn=<SelectBackward>)
713 :  tensor([3.2553e-01, 1.4067e-03, 3.9501e-03, 6.5850e-01, 2.7799e-07, 1.4624e-07,
        2.2168e-07, 7.7645e-07, 1.0611e-02], grad_fn=<SelectBackward>)
714 :  tensor([6.3641e-01, 8.6555e-04, 2.9876e-03, 3.3928e-01, 2.8449e-07, 1.5137e-07,
        2.3163e-07, 7.7502e-07, 2.0452e-02], grad_fn=<SelectBackward>)
715 :  tensor([8.1336e-01, 1.4651e-03, 4.3493e-04, 1.8292e-01, 1.8892e-07, 9.6578e-08,
        1.4768e-07, 5.3316e-07, 1.8188e-03], grad_fn=<SelectBackward>)
716 :  tensor([5.4917e-01, 5.3550e-03, 8.7412e-04, 4.4253e-01, 3.3521e-07, 1.6437e-07,
        2.6214e-07, 8.7309e-07, 2.0758e-03], grad_fn=<SelectBackward>)
717 :  tensor([9.1865e-02, 1.1033e-02, 1.2703e-02, 8

764 :  tensor([5.1057e-01, 3.6353e-02, 7.6356e-02, 3.3362e-01, 1.0116e-06, 5.1551e-07,
        7.8667e-07, 2.8671e-06, 4.3090e-02], grad_fn=<SelectBackward>)
765 :  tensor([9.0065e-01, 1.9671e-03, 6.8061e-03, 8.4018e-02, 2.8008e-07, 1.5352e-07,
        2.3353e-07, 8.9900e-07, 6.5557e-03], grad_fn=<SelectBackward>)
766 :  tensor([7.2180e-01, 2.6701e-03, 4.5003e-02, 2.0944e-01, 5.4118e-07, 2.9211e-07,
        4.4931e-07, 1.7272e-06, 2.1079e-02], grad_fn=<SelectBackward>)
767 :  tensor([7.0587e-01, 1.6045e-02, 5.5284e-03, 2.6209e-01, 5.2461e-07, 2.9173e-07,
        4.2710e-07, 1.6567e-06, 1.0463e-02], grad_fn=<SelectBackward>)
768 :  tensor([8.8077e-01, 4.1996e-02, 1.1221e-03, 7.4487e-02, 4.0254e-07, 2.3102e-07,
        3.3888e-07, 1.2335e-06, 1.6207e-03], grad_fn=<SelectBackward>)
769 :  tensor([9.0926e-01, 6.5770e-02, 1.9021e-03, 2.2648e-02, 3.9461e-07, 2.1972e-07,
        3.4234e-07, 1.1697e-06, 4.1825e-04], grad_fn=<SelectBackward>)
770 :  tensor([9.6190e-01, 2.3915e-02, 4.8161e-03, 8

821 :  tensor([9.9108e-01, 6.0312e-04, 4.1158e-03, 4.1794e-03, 9.1635e-08, 4.5829e-08,
        6.8020e-08, 2.9379e-07, 2.1677e-05], grad_fn=<SelectBackward>)
822 :  tensor([9.9674e-01, 2.4436e-04, 3.2136e-04, 2.6644e-03, 4.6830e-08, 2.4822e-08,
        3.6961e-08, 1.4950e-07, 2.5241e-05], grad_fn=<SelectBackward>)
823 :  tensor([9.9329e-01, 1.6918e-04, 7.2694e-04, 5.7635e-03, 6.0377e-08, 3.0754e-08,
        4.6541e-08, 1.9373e-07, 5.1398e-05], grad_fn=<SelectBackward>)
824 :  tensor([9.9782e-01, 5.3641e-05, 8.8438e-04, 1.2144e-03, 3.5513e-08, 1.7703e-08,
        2.7541e-08, 1.1289e-07, 3.2329e-05], grad_fn=<SelectBackward>)
825 :  tensor([9.6296e-01, 6.1879e-04, 2.1317e-02, 1.4721e-02, 1.9453e-07, 9.4972e-08,
        1.4169e-07, 5.9874e-07, 3.8484e-04], grad_fn=<SelectBackward>)
826 :  tensor([9.7949e-01, 5.7674e-04, 9.0561e-03, 1.0239e-02, 1.6418e-07, 7.9609e-08,
        1.2202e-07, 5.0420e-07, 6.3595e-04], grad_fn=<SelectBackward>)
827 :  tensor([9.7168e-01, 1.3444e-03, 1.5508e-02, 1

880 :  tensor([9.2477e-01, 3.3274e-02, 9.0192e-04, 3.5271e-02, 4.1574e-07, 2.2575e-07,
        3.6243e-07, 1.1575e-06, 5.7813e-03], grad_fn=<SelectBackward>)
881 :  tensor([4.4710e-01, 2.4889e-01, 9.6713e-03, 2.8986e-01, 9.5362e-07, 5.3654e-07,
        7.6256e-07, 2.7650e-06, 4.4754e-03], grad_fn=<SelectBackward>)
882 :  tensor([3.8882e-01, 3.1636e-01, 6.9520e-03, 2.8201e-01, 9.5501e-07, 5.3448e-07,
        7.5970e-07, 2.7227e-06, 5.8575e-03], grad_fn=<SelectBackward>)
883 :  tensor([3.6020e-01, 3.3852e-01, 8.4353e-03, 2.8564e-01, 1.0078e-06, 5.6084e-07,
        7.9652e-07, 2.8591e-06, 7.2001e-03], grad_fn=<SelectBackward>)
884 :  tensor([9.5876e-01, 2.6902e-02, 3.0503e-03, 1.1128e-02, 2.6075e-07, 1.3092e-07,
        2.0043e-07, 7.8459e-07, 1.5672e-04], grad_fn=<SelectBackward>)
885 :  tensor([9.9933e-01, 5.4123e-04, 7.3216e-05, 5.1847e-05, 1.9019e-08, 9.3057e-09,
        1.5714e-08, 5.3283e-08, 7.7097e-06], grad_fn=<SelectBackward>)
886 :  tensor([9.9048e-01, 1.3444e-03, 7.2272e-03, 8

939 :  tensor([8.0995e-01, 1.3764e-02, 2.9137e-02, 1.2880e-01, 6.5912e-07, 3.3718e-07,
        5.2319e-07, 1.8276e-06, 1.8345e-02], grad_fn=<SelectBackward>)
940 :  tensor([2.3954e-01, 2.9437e-02, 4.6159e-01, 2.6075e-01, 9.1201e-07, 4.7015e-07,
        6.5222e-07, 2.5188e-06, 8.6852e-03], grad_fn=<SelectBackward>)
941 :  tensor([5.0294e-01, 2.5186e-01, 1.7625e-01, 6.8056e-02, 1.1688e-06, 6.1960e-07,
        8.3758e-07, 3.2569e-06, 8.9056e-04], grad_fn=<SelectBackward>)
942 :  tensor([9.8167e-01, 1.2369e-02, 2.1292e-04, 5.6822e-03, 1.2596e-07, 7.1095e-08,
        1.0656e-07, 3.9225e-07, 6.5611e-05], grad_fn=<SelectBackward>)
943 :  tensor([9.9948e-01, 4.4968e-04, 5.2619e-06, 6.4782e-05, 1.1066e-08, 5.9277e-09,
        1.0025e-08, 3.3679e-08, 1.2191e-06], grad_fn=<SelectBackward>)
944 :  tensor([9.9744e-01, 1.8385e-03, 5.3978e-06, 6.8486e-04, 2.9793e-08, 1.6678e-08,
        2.8726e-08, 9.1420e-08, 3.4028e-05], grad_fn=<SelectBackward>)
945 :  tensor([9.1662e-01, 5.2370e-02, 1.3057e-04, 3

998 :  tensor([9.1026e-01, 2.8993e-02, 1.4635e-03, 5.8747e-02, 3.2560e-07, 1.8247e-07,
        2.6666e-07, 9.6667e-07, 5.3517e-04], grad_fn=<SelectBackward>)
999 :  tensor([9.9932e-01, 3.5587e-05, 1.9261e-05, 5.9648e-04, 1.2911e-08, 6.6350e-09,
        1.0943e-08, 3.7519e-08, 2.5126e-05], grad_fn=<SelectBackward>)


In [22]:
Y_train = Y_train.reshape(-1, len(ACTIONS_LIST))

In [23]:
acertou = 0
errou = 0
# for i in range(end_frame - start_frame + 1):
for i in range(1000):
    if np.argmax(Y_train[i]) == torch.argmax(out_arr[i][0]):
        acertou += 1
    else:
        errou += 1 

In [24]:
acertou

405

In [25]:
errou

595

In [26]:
acertou/(acertou + errou)

0.405

## Play Gym Enduro

In [19]:
import gym
import time
from PIL import Image

In [20]:
if zigzag:
        
    ACTIONS = {
        "right": 2,
        "left": 3,
    }

else:

    ACTIONS = {
        "noop": 0,
        "accelerate": 1,
        "right": 2,
        "left": 3,
        "break": 4,
        "right_break": 5,
        "left_break": 6,
        "right_accelerate": 7,
        "left_accelerate": 8,
    }

In [21]:
y_min, y_max, x_min, x_max = 25+30, 195-40, 20, 140
shape_of_single_frame = (1, (y_max-y_min),(x_max-x_min))

In [22]:
sleep_time = 0.05

In [30]:
env = gym.make("Enduro-v0")
frame = env.reset()
reward, action, done, info = 0, 0, False, {'ale.lives': 0}

hx = torch.zeros(1, hidden_neurons)
cx = torch.zeros(1, hidden_neurons)

env.render()

action_list = []
reward_list = []

for _ in range(4459):
    
    time.sleep(sleep_time)
    env.render()
    
    frame = frame[y_min:y_max, x_min:x_max]

    frame = Image.fromarray(frame)
    frame = frame.convert("L")
    
    frame = np.asarray(frame)
    frame = frame.reshape(1, -1)
    frame = torch.tensor(frame)/255
    
    hx, cx = lstmcell(frame, (hx, cx))
    out = linear(hx)
    action = output(out)
    
    action = list(ACTIONS.values())[torch.argmax(action, axis=1)]
    frame, reward, done, info = env.step(action)
    print(reward)
    
    action_list.append(action)
    reward_list.append(reward)
    
np.savez_compressed(dir_path, np.array(action_list))
np.savez_compressed(dir_path , np.array(reward_list))


  action = output(out)


0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
-1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
-1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0


0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0


In [31]:
np.array(reward_list).sum()

0.0

In [34]:
count = 0
for i in reward_list:
    if int(i) == -1:
        count+=1

In [35]:
count

31