In [None]:
def BoltzmanPolicy(preds, temp):
    probs = torch.softmax(preds / temp, dim=1)
    a = torch.distributions.Categorical(probs).sample().item()
    return a, probs[0][a]


In [None]:
def reinforce(episodes, lr, gamma, T, decay, decay_rate, end_temp):
    rewards = []
    q_network = QNetwork(state_dim, action_dim)
    optimizer = optim.Adam(q_network.parameters(), lr=lr)
    temp = T
    end_temp = end_temp
    for i in range(episodes):
        if decay:
            temp = T - ((T - end_temp) * i / episodes)
        total_reward = 0
        if i % 100 == 0:
            print("EPISODE#", i)
        ep = []
        state, info = env.reset()
        done = False
        step = 0
        while not done and step < 1000:
            action, prob = BoltzmanPolicy(q_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)), temp)
            next_state, reward, done, truncated, _= env.step(action)
            total_reward += reward
            ep.append((state, action, reward, prob))
            state = next_state
            step += 1
        grad = 0
        for t in range(len(ep)):
            G = sum([gamma**(i-t-1) * ep[i][2] for i in range(t, len(ep))])
            # print(G)
            grad += (gamma**t) * G * torch.log(max(prob,torch.tensor(1e-8)))
        optimizer.zero_grad()
        grad.backward()
        optimizer.step()
        rewards.append(total_reward)
    return rewards

In [None]:
def A2C(episodes, gamma, policy_lr,value_lr, T, decay, decay_rate, end_temp):
    policy_network = QNetwork(state_dim, action_dim)
    value_network = QNetwork(state_dim, 1)
    optimizer_actor = optim.Adam(policy_network.parameters(), lr=0.0001)
    optimizer_value = optim.Adam(value_network.parameters(), lr=0.001)
    rewards = []
    temp = T
    end_temp = end_temp
    for i in range(episodes):
        if decay:
            temp = T - ((T - end_temp) * i / episodes)
        total_reward = 0
        if i % 100 == 0:
            print("EPISODE#", i)
        state, info = env.reset()
        done = False
        step = 0
        policy_loss = 0
        value_loss = 0
        while not done  and step < 1000:
            with torch.no_grad():
                action, prob = BoltzmanPolicy(policy_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)), temp)
            next_state, reward, done, truncated, _ = env.step(action)
            total_reward += reward
            if done:
                advantage = reward - value_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)).detach()
            else:
                advantage = reward + gamma * value_network(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(device)).detach() - value_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)).detach()
            preds = policy_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device))[0]
            policy_loss -= torch.log(torch.exp(preds[action]/T)/torch.sum(torch.exp(preds/T))) * advantage
            value_loss +=  0.5*(reward + gamma * value_network(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(device)).detach() - value_network(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)))**2
            state = next_state
            step += 1
        mean_policy_loss = policy_loss / step
        mean_value_loss = value_loss / step
        optimizer_actor.zero_grad()
        optimizer_value.zero_grad()
        mean_policy_loss.backward()
        mean_value_loss.backward()
        optimizer_actor.step()
        optimizer_value.step()
        rewards.append(total_reward)
    return rewards


In [None]:
def run_experiment2(lr, T, decay, decay_rate, gamma, episodes, seeds):
    random_seeds = np.random.randint(0, 50, size=seeds)
    a2c_learning_rewards = np.zeros((seeds, EPISODES))
    reinforce_rewards = np.zeros((seeds, EPISODES))
    seeds = seeds
    for i in range(seeds):
        print(f'Run {i+1}/{seeds}')
        torch.manual_seed(i)
        np.random.seed(i)
        a2c_learning_rewards[i] = A2C(episodes=episodes, gamma=gamma, policy_lr=lr, value_lr=lr, T=T, decay=decay, decay_rate=decay_rate, end_temp=decay_rate)
        reinforce_rewards[i] = reinforce(episodes=episodes, lr=lr, gamma=gamma, T=T, decay=decay, decay_rate=decay_rate, end_temp=decay_rate)
        pickle.dump(a2c_learning_rewards[i], open(f'RUN#{i}_ASST_a2c_learning_rewards_{lr}_{T}_{decay}_{decay_rate}.pkl', 'wb'))
        pickle.dump(reinforce_rewards[i], open(f'RUN#{i}_ASSOT_reinforce_rewards_{lr}_{T}_{decay}_{decay_rate}.pkl', 'wb'))
    a2c_learning_rewards_mean = a2c_learning_rewards.mean(axis=0)
    a2c_learning_rewards_std = a2c_learning_rewards.std(axis=0)/math.sqrt(seeds)
    reinforce_rewards_mean = reinforce_rewards.mean(axis=0)
    reinforce_rewards_std = reinforce_rewards.std(axis=0)/math.sqrt(seeds)
    plt.plot(a2c_learning_rewards_mean, label='A2C', color='green')
    plt.fill_between(range(EPISODES), a2c_learning_rewards_mean - a2c_learning_rewards_std, a2c_learning_rewards_mean + a2c_learning_rewards_std, color='green', alpha=0.2)
    plt.plot(reinforce_rewards_mean, label='Reinforce', color='red')
    plt.fill_between(range(EPISODES), reinforce_rewards_mean - reinforce_rewards_std, reinforce_rewards_mean + reinforce_rewards_std, color='red', alpha=0.2)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.legend()
    plt.savefig('plot.png')
    plt.show()

In [None]:
T = [[2, False, 1]]
import pickle
#REINFORCE
lr = [1e-4]
SEEDS = 3
for stuff in T:
    run_experiment2(lr=lr[0], T=stuff[0], decay=stuff[1], decay_rate=stuff[2], gamma=0.99, episodes=1000, seeds=SEEDS)

In [None]:
import pickle
epsilons = [0.25, 0.125, 0.0625]
lrs = [1/4, 1/8, 1/16]
replay_buffers = [(1, 1), (32, 1000000)]
for epsilon in epsilons:
    for lr in lrs:
        for replay_buffer in replay_buffers:
            run_experiment(lr, replay_buffer[0], replay_buffer[1], epsilon)
            time.sleep(1)
            os.system('mv plot.png plots/{}_{}_{}_{}.png'.format(env_name, epsilon, lr, replay_buffer[0]))
            time.sleep(1)
            os.system('rm plot.png')
            time.sleep(1)