In [13]:
import gym, os
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import numpy as np
import sklearn
%load_ext autoreload
%autoreload 2

# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.linear1 = nn.Linear(state_size, 128)
        self.linear3 = nn.Linear(128, action_size)

    def forward(self, state):
        output = F.relu((self.linear1(state)))
        output = self.linear3(output)
        distribution = F.softmax(output, dim=1)
        return distribution


class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.linear1 = nn.Linear(state_size, 128)
        self.linear3 = nn.Linear(128, 1)

    def forward(self, state):
        output = F.relu((self.linear1(state)))
        # output = F.relu(self.linear2(output))
        value = self.linear3(output)
        return value

def pick_action(actor, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    action_probs = actor(state)
    
    dist = Categorical(action_probs)
    action = dist.sample()

    return action.item(), dist.log_prob(action)

In [26]:
gamma = 0.99
num_episodes = 1000
num_steps = 10000
env = gym.make("MountainCar-v0")
# env = gym.envs.make("MountainCarContinuous-v0")

# env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

count_episode = range(1,num_episodes+1)
count_ca = []
count_tca = []

for k in range(1):
    actor = Actor(state_size=state_size, action_size=action_size)
    critic = Critic(state_size)

    
    count_actions = []
    total_count_actions = []
    total_a = 0

    for episode in range(num_episodes):
        actor_optim = optim.SGD(actor.parameters(),lr=0.001, momentum=0.1)
        critic_optim = optim.SGD(critic.parameters(), lr=0.001, momentum=0.1)
        state = env.reset()
        isTerminal = False
        score = 0
        count_a = 0
        
        # while isTerminal != True:
        for i in range(num_steps): 
            count_a += 1
            
            
            #Pick action
            action, log_prob = pick_action(actor, state=state)
            state_prime, reward, isTerminal, info = env.step(action)

            state_tensor = torch.from_numpy(state).float().unsqueeze(0)
            v_curr = critic(state_tensor)

            state_prime_tensor = torch.from_numpy(state_prime).float().unsqueeze(0)
            v_next = critic(state_prime_tensor)

            if state_prime[0] >= 0.5:
                print(f'Num episodes {episode}, num actions {count_a} {isTerminal}')
                v_next = torch.tensor([0]).float().unsqueeze(0)
            # if isTerminal:
            #     print(f'Num episodes {episode}, num actions {i} {isTerminal}')
            
            
            
            # if isTerminal:
            #     v_next = torch.tensor([0]).float().unsqueeze(0)

            if state_prime[0] >= 0.5:
                reward += 100
            td_target = reward + gamma * v_next
            td_error = reward + gamma*v_next.item()-v_curr.item()
            critic_loss = F.mse_loss(td_target,v_curr)
            actor_loss = -log_prob * td_error

            print(actor_loss)
            print(critic_loss)
            actor_optim.zero_grad()
            actor_loss.backward(retain_graph=True)
            actor_optim.step()

            
            # print(critic_loss)
            critic_optim.zero_grad()
            critic_loss.backward()
            critic_optim.step()


            state = state_prime

            # print(f'Actor loss is {actor_loss} and critic loss is {critic_loss}')
            if state_prime[0] >= 0.5:
                break
            # print(state)
        print(f'This is the count {count_a} for episode {episode}')
        count_actions.append(count_a)
        total_a += count_a
        total_count_actions.append(total_a)
        
    # torch.save(actor, f'actor{k}.pkl')
    # torch.save(critic, f'critic{k}.pkl')
    env.close()
    count_ca.append(count_actions)
    count_tca.append(total_count_actions)


avg_ca = np.array(count_ca)
avg_ca = np.average(count_ca, axis=0)
plt.figure()
plt.title('Count of Episodes vs Count of Actions')
plt.xlabel('Count of Episodes')
plt.ylabel('Count of Actions')
plt.plot(count_episode, avg_ca)
# plt.savefig('count_actions_ac.jpg')
plt.show()

avg_tca = np.array(count_tca)
avg_tca = np.average(count_tca, axis=0)
plt.figure()
plt.title('Total Actions vs Count of Episodes ')
plt.ylabel('Count of Episodes')
plt.xlabel('Total Count of Actions')
plt.plot(avg_tca, count_episode)
# plt.savefig('total_actions_ac.jpg')
plt.show()

tensor([-1.1182], grad_fn=<MulBackward0>)
tensor(1.0048, grad_fn=<MseLossBackward0>)
tensor([-1.0334], grad_fn=<MulBackward0>)
tensor(1.0044, grad_fn=<MseLossBackward0>)
tensor([-1.1400], grad_fn=<MulBackward0>)
tensor(1.0039, grad_fn=<MseLossBackward0>)
tensor([-1.0375], grad_fn=<MulBackward0>)
tensor(1.0042, grad_fn=<MseLossBackward0>)
tensor([-1.1461], grad_fn=<MulBackward0>)
tensor(1.0038, grad_fn=<MseLossBackward0>)
tensor([-1.0420], grad_fn=<MulBackward0>)
tensor(1.0042, grad_fn=<MseLossBackward0>)
tensor([-1.0507], grad_fn=<MulBackward0>)
tensor(1.0042, grad_fn=<MseLossBackward0>)
tensor([-1.0960], grad_fn=<MulBackward0>)
tensor(1.0046, grad_fn=<MseLossBackward0>)
tensor([-1.1429], grad_fn=<MulBackward0>)
tensor(1.0038, grad_fn=<MseLossBackward0>)
tensor([-1.0512], grad_fn=<MulBackward0>)
tensor(1.0041, grad_fn=<MseLossBackward0>)
tensor([-1.0952], grad_fn=<MulBackward0>)
tensor(1.0046, grad_fn=<MseLossBackward0>)
tensor([-1.0556], grad_fn=<MulBackward0>)
tensor(1.0042, grad_fn=

KeyboardInterrupt: 

In [24]:
env.action_space

Box([-1.], [1.], (1,), float32)

In [24]:
import gym
# env = gym.make('MountainCar-v0')
env.reset()

while not isTerminal:
    action = torch.argmax(actor(state))
    state_prime, reward, isTerminal, info = env.step(action.item())
    state_prime = torch.FloatTensor(state_prime)
    state = state_prime
    # env.step(env.action_space.sample())
    env.render()