In [66]:
import gym
import random
import math
import numpy as np
import matplotlib

import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import matplotlib.pyplot as plt

class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        
        self.a2b = nn.Linear(4, 128)
        self.b2c = nn.Linear(128, 128)
        self.c2d = nn.Linear(128, 128)
        self.d2e = nn.Linear(128, 128)
        self.e2f = nn.Linear(128, 1)
        
    def forward(self, data):
        hidden = torch.sigmoid( self.a2b(data) )
        hidden = torch.sigmoid( self.b2c(hidden) )
        hidden = torch.sigmoid( self.c2d(hidden) )
        hidden = torch.relu( self.d2e(hidden) )
        return self.e2f(hidden)
    
action0 = NN()
action1 = NN()

env = gym.make('CartPole-v0')

def sigmoid(x):
    if x>2 or x<-2:
        print(x)
    return 0 if x<-30 else (1 if x>30 else 1.0/(1+math.exp(-x)) )
def sigmoid_der(x):
    return sigmoid(x) * (1 - sigmoid(x))

sigmoid = np.vectorize(sigmoid)
sigmoid_der = np.vectorize(sigmoid)

def arrange_state(state):
    #return torch.tensor(list(state))
    #state[0] = sigmoid(state[0]/500)
    #state[1] = sigmoid(state[1]/100)
    #state[2] = sigmoid((state[2]%(2*math.pi)-math.pi)/math.pi)
    #state[3] = sigmoid(state[3]/30)
    return torch.tensor(list(state))

alpha = 0.001
gamma = 0.95
epsilon = 1

minibatch_size = 10
minibatch0 = []
minibatch1 = []

Q_values = []

for i_episode in range(1000):
    state = env.reset()
    state = arrange_state(state)
    
    Q = 0
    
    epsilon *= 0.99
    
    for i_step in range(500):
        
        if random.random() < epsilon:
            next_action = random.randint(0,1)
        else:
            next_action = 0 if action0(state) > action1(state) else 1
        
        next_state, reward, done, info = env.step(next_action)
        next_state = arrange_state(next_state)
        
        Q += reward
        
        x = action0(next_state)
        y = action1(next_state)
        max_qsa = torch.max( x, y )
        
        goal = reward + gamma * max_qsa
        
        if next_action == 0:
            
            minibatch0.append( (state, goal) )
            
            if len(minibatch0) >= minibatch_size:
                
                states = torch.empty(minibatch_size, 4, dtype=torch.float)
                goals = torch.empty(minibatch_size, 1, dtype=torch.float)
                
                ind = 0
                for x in minibatch0:
                    states[ind] = x[0]
                    goals[ind] = x[1]
                    ind+=1
                
                minibatch0 = []
                
                outputs = action0(states)
                
                criterion = nn.MSELoss()
                loss = criterion(outputs, goals)
                action0.zero_grad()
                loss.backward()
                for f in action0.parameters():
                    f.data.sub_(f.grad.data * alpha)
                    f.data[torch.lt(f.data,-100)]=-100
                    f.data[torch.gt(f.data,+100)]=+100
        else:
            
            minibatch1.append( (state, goal) )
            
            if len(minibatch1) >= minibatch_size:
                
                states = torch.empty(minibatch_size, 4, dtype=torch.float)
                goals = torch.empty(minibatch_size, 1, dtype=torch.float)
                
                ind = 0
                for x in minibatch1:
                    states[ind] = x[0]
                    goals[ind] = x[1]
                    ind+=1
                
                minibatch1 = []
                
                outputs = action1(states)
                
                criterion = nn.MSELoss()
                loss = criterion(outputs, goals)
                action1.zero_grad()
                loss.backward()
                for f in action1.parameters():
                    f.data.sub_(f.grad.data * alpha)
                    f.data[torch.lt(f.data,-100)]=-100
                    f.data[torch.gt(f.data,+100)]=+100
        
        state = next_state
    
    print("Round: "+i_episode.__str__() + ", Q: "+Q.__str__())
    Q_values += [Q]
    
    ok = False
    if i_episode>=50:
        ok = True
        for i in range(i_episode-50, i_episode):
            if Q_values[i] != 500:
                ok = False
                break
    if ok:
        break
       
    env.close()
    
plt.plot(range(1, len(Q_values)+1), Q_values)
plt.show()


Round: 0, Q: 14.0
Round: 1, Q: 45.0
Round: 2, Q: 11.0
Round: 3, Q: 20.0
Round: 4, Q: 31.0
Round: 5, Q: 31.0
Round: 6, Q: 22.0
Round: 7, Q: 19.0
Round: 8, Q: 30.0
Round: 9, Q: 18.0
Round: 10, Q: 21.0
Round: 11, Q: 11.0
Round: 12, Q: 21.0
Round: 13, Q: 14.0
Round: 14, Q: 40.0
Round: 15, Q: 15.0
Round: 16, Q: 19.0
Round: 17, Q: 41.0
Round: 18, Q: 41.0
Round: 19, Q: 13.0
Round: 20, Q: 12.0
Round: 21, Q: 16.0
Round: 22, Q: 21.0
Round: 23, Q: 12.0
Round: 24, Q: 20.0
Round: 25, Q: 16.0
Round: 26, Q: 12.0
Round: 27, Q: 22.0
Round: 28, Q: 18.0
Round: 29, Q: 11.0
Round: 30, Q: 17.0
Round: 31, Q: 11.0
Round: 32, Q: 20.0
Round: 33, Q: 11.0
Round: 34, Q: 15.0
Round: 35, Q: 11.0
Round: 36, Q: 44.0
Round: 37, Q: 32.0
Round: 38, Q: 30.0
Round: 39, Q: 18.0
Round: 40, Q: 28.0
Round: 41, Q: 24.0
Round: 42, Q: 73.0
Round: 43, Q: 25.0
Round: 44, Q: 14.0
Round: 45, Q: 21.0
Round: 46, Q: 18.0
Round: 47, Q: 11.0
Round: 48, Q: 22.0
Round: 49, Q: 24.0
Round: 50, Q: 12.0
Round: 51, Q: 12.0
Round: 52, Q: 13.0
Rou

KeyboardInterrupt: 