In [73]:
import gym
import numpy as np
import random
from collections import Counter, deque
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim

In [74]:
# Hyperparameters
batch_size = 32

gamma = 0.99

eps_start=1.0
eps_decay = 0.995
eps_min = 0.1      # Minimal exploration rate (epsilon-greedy)

num_rounds = 500
num_episodes = 50
learning_limit = 100
replay_limit = 1000  # Number of steps until starting replay
weight_update = 1000 # Number of steps until updating the target weights


In [75]:
#create model
class Model(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(Model, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_shape[0]*input_shape[1], 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
            
        )

    def forward(self, x):
        # flatten the observation space Box to linear tensor
        tensor_array = torch.from_numpy(state)
        x_flat = torch.flatten(tensor_array).to(torch.float32)
        return self.net(x_flat)

In [76]:
class SchedulerEnv(gym.Env):

    def __init__(self):
        
        #starting parameters
        num_gps = 100
        num_slots = 32
        num_pre_booked = 750
        to_book = [2,1,2,2,1,1,1,3,3,1,2,1,3,2,1,1,2,1,3,2,3,2]
        num_to_book = len(to_book)
        agent_pos = [0,0]
        reward_decay = 0.95
        
        #set parameters for the day
        self.num_gps = num_gps
        self.num_slots = num_slots
        self.num_pre_booked = num_pre_booked
        self.to_book = to_book
        self.num_to_book = num_to_book
        self.diary_slots = num_gps*num_slots
        self.agent_pos = agent_pos
        self.reward_decay = reward_decay

        #set action space to move around the grid
        self.action_space = gym.spaces.Discrete(4) #up, down, left, right
        
        #set observation space 
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(self.num_slots, self.num_gps), dtype=np.int32)
   
    #creates daily diary for each gp, randomly populates prebooked appointments and resets parameters
    def reset(self):

        #creates zero filled dataframe with row per time slot and column per gp
        self.state = np.zeros((self.num_slots, self.num_gps),dtype=float)

        #randomly enters a 1 for each pre booked appointments
        pre_booked = self.num_pre_booked
        while pre_booked>0:
            pre_booked -= 1
            self.state[np.random.randint(self.num_slots), np.random.randint(self.num_gps)] = 1
            
        #randomly sets the agent start space
        self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]

        #resets parameters for new episode
        self.done = False
        self.reward = 0
        self.appt_idx = 0
        self.decay_steps = 1
        
        #print('starting state', self.state.sum(), self.state)

        return self.state
    
    #calculates new position of the agent based on the action
    def move_agent(self, action):

        #set boundaries for the grid
        max_row = self.num_slots - 1
        max_col = self.num_gps - 1

        #setting new co-ordinates for the agent
        new_row = self.agent_pos[0]
        new_col = self.agent_pos[1]

        #calculate what the new position may be based on the action without going out the grid
        if action == 0:
            #print('up')
            new_row = max(self.agent_pos[0] - 1, 0)
        if action == 1:
            #print('down')
            new_row = min(self.agent_pos[0] + 1, max_row)
        if action == 2:
            #print('left')
            new_col = max(self.agent_pos[1] - 1, 0)
        if action == 3:
            #print('right')
            new_col = min(self.agent_pos[1] + 1, max_col)

        new_pos = [new_row, new_col]
        #print('new pos', new_pos)

        return new_pos

    #checks if we can look to book appointment starting here
    def check_bookable(self):
        return self.state[self.agent_pos[0], self.agent_pos[1]] == 0.0
    
    #action if we can't book the appointment
    def invalid_booking(self):
        #print('cant book')
        self.decay_steps += 1
        self.reward = -1
        
    #action if we can book the appointment
    def valid_booking(self):
        #print('go ahead and book')
        self.appt_idx += 1
        self.decay_steps = 1
        self.reward = 1
    
    #checks if the appointment fits
    def check_and_book(self):
        
        max_row = self.num_slots - 1
        cells_to_check = self.to_book[self.appt_idx]
        
        if cells_to_check==1:
            #print('good to check for single')
            if self.state[self.agent_pos[0], self.agent_pos[1]] == 0:
                self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                self.valid_booking()
            else:
                #print('single taken')
                self.invalid_booking()

        if cells_to_check==2:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]<max_row:
                #check the next cells is also 0.0
                #print('good to check for double')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+1), self.agent_pos[1]]
                    #print('after booking', self.agent_pos)
                else:
                    #print('double taken')
                    self.invalid_booking()
            else:
                #print('not for double')
                self.invalid_booking()
                
        if cells_to_check==3:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+1<max_row:
                #print('good to check for treble')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+2), self.agent_pos[1]]
                else:
                    #print('treble taken')
                    self.invalid_booking()
            else:
                #print('not for treble')
                self.invalid_booking()
                
        if cells_to_check==4:
            #check we're not at the bottom of the grid
            if self.agent_pos[0]+2<max_row:
                #check the next cells is also 0.0
                #print('good for quad')
                if self.state[self.agent_pos[0], self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+1), self.agent_pos[1]] == 0 \
                 and self.state[(self.agent_pos[0]+2), self.agent_pos[1]] == 0 and \
                self.state[(self.agent_pos[0]+3), self.agent_pos[1]] == 0:
                    self.state[self.agent_pos[0], self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+1), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+2), self.agent_pos[1]] = 1
                    self.state[(self.agent_pos[0]+3), self.agent_pos[1]] = 1
                    self.valid_booking()
                    self.agent_pos = [(self.agent_pos[0]+3), self.agent_pos[1]]
                else:
                    #print('quad taken')
                    self.invalid_booking()
            else:
                #print('not for quad')
                self.invalid_booking()

        next_state = self.state

        return next_state

    def step(self, action):

        #get new position of agent based on action
        new_agent_pos = self.move_agent(action)
        #print('new and old pos', new_agent_pos, self.agent_pos)
        
        #if the agent is stuck on an edge then move to a new position
        if new_agent_pos == self.agent_pos:
            self.agent_pos = [np.random.randint(self.num_slots), np.random.randint(self.num_gps)]
            #print('here1', self.agent_pos)
        else:
            self.agent_pos = new_agent_pos
            #print('here2', self.agent_pos)
        
        #print('trying to book', self.to_book, self.appt_idx)
        
        #check if it's possible to book then book
        if self.check_bookable():
            #print('checked here')
            self.state = self.check_and_book()
        else:
            #print('not bookable')
            self.invalid_booking()
        
        #work out if episode complete
        if self.appt_idx == len(self.to_book):
            #print('all booked')
            self.done = True
  
        #work out rewards
        #self.reward = (1 - (self.reward_decay**self.decay_steps))
        
        #print('step', self.decay_steps, self.reward)
        #print('end step')

        info = {}
        return self.state, self.reward, self.done, info

In [None]:
#device = "cuda"
device = "cpu"

env = SchedulerEnv()

#start writing to tensorboard
writer = SummaryWriter(comment="Scheduler DQN")

#create the current network and target network
policy_model = Model(env.observation_space.shape, env.action_space.n).to(device)

target_model = Model(env.observation_space.shape, env.action_space.n).to(device)
target_model.load_state_dict(policy_model.state_dict())

optimizer = optim.Adam(policy_model.parameters(), lr=0.001, eps=1e-3)

epsilon = eps_start          # Exploration rate    
replay_buffer = deque(maxlen=1000)

step_idx = 0

for i in range(num_rounds):
    #change this for while not true once it works
    episode_reward = 0
    for j in range(num_episodes):
        
        step_idx += 1
        state = env.reset()

        #epsilon for epsilon greedy strategy  
        if epsilon > eps_min:
            epsilon *= eps_decay
            
        check = policy_model(state)
        print(check)

        # Select and perform an action
        if step_idx > learning_limit:
            if np.random.rand() > epsilon:
                action = torch.argmax(policy_model(state))
                print('action', action)
        else:
            action = np.random.randint(env.action_space.n)

        next_state, reward, done, _ = env.step(action)
        reward = torch.tensor([reward], device=device)
        episode_reward += reward

        # Store other info in replay buffer
        replay_buffer.append((state, action, reward, next_state, done))

        # Move to the next state
        state = next_state

        if done:
            break
        
    writer.add_scalar('episode_reward', episode_reward, step_idx)

    #once we're ready to learn then start learning with mini batches
    if len(replay_buffer) == replay_limit:
        minibatch = random.sample(replay_buffer, batch_size)

        for state, action, reward, next_state, done in minibatch:    
            optimizer.zero_grad()
            #pass state to policy to get qval from policy
            pred_qval = policy_model(state)

            #pass next state to target policy to get next set of qvals (future gains)
            if not done:
                next_qval = (reward + (gamma * max(target_model(next_state)))).detach()
                #next_qval = next_qval.detach()
            else:
                next_qval = reward    

            loss = F.mse_loss(pred_qval, next_qval)
            loss.backward()

            optimizer.step()


    # Update the target network, copying all weights and biases in DQN
    # Periodically update the target network by Q network to target Q network
    if i % weight_update == 0:
        # Update weights of target
        target_model.load_state_dict(policy_model.state_dict())

writer.close()

tensor([-0.1323, -0.0116, -0.0279, -0.0259], grad_fn=<AddBackward0>)
tensor([-0.1001, -0.0004,  0.0048,  0.0088], grad_fn=<AddBackward0>)
tensor([-0.1278,  0.0127, -0.0290, -0.0396], grad_fn=<AddBackward0>)
tensor([-0.1070,  0.0859, -0.0587,  0.1053], grad_fn=<AddBackward0>)
tensor([-0.0536,  0.0559, -0.0342,  0.0547], grad_fn=<AddBackward0>)
tensor([-0.0697,  0.0889, -0.0185, -0.0010], grad_fn=<AddBackward0>)
tensor([-0.1062,  0.0476, -0.0823,  0.0509], grad_fn=<AddBackward0>)
tensor([-0.0571,  0.0251, -0.0451,  0.0399], grad_fn=<AddBackward0>)
tensor([-0.0908,  0.0519,  0.0093,  0.0104], grad_fn=<AddBackward0>)
tensor([-0.0134,  0.0710, -0.0485,  0.0069], grad_fn=<AddBackward0>)
tensor([-0.0828,  0.0261, -0.0319,  0.0101], grad_fn=<AddBackward0>)
tensor([-0.1079, -0.0410,  0.0284, -0.0630], grad_fn=<AddBackward0>)
tensor([-0.0697,  0.0142, -0.0296,  0.0265], grad_fn=<AddBackward0>)
tensor([-0.0919, -0.0143, -0.0560,  0.0286], grad_fn=<AddBackward0>)
tensor([-0.1306,  0.0322, -0.0546,

tensor([-0.0714,  0.0212, -0.0137,  0.0188], grad_fn=<AddBackward0>)
tensor([-0.0585, -0.0114,  0.0598, -0.0159], grad_fn=<AddBackward0>)
tensor([-0.0337,  0.0897, -0.0249,  0.0747], grad_fn=<AddBackward0>)
tensor([-0.0473,  0.0021, -0.0265,  0.0282], grad_fn=<AddBackward0>)
tensor([-0.0998,  0.0056, -0.0511,  0.0748], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0389, -0.0046, -0.0299,  0.0431], grad_fn=<AddBackward0>)
tensor([-0.0273,  0.0456, -0.0136,  0.0433], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0909,  0.0590, -0.0130,  0.0301], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0674,  0.0234, -0.0474, -0.0179], grad_fn=<AddBackward0>)
action tensor(1)
tensor([ 0.0047,  0.0184, -0.0136, -0.0213], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0367,  0.0589,  0.0383,  0.0329], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0774,  0.0086, -0.0125,  0.0342], grad_fn=<AddBackward0>)
tensor([-0.0655,  0.0386, -0.0167,  0.0556], grad_fn=<AddBackward0>)
a

tensor([-0.0792,  0.0360, -0.0445,  0.0118], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0745,  0.0419, -0.0124,  0.0642], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0952,  0.0361,  0.0028,  0.0520], grad_fn=<AddBackward0>)
tensor([ 0.0107, -0.0004, -0.0084, -0.0050], grad_fn=<AddBackward0>)
action tensor(0)
tensor([-0.0388,  0.0115, -0.0055, -0.0282], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0923, -0.0188,  0.0087,  0.0467], grad_fn=<AddBackward0>)
tensor([-0.1019, -0.0359, -0.0226,  0.0493], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0569,  0.0454, -0.0266,  0.0255], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0549, -0.0173,  0.0426,  0.0422], grad_fn=<AddBackward0>)
tensor([-0.0577, -0.0362, -0.0142,  0.0011], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0973,  0.0306, -0.0218,  0.0257], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0749,  0.0281,  0.0190, -0.0356], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0534,  0

tensor([-0.1177, -0.0312, -0.0444,  0.0272], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0893,  0.0231,  0.0196,  0.0094], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0528,  0.0341, -0.0090,  0.0497], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1289, -0.0127,  0.0217, -0.0580], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0469,  0.0593, -0.0403, -0.0299], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1332, -0.0067, -0.0273,  0.0143], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1489,  0.0354,  0.0533,  0.0029], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.1181,  0.0169, -0.0103,  0.0051], grad_fn=<AddBackward0>)
action tensor(1)
tensor([ 0.0138,  0.0641, -0.0538,  0.0635], grad_fn=<AddBackward0>)
tensor([-0.0567,  0.0294, -0.0389,  0.0544], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0715,  0.0340,  0.0159,  0.0169], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1199,  0.0295, -0.0508,  0.0224], grad_fn=<AddBackward0>)
ac

tensor([ 0.0348, -0.0096,  0.0331, -0.0069], grad_fn=<AddBackward0>)
tensor([-0.0245, -0.0212, -0.0276,  0.0048], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0451,  0.0545, -0.0299,  0.0410], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1061, -0.0009, -0.0067, -0.0013], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0677,  0.0634, -0.0510,  0.0099], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1409, -0.0209, -0.0441,  0.0182], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0541, -0.0278, -0.0624, -0.0158], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0870,  0.0811,  0.0373,  0.0703], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0620,  0.0297, -0.0238,  0.0751], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0611,  0.0120, -0.0069,  0.0106], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1032,  0.0235,  0.0030,  0.0808], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1066,  0.0397, -0.0232, -0.0271], grad_fn=<AddBackward0>)
te

tensor([-0.0791, -0.0255,  0.0099, -0.0217], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0474,  0.0184, -0.0285, -0.0314], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0430,  0.0578, -0.0176,  0.0160], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0606,  0.0051,  0.0401,  0.0642], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1019,  0.0524, -0.0207,  0.0372], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0639,  0.0651, -0.0546,  0.0408], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0722, -0.0031, -0.0238,  0.0613], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1450,  0.0116, -0.0601,  0.0113], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0514, -0.0081, -0.0470,  0.0111], grad_fn=<AddBackward0>)
tensor([-0.0827,  0.0332, -0.0854, -0.0042], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0523,  0.0663, -0.0289, -0.0491], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0119, 0.0517, 0.0380, 0.0301], grad_fn=<AddBackward0>)
action

tensor([-0.0780,  0.0066, -0.0279,  0.0356], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0514,  0.0503, -0.0382,  0.0461], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1138,  0.0184, -0.0197,  0.0164], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0469,  0.0130, -0.0025,  0.0244], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0482,  0.0302, -0.0342,  0.0277], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1037, -0.0187,  0.0078,  0.0002], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0747,  0.0882, -0.0134,  0.0285], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0733,  0.0416, -0.0373,  0.0208], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0256,  0.0746, -0.0342,  0.0340], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0390,  0.0775, -0.0563,  0.0466], grad_fn=<AddBackward0>)
tensor([-0.0292,  0.0312,  0.0268, -0.0308], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0423,  0.0013, -0.0222,  0.0066], grad_fn=<AddBackward0>)
ac

tensor([-0.1185,  0.0118, -0.0140,  0.0129], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0506,  0.0239,  0.0594, -0.0211], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0725,  0.0033, -0.0214,  0.0044], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1114,  0.0498,  0.0410, -0.0112], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1137,  0.0109,  0.0216,  0.0360], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0426,  0.0557,  0.0013,  0.0037], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1125,  0.0249, -0.0872,  0.0384], grad_fn=<AddBackward0>)
tensor([-0.1429,  0.0397, -0.0418,  0.0418], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1268,  0.0036, -0.0519,  0.0688], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.1012, -0.0387, -0.0261,  0.0104], grad_fn=<AddBackward0>)
action tensor(3)
tensor([ 0.0229,  0.0714, -0.0263,  0.0284], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0320,  0.0946, -0.0088, -0.0131], grad_fn=<AddBackward0>)
te

tensor([-0.0694,  0.0150,  0.0381,  0.0141], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0771, -0.0012,  0.0314,  0.0064], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0737, -0.0163,  0.0535, -0.0407], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0406,  0.0275, -0.1331, -0.0014], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0321,  0.0411, -0.0502,  0.0014], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0803,  0.0179,  0.0139, -0.0171], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0607,  0.0480, -0.0410,  0.0997], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0984, -0.0248, -0.0126,  0.0023], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0276,  0.0525, -0.0489, -0.0835], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.1152,  0.0198, -0.0764,  0.0708], grad_fn=<AddBackward0>)
action tensor(3)
tensor([-0.0175,  0.0443, -0.0157,  0.0395], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0281, -0.0407,  0.0089, -0.0183], grad_fn=<

tensor([-0.0301,  0.0375,  0.0516, -0.0251], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0739,  0.0341, -0.0358,  0.0047], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2923, 0.2608, 0.2556, 0.2162], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.1963, 0.2029, 0.2120, 0.1706], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.3046, 0.2787, 0.2432, 0.1822], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2648, 0.2354, 0.2096, 0.2421], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2536, 0.2250, 0.1671, 0.2205], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2462, 0.2791, 0.1801, 0.2208], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2276, 0.2404, 0.2053, 0.1943], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.3039, 0.2856, 0.3022, 0.1842], grad_fn=<AddBackward0>)




action tensor(0)
tensor([0.2923, 0.2870, 0.2135, 0.2187], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2709, 0.2284, 0.2421, 0.2207], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2451, 0.2830, 0.2719, 0.2015], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2934, 0.2483, 0.2159, 0.2311], grad_fn=<AddBackward0>)
tensor([0.2693, 0.2564, 0.2459, 0.2102], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.3179, 0.2266, 0.2647, 0.2524], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.3687, 0.3205, 0.2680, 0.2505], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2385, 0.2273, 0.1985, 0.1518], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2715, 0.2732, 0.1803, 0.1970], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.3043, 0.2935, 0.1979, 0.2346], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.3108, 0.2719, 0.2088, 0.2524], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2843, 0.2633, 0.1696, 0.2408], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2580, 

tensor([0.0778, 0.1510, 0.1461, 0.1651], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.0496, 0.0954, 0.1511, 0.1164], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1762, 0.2126, 0.2017, 0.1902], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1432, 0.1755, 0.2026, 0.2012], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1176, 0.1538, 0.1183, 0.2196], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.1642, 0.1755, 0.1680, 0.1775], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.0845, 0.1601, 0.1762, 0.1938], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.1367, 0.1689, 0.1907, 0.1508], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1199, 0.1598, 0.1762, 0.2018], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.1381, 0.1717, 0.1395, 0.1565], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2143, 0.2369, 0.2388, 0.2478], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.1400, 0.1833, 0.1916, 0.1423], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.0883, 

tensor([0.0866, 0.1871, 0.2495, 0.1246], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1713, 0.2469, 0.2559, 0.2020], grad_fn=<AddBackward0>)
tensor([0.1977, 0.2549, 0.2459, 0.2127], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1880, 0.2121, 0.2304, 0.1732], grad_fn=<AddBackward0>)
tensor([0.1268, 0.1948, 0.2314, 0.1414], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1105, 0.2330, 0.2297, 0.1518], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1160, 0.1868, 0.2303, 0.1488], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1348, 0.2316, 0.2690, 0.1728], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.0969, 0.2141, 0.2389, 0.1674], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1859, 0.2319, 0.2342, 0.2015], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1231, 0.2159, 0.2825, 0.1991], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1317, 0.2046, 0.2487, 0.1468], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1614, 0.2365, 0.2462, 0.1908], grad_fn=<

tensor([0.0424, 0.1587, 0.1385, 0.0751], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0743, 0.1444, 0.1240, 0.0932], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0219, 0.1357, 0.1333, 0.0491], grad_fn=<AddBackward0>)
tensor([0.0634, 0.1946, 0.1960, 0.1064], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.0598, 0.1490, 0.1466, 0.1115], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.5687, 0.3769, 0.5607, 0.4552], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.5230, 0.3686, 0.5031, 0.4643], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.5879, 0.4196, 0.5804, 0.4948], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.5741, 0.3863, 0.5866, 0.5319], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.5866, 0.3833, 0.4085, 0.4514], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.4471, 0.3238, 0.3683, 0.3267], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.4945, 0.3616, 0.4354, 0.4104], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.5529, 0.4102, 0.4289, 0

tensor([0.2231, 0.2581, 0.3211, 0.2586], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.2060, 0.2394, 0.2759, 0.2581], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1418, 0.2369, 0.2708, 0.2499], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.2706, 0.2741, 0.3488, 0.3343], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1972, 0.2624, 0.3181, 0.3138], grad_fn=<AddBackward0>)
tensor([0.1406, 0.2036, 0.1893, 0.2126], grad_fn=<AddBackward0>)
tensor([0.3759, 0.3728, 0.3995, 0.4217], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.2274, 0.2823, 0.3065, 0.2840], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.2437, 0.2800, 0.3535, 0.3151], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.2646, 0.2870, 0.3052, 0.2911], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1633, 0.2190, 0.3018, 0.2082], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1350, 0.1872, 0.2711, 0.2118], grad_fn=<AddBackward0>)
tensor([0.1200, 0.2061, 0.2483, 0.1692], grad_fn=<AddBackward0>)
ac

tensor([0.2018, 0.1445, 0.1224, 0.0802], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.1569, 0.1036, 0.0932, 0.0375], grad_fn=<AddBackward0>)
tensor([0.1415, 0.0689, 0.0729, 0.0189], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.1363, 0.0914, 0.0971, 0.0259], grad_fn=<AddBackward0>)
action tensor(0)
tensor([ 0.1132, -0.0045,  0.0076,  0.0147], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.1658, 0.0765, 0.0688, 0.0410], grad_fn=<AddBackward0>)
tensor([0.1379, 0.0577, 0.0632, 0.0184], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2607, 0.1697, 0.1502, 0.1297], grad_fn=<AddBackward0>)
action tensor(0)
tensor([ 0.1096,  0.0387,  0.0098, -0.0223], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.1308, 0.0789, 0.1195, 0.0337], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.1584, 0.1081, 0.1213, 0.0376], grad_fn=<AddBackward0>)
action tensor(0)
tensor([ 0.1789,  0.0440,  0.0515, -0.0192], grad_fn=<AddBackward0>)
tensor([0.1219, 0.0917, 0.0942, 0.0396], grad_fn=<AddBa

tensor([0.1095, 0.2237, 0.2562, 0.1984], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.0752, 0.2154, 0.2026, 0.1231], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.3653, 0.4070, 0.3798, 0.3778], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0222, 0.1454, 0.2161, 0.0941], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0806,  0.0729,  0.1262,  0.0251], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.2502, 0.3182, 0.2816, 0.3091], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0246, 0.1634, 0.1538, 0.0570], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0569,  0.1258,  0.0950,  0.0288], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1104, 0.2050, 0.2851, 0.1596], grad_fn=<AddBackward0>)
action tensor(2)
tensor([-0.0207,  0.1379,  0.1743,  0.1148], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.0116, 0.1511, 0.1966, 0.0936], grad_fn=<AddBackward0>)
tensor([0.2351, 0.2895, 0.3410, 0.2526], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1580, 0.237

tensor([0.1094, 0.1346, 0.1212, 0.1041], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2462, 0.2483, 0.1672, 0.2146], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0278, 0.0998, 0.0365, 0.0708], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1151, 0.1560, 0.0955, 0.1005], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0319, 0.1186, 0.0198, 0.0908], grad_fn=<AddBackward0>)
action tensor(1)
tensor([-0.0248,  0.0745,  0.0549,  0.0546], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.0918, 0.1548, 0.1481, 0.0783], grad_fn=<AddBackward0>)
tensor([0.0652, 0.1272, 0.1560, 0.0664], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.1194, 0.1610, 0.1298, 0.0828], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1741, 0.2100, 0.1726, 0.1692], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1833, 0.2048, 0.1675, 0.1483], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1840, 0.2247, 0.1959, 0.1520], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2209, 0.2331, 0.212

tensor([0.5452, 0.5036, 0.4691, 0.5664], grad_fn=<AddBackward0>)
tensor([0.4137, 0.4309, 0.4082, 0.4439], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.5274, 0.5159, 0.5056, 0.5554], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.6120, 0.5623, 0.5499, 0.6286], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.6060, 0.5811, 0.5486, 0.6336], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.5754, 0.5943, 0.6131, 0.5716], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.3414, 0.3785, 0.3902, 0.3322], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.4972, 0.5110, 0.5304, 0.4861], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.5375, 0.5400, 0.5246, 0.5376], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.5332, 0.5402, 0.5588, 0.5269], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.6220, 0.6504, 0.6843, 0.6165], grad_fn=<AddBackward0>)
tensor([0.5879, 0.6193, 0.6241, 0.6100], grad_fn=<AddBackward0>)
action tensor(2)
tensor([0.4806, 0.5350, 0.5256, 0.4989], grad_fn=<

tensor([0.2247, 0.2495, 0.2178, 0.1864], grad_fn=<AddBackward0>)
tensor([0.2992, 0.3083, 0.2602, 0.2991], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.2874, 0.2846, 0.2409, 0.2552], grad_fn=<AddBackward0>)
tensor([0.4308, 0.4120, 0.3502, 0.4232], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2471, 0.2581, 0.2340, 0.2174], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.3193, 0.3127, 0.2605, 0.2795], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.2349, 0.2505, 0.2135, 0.1880], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.1972, 0.2350, 0.2000, 0.1771], grad_fn=<AddBackward0>)
action tensor(1)
tensor([0.3401, 0.3238, 0.2811, 0.3525], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.3367, 0.3261, 0.2525, 0.3127], grad_fn=<AddBackward0>)
action tensor(0)
tensor([0.3889, 0.3683, 0.3099, 0.3940], grad_fn=<AddBackward0>)
action tensor(3)
tensor([0.3149, 0.3151, 0.2420, 0.2920], grad_fn=<AddBackward0>)
tensor([0.4008, 0.3895, 0.3137, 0.4061], grad_fn=<AddBackward0>)
ac

In [None]:
check

In [None]:
torch.argmax(check)

In [None]:
np.random.randint(5, size=(2, 4))