In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import torch as torch
import torch.nn as nn
import random
from IPython.display import clear_output
from enum import Enum, auto
import yaml
from types import SimpleNamespace as SN

In [2]:
# get config of qmix algorithim 
with open('qmix.yaml', 'r') as f:
    qmix_config = yaml.load(f, Loader=yaml.FullLoader)
# get config of environment
with open('env.yaml', 'r') as f:
    env_config = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
class Actions(Enum):
    NO_OP = 0
    MOVE_UP = auto()
    MOVE_DOWN = auto()
    MOVE_LEFT = auto()
    MOVE_RIGHT = auto()
    
    @property
    def delta(self):
        if self == self.NO_OP:
            return (0, 0)
        if self == self.MOVE_UP:
            return (-1, 0)
        if self == self.MOVE_DOWN:
            return (1, 0)
        if self == self.MOVE_LEFT:
            return (0, -1)
        if self == self.MOVE_RIGHT:
            return (0, 1)

    @property
    def one_hot(self):
        out_dim = len(Actions)
        y_one_hot = torch.zeros(out_dim)
        y_one_hot[self.value] = 1
        return y_one_hot.float()

In [9]:
class GridEnv():
    def __init__(self, rows, cols, n_agents, state_shape):
        self.rows = rows
        self.cols = cols
        self.n_agents = n_agents

        self.state_shape = state_shape
        self.obs_shape = (n_agents+2)*self.state_shape
        self.action_shape = len(Actions)
        self.agent_in_shape = self.obs_shape + self.action_shape

        self.block = None
        self.goal = None
        self.agent = None

        self.populate_grid()
        
    def get_global_state(self):
        return torch.flatten(nn.functional.one_hot(torch.tensor([p[0]*self.cols+p[1] for p in (self.goal, self.block, self.agent.pos)]), self.rows*self.cols))
        
    def populate_grid(self):
        positions = np.random.choice(self.rows*self.cols, 1+2, replace=False)

        self.goal = (positions[0]//self.cols, positions[0]%self.cols)
        self.block = (positions[1]//self.cols, positions[1]%self.cols)
        self.agent = Agent(self, (positions[2]//self.cols, positions[2]%self.cols))

    def vizualize_grid(self): 
        grid = [list("."*self.cols) for _ in range(self.rows)]
        
        grid[self.goal[0]][self.goal[1]] = "G"
        grid[self.block[0]][self.block[1]]= "B"
        
        grid[self.agent.pos[0]][self.agent.pos[1]] = "A"
        
        return '\n'.join([' '.join(row) for row in grid])

    def __repr__(self):
        return str(self.vizualize_grid())

In [11]:
class AgentModel(nn.Module): 
    def __init__(self, input_shape, embed_dim, num_actions): # input shape is shape of a replay buffer
        super().__init__()
        self.embed_dim = embed_dim
        self.num_actions = num_actions
        self.linear1 = nn.Linear(input_shape, self.embed_dim)
        self.rnn = nn.GRUCell(self.embed_dim, self.embed_dim)
        self.linear2 = nn.Linear(self.embed_dim, self.num_actions)

    def init_hidden(self):
        return self.linear1.weight.new(1, self.embed_dim).zero_()
    
    def forward(self, inputs, hidden_in):
        x = nn.functional.relu(self.linear1(inputs)) 
        h_in = torch.reshape(hidden_in, (-1, self.embed_dim))
        hidden_out = self.rnn(x, h_in)
        q_values = self.linear2(hidden_out)
        return q_values, hidden_out

In [41]:
class Agent():
    def __init__(self, grid, pos):
        self.pos = pos
        self.grid = grid
        self.embed_dim = 64
        self.hidden_states = torch.zeros((self.embed_dim,)) 
        self.local_model = AgentModel(self.grid.agent_in_shape, self.embed_dim, len(Actions))
        self.prev_action = Actions.NO_OP
        self.epsilon = 0.1
        self.alpha = 0.01
        self.gamma = 0.01 
        self.loss = nn.HuberLoss()
        self.optimizer = torch.optim.Adam([{'params': self.local_model.parameters(), 'lr': self.alpha}])
        self.agent_in = None
        self.target_qvalues = None
        
    def get_reward(self):
        if self.grid.block == self.grid.goal:
            print("Goal Reached!")
            return 0
        else:
            return -1

    def get_correct_qvalue(self, pred_max_value, reward, current_max_value):
        q_value = pred_max_value + self.alpha * (reward + self.gamma * current_max_value - pred_max_value)
        return q_value

    def take_action(self, action):
        future_position = self.pos
        future_position += np.array(action.delta)
        future_position %= [self.grid.rows, self.grid.cols]

        if list(future_position) == self.grid.block:
            self.grid.block += np.array(action.delta)
            self.grid.block %= [self.grid.rows, self.grid.cols]
        else:
            self.pos = future_position
            
    def step(self, vizualize):
        if vizualize:
            clear_output(wait=True)
            print(self.grid.vizualize_grid())
            
        current_state = self.grid.get_global_state()
        self.agent_in = torch.unsqueeze(torch.cat((self.grid.get_global_state(), self.prev_action.one_hot)), 0)
        pred_qvalues, self.hidden_states = self.local_model(self.agent_in, self.hidden_states)
    
        pred_max_index = np.argmax(torch.detach(pred_qvalues).numpy())

        if random.random() < self.epsilon:
           action = random.choice(list(Actions))
        else:
            action = Actions(pred_max_index)
            
        self.take_action(action)
        self.prev_action = action
        future_state = self.grid.get_global_state()
        future_in = torch.unsqueeze(torch.cat((future_state, self.prev_action.one_hot)), 0)

        reward = self.get_reward()        
        future_qvalues, _  = self.local_model(future_in, self.hidden_states)
        future_max_index = np.argmax(torch.detach(future_qvalues).numpy())
        
        target_qvalue = self.get_correct_qvalue(pred_qvalues[0][pred_max_index], reward, future_qvalues[0][future_max_index])
        self.target_qvalues = pred_qvalues
        self.target_qvalues[0][pred_max_index] = target_qvalue

In [42]:
class ReplayBuffer():
    def __init__(self):
        pass

In [59]:
class QLearner():
    def __init__(self, grid, batch):
        self.grid = grid
        self.batch = batch

    def run_episode(self, n_timesteps):
        t = 0
        while self.grid.block != self.grid.goal and t < n_timesteps: 
            self.grid.agent.step(True)
            print(t)
            t += 1
        
    # def train(self):
    #     torch.autograd.set_detect_anomaly(True)
    #     for inputs, targets in zip([self.agent_in], [self.target_qvalues]):  
    #         self.optimizer.zero_grad()  # Zero the gradients
    #         outputs, _ = g.agent.local_model(inputs, self.hidden_states)  # Forward pass
    #         loss = self.loss(outputs, targets)  # Compute loss
    #         loss.backward()  # Backward pass
    #         self.optimizer.step()  # Update weights

In [60]:
g = GridEnv(6, 4, 1, 6*4)

In [61]:
qlearner = QLearner(g, None)

In [64]:
qlearner.run_episode(1000)

. . . .
. B G .
A . . .
. . . .
. . . .
. . . .
999


In [None]:
class AgentController():
    def __init__(self):
        pass

In [None]:
# class EpisodeBatch():
#     def __init__(self):
#         pass
    
#     def return_batch(self):
#         return 0

#     # def max_t_filled(self):
#     #     return torch.sum(self.data.transition_data.["filled"], 1).max(0)[0]

In [None]:
# class ReplayBuffer(EpisodeBatch):
#     def __init__(self, buffer_size):
#         self.buffer_size = buffer_size
#         self.episodes_in_buffer = 0

#     def insert_episode_batch(self): 
#         pass

#     def can_sample(self, batch_size):
#         return self.episodes_in_buffer >= batch_size

#     def sample(self, batch_size):
#         assert self.can_sample(batch_size)
        
#         if self.episodes_in_buffer == batch_size: # return one batch
#             return self[:batch_size]
#         else: # sample uniformly
#             n_sampled_episodes = np.random.choice(self.episodes_in_batch, batch_size, replace=False)
#             return self[:n_sampled_episodes]

In [None]:
# class EpisodeRunner():
#     def __init__(self, data):
#         self.batcher = EpisodeBatch()

#         if data is not None:
#             self.data = data
#         else:
#             data = SN()

    
#     def get_batch(self):
#         batch = self.batcher.return_batch()
#         return batch

#     def run_episode(self, training_mode):
#         pass

In [None]:
# class Runner(): 
#     def __init__(self, num_episodes):
#         self.rows = 6
#         self.cols = 4
#         self.n_agents = 1
#         # self.batch = torch.tensor(g.get_global_state() + torch.unsqueeze(torch.tensor_size == 1))
#         self.n_actions = len(Actions)
#         self.state_shape = self.rows*self.cols

#         self.batch_size = 1
        
#         self.env = GridEnv(self.rows, self.cols, self.n_agents, self.state_shape)
#         self.replay_buffer = ReplayBuffer(self.batch_size)
#         self.agent_controller = AgentController()
#         self.learner = QLearner()
#         self.episode_runner = EpisodeRunner(None)

#         self.max_timesteps = 20
#         self.current_t = 0

#     def train(self):
#         episode = 0
#         while self.current_t <= self.max_timesteps:
#             # run for one episode
#             episode = self.episode_runner.run_episode(training_mode=True)
#             #insert episode into replay buffer
#             self.replay_buffer.insert_episode_batch()

#             if self.replay_buffer.can_sample(self.batch_size):
#                 episode_sample = self.replay_buffer.sample(self.batch_size)

#                 max_eps_t = episode_sample.max_t_filled() 
#             break

In [None]:
# r = Runner(10)

In [None]:
# r.train()