In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.functional import normalize
from itertools import count
from torch.autograd import Variable
from environments.base_environment import OneHotEnv

In [14]:
class AverageRewardMDP_Agent(OneHotEnv):
    def __init__(self):
        self.states = []
        self.actions = ["up", "down", "left", "right"]
        self.lr = 0.2
        self.exp_rate = 0.3
        
        self.num_states = 9
        self.num_actions = 2
        
        #Env params
        self.start_state = 0
        self.current_state = 0
        self.reward_scale_factor = None
        self.P = torch.zeros((self.num_states, self.num_actions, self.num_states))
        self.R = torch.zeros((self.num_states, self.num_actions, self.num_states))
        # adding connections from node i to i+1
        for s in range(self.num_states - 1):
            self.P[s, 0, s + 1] = 1
            self.P[s, 1, s + 1] = 1
        # connection from N-1th to 0th node
        self.P[8, 0, 0] = 1; self.P[8, 1, 0] = 1
        # removing the connection from 4th to 5th node
        self.P[4, 0, 5] = 0; self.P[4, 1, 5] = 0
        # connection from 4th to 0th node
        self.P[4, 0, 0] = 1; self.P[4, 1, 0] = 1
        # action 1 in node 0 should not lead to 1, but 5
        self.P[0, 1, 1] = 0
        self.P[0, 1, 5] = 1
        # rewards for going from 0 to 1 and 8 to 0
        self.R[0, 0, 1] = 1
        self.R[8, 0, 0] = 2; self.R[8, 1, 0] = 2
        
        self.random_seed = 2023
        self.rand_generator = np.random.RandomState(self.random_seed)

        self.start_state = self.rand_generator.choice(self.num_states)
        self.reward_obs_term = [0.0, None, False]
        
        self.R_matrix = torch.zeros(self.num_states,self.num_actions)
        possible_transitions = self.R*self.P
        for i, actions in enumerate(possible_transitions):
            self.R_matrix[i] = actions.sum(dim=1)
        
        self.P = self.P.reshape(self.num_actions*self.num_states, self.num_states)
        #agent params
        self.policy = torch.zeros(self.num_states, self.num_actions, requires_grad=True)
        self.policy_optimizer = torch.optim.Adam([self.policy], lr=0.001)
        
        self.T = torch.zeros(self.num_states, requires_grad=True)
        self.T_optimizer = torch.optim.Adam([self.T], lr=0.001)
        
        self.potential = torch.zeros(self.num_states, requires_grad=True)
        self.potential_optimizer = torch.optim.Adam([self.potential], lr=0.001)
        
        self.sm = nn.Softmax(dim=0)
        self.psm = nn.Softmax(dim=1)

    
    def train(self, rounds=10):
        for step in range(rounds):
            self.policy.requires_grad_(True)
            self.T.requires_grad_(True)
            self.potential.requires_grad_(False)
            for g in range(1):
                
                objective = self.R_matrix.flatten()*self.sm(self.policy.flatten())#self.psm(self.policy).flatten()
                entropy = -torch.log(self.sm(self.policy.flatten()))*self.sm(self.policy.flatten())
                t_part = self.potential*self.sm(self.T)
                p_part = torch.mm(self.P, self.potential.reshape(self.num_states,1)).reshape(self.num_states,self.num_actions).flatten()*self.sm(self.policy.flatten())
                
                
                #print('objective   ', objective.shape)
                #print('f_constraint',f_constraint.shape)
                #print('p_constraint', p_constraint.shape)
                
                policy_loss = -objective.sum() + t_part.sum() - p_part.sum()# - 10*entropy.sum()
               # policy_loss = -(self.R_matrix.flatten()*self.psm(self.policy.flatten())).sum() \
               #               -(torch.mm(self.P.reshape(self.num_actions*self.num_states, self.num_states), self.potential.reshape(self.num_states,1))*self.sm(self.policy.flatten())).sum()
                
                #print('Reward and policy:', (self.R_matrix*self.sm(self.policy)).shape) torch.Size([9, 2])
                #print('Potential and T:', (self.potential*self.sm(self.T)).shape) torch.Size([9])
                #print('P and Potential:', (torch.mm(self.P.reshape(self.num_states*self.num_actions, self.num_states), self.potential.reshape(self.num_states,1)).shape)) torch.Size([18, 1])
                #kmd
                #policy_loss = -policy_loss
                self.policy_optimizer.zero_grad()
                self.T_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()
                self.T_optimizer.step()
            
            self.policy.requires_grad_(False)
            self.T.requires_grad_(False)
            self.potential.requires_grad_(True)
            for d in range(1):
                
                t_part = self.potential*self.sm(self.T)
                p_part = torch.mm(self.P, self.potential.reshape(self.num_states,1)).reshape(self.num_states,self.num_actions).flatten()*self.sm(self.policy.flatten())
                f_loss = t_part.sum() - p_part.sum()
                f_loss = -f_loss
                
                self.potential_optimizer.zero_grad()
                f_loss.backward() 
                self.potential_optimizer.step()
            if step %1000 ==0:       
                print('Policy loss: ',policy_loss.item())
                print('Potentials loss: ', f_loss.item())
                print('----------------------------------')

In [12]:
agent = AverageRewardMDP_Agent()
agent.train(1000)

Policy loss:  0.2777777910232544
Potentials loss:  -0.0
----------------------------------
Policy loss:  0.08594746142625809
Potentials loss:  -0.009422238916158676
----------------------------------
Policy loss:  0.0379662811756134
Potentials loss:  -0.0010499954223632812
----------------------------------
Policy loss:  -0.0038937460631132126
Potentials loss:  0.02148124761879444
----------------------------------
Policy loss:  -0.007359161972999573
Potentials loss:  0.01724955439567566
----------------------------------
Policy loss:  -0.0012164413928985596
Potentials loss:  0.008312709629535675
----------------------------------
Policy loss:  0.032965198159217834
Potentials loss:  -0.02733531966805458
----------------------------------
Policy loss:  -0.03129586577415466
Potentials loss:  0.035596683621406555
----------------------------------
Policy loss:  -0.010031193494796753
Potentials loss:  0.012690022587776184
----------------------------------
Policy loss:  0.0554366409778595


In [13]:
#the results without using f and P in objectife
agent.policy, agent.psm(agent.policy), agent.sm(agent.policy.flatten())

(tensor([[-3.6785,  1.9867],
         [ 1.9600,  1.9600],
         [ 1.9600,  1.9600],
         [ 1.9600,  1.9600],
         [ 1.0222,  1.0222],
         [ 1.9600,  1.9600],
         [ 1.9600,  1.9600],
         [ 1.9600,  1.9600],
         [-4.1230, -4.1230]]),
 tensor([[0.0035, 0.9965],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000]]),
 tensor([0.0003, 0.0743, 0.0724, 0.0724, 0.0724, 0.0724, 0.0724, 0.0724, 0.0283,
         0.0283, 0.0724, 0.0724, 0.0724, 0.0724, 0.0724, 0.0724, 0.0002, 0.0002]))

In [271]:
for i in range(10):
    action = np.random.choice(agent.num_actions)
    obs = agent.env_step(action)
    print('action:',action, '  reward:',obs[0].item(), '  next_state:',np.argmax(obs[1]))

action: 1   reward: 0.0   next_state: 5
action: 1   reward: 0.0   next_state: 6
action: 1   reward: 0.0   next_state: 7
action: 0   reward: 0.0   next_state: 8
action: 0   reward: 2.0   next_state: 0
action: 1   reward: 0.0   next_state: 5
action: 0   reward: 0.0   next_state: 6
action: 0   reward: 0.0   next_state: 7
action: 1   reward: 0.0   next_state: 8
action: 0   reward: 2.0   next_state: 0


In [275]:
agent.T

tensor([ 1.5785, -1.6616, 11.7198, 11.7198, 11.7198, -1.6616, 11.7198, 11.7198,
        11.7198])

In [274]:
agent.potential

tensor([76.7428, 48.8321, 47.9522, 47.9522, 47.9522, 48.8321, 47.9522, 47.9522,
        47.9522], requires_grad=True)