In [2]:
import gym
import numpy as np
import torch
import matplotlib.pyplot as plt
import time
from statistics import mean

In [3]:
from gym.wrappers import Monitor

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [5]:
import math
import copy
from torch.distributions import Categorical
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
#Hyper-parameters
max_episodes = 1000          # max training episodes
max_timesteps = 250          # max timesteps in one episode
gamma = 0.01                # discount factor
gamma1 = 0.99
epsilon = 0.2                #need to change it to max(advantage)
dkl=1                       #need to change it to KL divergence between old and new policies
Q_r=[]
Q_r1=[]
a=[]
Q=np.zeros((max_timesteps,2))
mutation_power = 0.02#hyper-parameter, set from https://arxiv.org/pdf/1712.06567.pdf
#print(Q)

In [7]:
class CartPoleAI(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Sequential(
                        nn.Linear(4,128, bias=True),
                        nn.ReLU(),
                        nn.Linear(128,2, bias=True),
                        nn.Softmax(dim=1)
                        )

                
        def forward(self, inputs):
            x = self.fc(inputs)
            return x

In [8]:
def init_weights(m):
    
        # nn.Conv2d weights are of shape [16, 1, 3, 3] i.e. # number of filters, 1, stride, stride
        # nn.Conv2d bias is of shape [16] i.e. # number of filters
        
        # nn.Linear weights are of shape [32, 24336] i.e. # number of input features, number of output features
        # nn.Linear bias is of shape [32] i.e. # number of output features
        
        if ((type(m) == nn.Linear) | (type(m) == nn.Conv2d)):
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.00)

In [9]:
def behavioural_policy(agents):
    return agent

In [10]:
def KL_divergence(agent1,agent2):
    return KL

In [11]:
def return_random_agents(num_agents):
    
    agents = []
    for _ in range(num_agents):
        
        agent = CartPoleAI()
        
        for param in agent.parameters():
            param.requires_grad = False
            
        init_weights(agent)
        agents.append(agent)
        
        
    return agents

In [12]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

In [13]:
def get_output_probabilities(agent):

  # We sample the agent for several nummber of iterations

  number_of_iterations = 1000
  val = [0 , 0]

  for itr in range(number_of_iterations):
    input = torch.randn(128,4)
    probability = agent(input).detach().numpy()[0]
    # print("itr = ",itr," value = ",probability)
    for i in range(len(probability)):
      val[i] += probability[i]

  for i in range(len(probability)):
    val[i] /= number_of_iterations
  
  # This sampling provides us with a representation of our agent function
  # in terms of probabilties which we pass through

  print(val)
  return val


In [75]:
def create_cluster(unassigned, probability_representation,cluster_size):
  created_cluster = []
  KLVals = []

  min_KL_Val = 1e9
  cluster_size = min(cluster_size,len(unassigned))

  # print("Cluster size for this iteration = ",cluster_size," #still unassigned = ",len(unassigned))
  
  for itr in range(len(probability_representation)):
    sum = 0

    if itr not in unassigned: continue

    KLVals.clear()
    for i in range(len(probability_representation)):
      if i not in unassigned: continue
      KLval = 0
      prob1 = probability_representation[itr]
      prob2 = probability_representation[i]
      for variable in range(len(prob1)):
        KLval += prob1[variable]*np.log(prob1[variable]/prob2[variable])
      KLVals.append([KLval,i])

    KLVals.sort()
        
    for j in range(cluster_size):
      sum += KLVals[j][0]
    # print("KLVals for ",itr," = ", sum)
    if min_KL_Val > sum:
      min_KL_Val = sum
      created_cluster.clear()
      for j in range(cluster_size):
        created_cluster.append(KLVals[j][1])

  print("created cluster is ",created_cluster)
  for agent in created_cluster:
    unassigned.remove(agent)

  
  return created_cluster, unassigned

In [76]:
def clustering(agents,num_agents):

  # Our cluser size is arbitrary, treated as an hyper parameter

  number_of_clusters = int(math.sqrt(num_agents))
  cluster_size = int(num_agents/number_of_clusters)

  print("forming clusers of size = ",cluster_size)

  probability_representation = []

  for agent in agents:
    required = get_output_probabilities(agent)
    probability_representation.append(required)

  unassigned = [x for x in range(num_agents)] 
  # Checks for policies that are not clusered yet

  root = [0 for i in range(num_agents)]
  # root[i] is the behaviour policy for taget policy i

  while len(unassigned) > 0:
    created_cluster, unassigned = create_cluster(unassigned, probability_representation,cluster_size)
    for itr in created_cluster:
      root[itr] = created_cluster[0] 
      # 0th index is behavior policy for cluster i

  return root

In [78]:
game_actions = 2 #2 actions possible: left or right

#disable gradients as we will not use them
torch.set_grad_enabled(False)

# initialize N number of agents
num_agents = 100
agents = return_random_agents(num_agents)
generations = 1
clustered = clustering(agents,num_agents)

for generation in range(generations):
  print(clustered)
    

forming clusers of size =  10
[0.5341129644513131, 0.46588703279197213]
[0.44900100941956045, 0.5509989898204803]
[0.5291978529691697, 0.4708021501749754]
[0.6099790781736374, 0.3900209236443043]
[0.4831907924413681, 0.5168092064857482]
[0.4710996518135071, 0.5289003469645978]
[0.5059427039027214, 0.4940572942495346]
[0.5167091462016106, 0.4832908549010754]
[0.4988336196243763, 0.5011663830280304]
[0.4745454404354095, 0.5254545590877533]
[0.4337009465396404, 0.5662990520000458]
[0.5261946974098682, 0.47380530241131785]
[0.4460518990457058, 0.5539480990171433]
[0.5343110322058201, 0.4656889684051275]
[0.48761766517162325, 0.5123823351264]
[0.6307132095992565, 0.3692867883369327]
[0.4508833016604185, 0.5491166982650757]
[0.5663440747559071, 0.43365592388808727]
[0.5219256051778793, 0.47807439452409745]
[0.43879225304722785, 0.5612077490389347]
[0.5550357738137245, 0.444964226603508]
[0.494200632750988, 0.5057993659079075]
[0.4669625455737114, 0.5330374532341957]
[0.4939909536242485, 0.50