In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import FloatTensor, LongTensor

import gym
np.random.seed(42)

from tqdm import tqdm_notebook as tqdm
%load_ext autoreload
%autoreload 2

In [2]:
env = gym.make('LunarLander-v2')
env.reset()

array([ 0.0023427 ,  1.4012029 ,  0.23728529, -0.43186894, -0.00270793,
       -0.05374873,  0.        ,  0.        ], dtype=float32)

In [3]:
action_to_i = {
    'nope':0,
    'left':1,
    'main':2,
    'right':3
}

In [4]:
class Lander(nn.Module):
    def __init__(self, env, num_env_actions, num_env_variables):
        super().__init__()
        self.env = env
        self.num_env_actions = num_env_actions
        self.num_env_variables = num_env_variables
        self.memoryX = np.zeros(1)
        self.memoryY = np.zeros(1)
        self.criterion = nn.MSELoss()
        self.optimizer = None
        
        self._dence_1 =nn.Sequential( nn.Linear(num_env_actions + num_env_variables, 512), nn.ReLU() )
        self._dence_2 =nn.Sequential( nn.Linear(512, 256), nn.ReLU() )
        self._dence_3 =nn.Sequential( nn.Linear(256, 256), nn.ReLU() )
        self._out = nn.Linear(256, 1)

    def forward(self, inputs):
        x = self._dence_1(inputs)
        x = self._dence_2(x)
        x = self._dence_3(x)
        x = self._out(x)
        return x
    
    def one_hot(self, action):
        result = np.zeros(self.num_env_actions)
        result[action] = 1
        return result

    def predict_total_reward(self, state, action):
        state_action = np.concatenate((state, self.one_hot(action)), axis = 0)
        total_reward = self.forward(FloatTensor(state_action.reshape(1, -1))).flatten()
        return total_reward

    def random_action(self):
        return self.env.action_space.sample()

    def normal_action(self, state):
        possible_action = np.array([self.predict_total_reward(state, i) for i in range(self.num_env_actions)])
        return np.argmax(possible_action)

    def get_action(self, state, explore = 0.5, observe = False):
        if observe:
            return self.random_action()
        prob = np.random.rand(1)
        if prob > explore:
            return self.random_action()
        return self.normal_action(state)

    def play_game(self, observe = False, explore = 0.5, gamma = 0.99, render = False):
        gameX = np.zeros(shape=(1,num_env_variables+num_env_actions))
        gameY = np.zeros(shape=(1,1))
        state = self.env.reset()
        done = False
        step = 0
        total_reward = 0
        while not done:
            action = self.get_action(state, explore, observe)
            if render:
                env.render()
            state_action = np.concatenate((state, self.one_hot(action)), axis = 0)
            state, reward, done, info = self.env.step(action)
            total_reward += reward
            if step == 0:
                gameX[0] = state_action
                gameY[0] = np.array([reward])
            else:
                gameX = np.vstack((gameX, state_action))
                gameY = np.vstack((gameY, np.array([reward])))
            step += 1
        for i in range(0,gameY.shape[0]):
            if i > 0:
                gameY[(gameY.shape[0]-1)-i][0] = gameY[(gameY.shape[0]-1)-i][0] + gamma * gameY[(gameY.shape[0]-1) - i + 1][0]
        return np.array(gameX), np.array(gameY), total_reward
    
    def update_memory(self, gameX, gameY):
        if self.memoryX.shape[0] == 1:
            self.memoryX = gameX
            self.memoryY = gameY
        else:
            self.memoryX = np.concatenate((self.memoryX, gameX), axis=0)
            self.memoryY = np.concatenate((self.memoryY, gameY), axis=0)
        max_memory = 6000
        if len(self.memoryX > max_memory):
            self.memoryX = self.memoryX[-max_memory:]
            self.memoryY = self.memoryY[-max_memory:]
            
    def train(self):
        if self.optimizer is None:
            self.optimizer = optim.Adam(self.parameters(), lr=1e-5)
        with torch.autograd.set_grad_enabled(True):
            for i in range(50):
                batch_ids = np.random.choice(len(self.memoryX), 256)
                batch = self.memoryX[batch_ids]
                target = self.memoryY[batch_ids]
                prediction = self.forward(FloatTensor(batch))
                loss = self.criterion(prediction, FloatTensor(target))
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            
    
    def train_model(self, steps, observe=False, explore=0.5, gamma=0.99, render=False):
        with tqdm(total=steps) as pbar:
            for _ in range(steps):
                gameX, gameY, _ = self.play_game(observe, explore, gamma, render)
                self.update_memory(gameX, gameY)
                if not observe:
                    self.train()
                pbar.update()

In [5]:
num_env_variables = 8
num_env_actions = 4

In [6]:
model = Lander(env, num_env_actions, num_env_variables)

In [7]:
model.train_model(2000, observe=True)

HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




In [8]:
model.train_model(2000, explore = 0.5, render=False)

HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




In [9]:
model.train_model(2000, explore = 0.9, render=False)

HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




In [12]:
for i in range(10):
    game_x, game_y = model.play_game(explore = 1, render = True)
    print(game_y.flatten[-1])

[[  77.99316625]
 [  77.8222176 ]
 [  77.98608659]
 [  77.8018671 ]
 [  77.41614228]
 [  77.01890004]
 [  76.37387146]
 [  75.62257697]
 [  74.56352714]
 [  74.40146682]
 [  75.3158952 ]
 [  76.37917304]
 [  74.79270237]
 [  74.74272811]
 [  74.29158209]
 [  74.34006985]
 [  73.94841603]
 [  73.24766137]
 [  70.19988867]
 [  68.7726641 ]
 [  68.17214981]
 [  67.47498433]
 [  67.198371  ]
 [  66.5721523 ]
 [  66.67799268]
 [  66.33425535]
 [  64.66086537]
 [  62.0112798 ]
 [  61.5058051 ]
 [  61.40826106]
 [  57.86855043]
 [  58.0891107 ]
 [  60.89648447]
 [  58.62366635]
 [  58.56405579]
 [  58.0040083 ]
 [  60.59176645]
 [  55.87990089]
 [  58.22593195]
 [  56.13166773]
 [  58.40014849]
 [  56.40379961]
 [  58.43094868]
 [  56.06550336]
 [  51.03992995]
 [  52.96107232]
 [  54.62818645]
 [  50.72042773]
 [  52.27500263]
 [  46.54298494]
 [  47.77708181]
 [  49.31890993]
 [  45.04227935]
 [  46.36646223]
 [  41.16635551]
 [  42.24356243]
 [  43.25736541]
 [  40.17141086]
 [  36.5920178

 [-100.        ]]
[[226.32932255]
 [224.9010313 ]
 [223.43162174]
 [221.79084828]
 [219.9464963 ]
 [218.60891363]
 [217.89286627]
 [217.38456826]
 [216.92940803]
 [217.23069415]
 [217.64328763]
 [218.60515384]
 [219.95149696]
 [222.63668131]
 [226.32574122]
 [228.58382797]
 [232.39529188]
 [236.17211841]
 [238.84271979]
 [242.60353345]
 [246.00696721]
 [249.24930994]
 [252.26588665]
 [255.00906916]
 [257.49202457]
 [259.76216893]
 [258.72252779]
 [255.90407424]
 [257.98495844]
 [255.36442512]
 [257.37661357]
 [251.58920808]
 [253.03460147]
 [254.57766408]
 [256.75780528]
 [254.81478702]
 [256.73431354]
 [251.66732301]
 [253.18182254]
 [255.00743701]
 [253.15633752]
 [251.36558789]
 [253.0429458 ]
 [250.51229217]
 [246.09417547]
 [247.77350903]
 [249.52883777]
 [248.75258836]
 [244.77413869]
 [246.46558307]
 [248.06783628]
 [243.58275892]
 [245.32441549]
 [246.78615319]
 [245.7035043 ]
 [241.31492336]
 [243.09544822]
 [244.53982732]
 [239.4352036 ]
 [240.76439683]
 [238.56399061]
 [240.

 [-100.        ]]
[[198.33894228]
 [200.72806195]
 [203.6795022 ]
 [206.79155371]
 [210.28602172]
 [213.9589657 ]
 [217.75329064]
 [221.85391825]
 [225.17883109]
 [228.25331264]
 [231.90357133]
 [234.91396814]
 [237.6781422 ]
 [240.94818941]
 [243.76159743]
 [246.27676185]
 [249.05540546]
 [251.45676866]
 [253.47935112]
 [255.28932499]
 [256.79015272]
 [256.09071885]
 [251.4204793 ]
 [252.91046063]
 [248.93889796]
 [250.85289656]
 [252.15128548]
 [251.01627843]
 [249.88447504]
 [245.0231985 ]
 [246.40030723]
 [244.80574133]
 [244.08894481]
 [239.88958102]
 [241.59603369]
 [242.64334462]
 [240.14449982]
 [241.84558656]
 [242.83769799]
 [240.12963296]
 [239.01586503]
 [234.12998743]
 [235.76871748]
 [236.64630856]
 [232.3968459 ]
 [233.94699588]
 [234.63743509]
 [230.9859597 ]
 [231.599092  ]
 [227.70346253]
 [222.17050986]
 [223.65560012]
 [224.36771292]
 [220.09099867]
 [218.07968444]
 [219.61773742]
 [220.31279277]
 [215.99334348]
 [211.96191914]
 [210.61856262]
 [208.68765946]
 [205.

 [100.        ]]
[[240.45150522]
 [241.18204854]
 [242.6482365 ]
 [243.57569963]
 [244.85055307]
 [245.5561192 ]
 [246.61153716]
 [246.95476232]
 [247.95329528]
 [249.39568783]
 [250.58527034]
 [252.17625282]
 [251.41140664]
 [249.15796219]
 [246.39641568]
 [248.1144238 ]
 [244.56093827]
 [246.26558525]
 [242.69377014]
 [243.22301461]
 [242.47179603]
 [242.47845753]
 [241.00835434]
 [239.05905206]
 [235.88143743]
 [235.17884013]
 [231.51857827]
 [233.69968187]
 [230.81733235]
 [228.90247741]
 [227.64175453]
 [226.15975083]
 [228.71643379]
 [226.23211736]
 [221.55007247]
 [223.99479075]
 [222.6632909 ]
 [222.75377874]
 [222.06009645]
 [219.31519414]
 [221.53933375]
 [217.81307152]
 [219.90905074]
 [219.31299251]
 [218.32439496]
 [217.69582476]
 [216.21294563]
 [213.7636744 ]
 [211.14076167]
 [212.90148233]
 [208.9426562 ]
 [207.89332574]
 [206.54645919]
 [207.85242772]
 [204.68316092]
 [199.48181634]
 [200.77421639]
 [198.73085395]
 [194.64565761]
 [195.96577763]
 [193.24813487]
 [190.7

 [100.        ]]
