In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.rnn as rnn_utils
import numpy as np

# parameters for training
GRAD_CLIP              = 30.0 # to clip gradients shown by tensorboard at 900 (30^2)
RNN_SIZE               = 128
GOAL_REPR_SIZE         = 12

# def normalized_columns_initializer(std=1.0):
#     def _initializer(shape, dtype=None):
#         out = torch.randn(*shape).type(torch.FloatTensor)
#         out *= std / (out**2).sum(0, keepdim=True).sqrt()
#         return out
#     return _initializer

class ACNet(nn.Module):
    def __init__(self, a_size, batch_size, trainer, learning_rate, TRAINING, GRID_SIZE):
        super(ACNet, self).__init__()
        # self.inputs = torch.nn.Parameter(torch.empty(1, 4, GRID_SIZE, GRID_SIZE))
        # self.water_res = torch.nn.Parameter(torch.empty(1, 1))
        # self.myinput = None  # You may need to transpose self.inputs here
        # self.inputs = torch.zeros(batch_size, 4, GRID_SIZE, GRID_SIZE)
        # self.water_res = torch.zeros(batch_size, 1)
        # self.myinput = self.inputs
        self.inputs = torch.zeros(batch_size, 4, GRID_SIZE, GRID_SIZE)
        self.water_res = torch.zeros(batch_size, 1)


        # Define ACNet layers
        # 4 maps for each agent
        self.conv1 = nn.Conv2d(4, RNN_SIZE // 4, kernel_size=3, stride=1, padding=1)
        self.conv1a = nn.Conv2d(RNN_SIZE // 4, RNN_SIZE // 4, kernel_size=3, stride=1, padding=1)
        self.conv1b = nn.Conv2d(RNN_SIZE // 4, RNN_SIZE // 4, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(RNN_SIZE // 4, RNN_SIZE // 2, kernel_size=3, stride=1, padding=1)
        self.conv2a = nn.Conv2d(RNN_SIZE // 2, RNN_SIZE // 2, kernel_size=3, stride=1, padding=1)
        self.conv2b = nn.Conv2d(RNN_SIZE // 2, RNN_SIZE // 2, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(RNN_SIZE // 2, RNN_SIZE - GOAL_REPR_SIZE, kernel_size=2, stride=1)
        self.flat_size = (RNN_SIZE - GOAL_REPR_SIZE) * GRID_SIZE * GRID_SIZE  # Update this size based on your GRID_SIZE
        self.flat_size = batch_size
        self.fc0 = nn.Linear(1,GOAL_REPR_SIZE)
        self.fc1 = nn.Linear(RNN_SIZE, RNN_SIZE)
        self.fc2 = nn.Linear(RNN_SIZE, RNN_SIZE)
        self.fc3 = nn.Linear(RNN_SIZE, RNN_SIZE)
        # LSTM cell
        self.lstm= nn.LSTM(RNN_SIZE, RNN_SIZE)
        # Policy and value head
        self.fc_policy = nn.Linear(RNN_SIZE, a_size)
        self.fc_value = nn.Linear(RNN_SIZE, 1)
        
        # Define initial and current LSTM states
        c_init = torch.zeros(1, RNN_SIZE, dtype=torch.float32)
        h_init = torch.zeros(1, RNN_SIZE, dtype=torch.float32)
        self.state_init = [c_init, h_init]
        self.state_in = (torch.zeros(1, RNN_SIZE, dtype=torch.float32), torch.zeros(1, RNN_SIZE, dtype=torch.float32)) # (cell_state, hidden_state)

        # self.policy, self.value, self.state_out, _ = self._build_net(self.myinput, self.water_res, RNN_SIZE, a_size)
        # self.policy, self.value, self.state_out, _ = self._build_net(self.myinput, self.water_res, a_size)
        self.policy, self.value, self.state_out, _ = self._build_net(self.inputs, self.water_res, a_size)

        if TRAINING:
            self.actions = torch.zeros(batch_size, dtype=torch.int64)
            # print(self.actions.shape)
            self.actions_onehot = F.one_hot(self.actions, a_size).type(torch.float32)
            # print(self.actions_onehot.shape)
            self.target_v = torch.zeros(batch_size, dtype=torch.float32)
            self.advantages = torch.zeros(batch_size, dtype=torch.float32)
            self.responsible_outputs = torch.sum(self.policy * self.actions_onehot, dim=1)

            # Loss Functions
            self.value_loss = 0.5 * torch.sum((self.target_v - self.value.view(-1))**2)
            self.entropy = -0.01 * torch.sum(self.policy * torch.log(torch.clamp(self.policy, 1e-10, 1.0)))
            self.policy_loss = -torch.sum(torch.log(torch.clamp(self.responsible_outputs, 1e-15, 1.0)) * self.advantages)
            self.loss = self.value_loss + self.policy_loss - self.entropy

            # Get gradients from local network using local losses and
            # normalize the gradients using clipping
            trainable_vars = list(self.parameters())
            self.gradients = torch.autograd.grad(self.loss, trainable_vars, create_graph=True)
            # self.var_norms = torch.norm(trainable_vars)
            self.var_norms = torch.norm(torch.cat([v.view(-1) for v in trainable_vars]))
            self.grad_norms = torch.nn.utils.clip_grad_norm_(self.gradients, GRAD_CLIP)
            self.apply_grads = trainer(self.parameters(), lr=learning_rate)
        
        # Initialize model weights
        self.apply(weights_init)

        print("QAQ! The network is working!")

    def _build_net(self, inputs, water_res, a_size):
        conv1 = F.relu(self.conv1(inputs))
        conv1a = F.relu(self.conv1a(conv1))
        conv1b = F.relu(self.conv1b(conv1a))
        pool1 = self.pool1(conv1b)

        conv2 = F.relu(self.conv2(pool1))
        conv2a = F.relu(self.conv2a(conv2))
        conv2b = F.relu(self.conv2b(conv2a))
        pool2 = self.pool2(conv2b)

        conv3 = self.conv3(pool2)
        print(conv3.shape)

        flat = torch.flatten(conv3, 1)
        # flat = F.relu(conv3.view(-1, self.flat_size))
        water_layer = F.relu(self.fc0(water_res))
        print(flat.shape)
        print(water_layer.shape)
        hidden_input = torch.cat([flat, water_layer], dim=1)
        print(hidden_input.shape)

        h1 = F.relu(self.fc1(hidden_input))
        h2 = F.relu(self.fc2(h1))
        self.h3 = F.relu(self.fc3(h2 + hidden_input))

        # # Recurrent network for temporal dependencies
        # lstm_cell = nn.LSTMCell(RNN_SIZE, RNN_SIZE)
        # rnn_in = h3.unsqueeze(0)
        # step_size = inputs.size(0)
        # lstm_out, state_out = lstm_cell(rnn_in, self.state_in)
        # lstm_c, lstm_h = state_out
        # state_out = (lstm_c[:1, :], lstm_h[:1, :])
        # rnn_out = lstm_out.view(-1, RNN_SIZE)



        # rnn_in = h3.unsqueeze(0)
        # lstm_cell = nn.LSTMCell(input_size=RNN_SIZE, hidden_size=RNN_SIZE)
        # packed_sequence = rnn_utils.pack_padded_sequence(rnn_in, lengths=[sequence_length], enforce_sorted=False)
        # lstm_outputs, lstm_state = lstm_cell(packed_sequence)
        # unpacked_sequence, _ = rnn_utils.pad_packed_sequence(lstm_outputs)
        # lstm_c, lstm_h = lstm_state
        # state_out = (lstm_c[:1, :], lstm_h[:1, :])
        # rnn_out = unpacked_sequence.view(batch_size, -1)



        # # Get the size of the input to the LSTMCell
        # input_size = self.h3.size(1)
        # lstm_cell = nn.LSTMCell(input_size=input_size, hidden_size=RNN_SIZE)
        # c_init, h_init = self.state_in
        # lstm_cell.bias_ih.data.fill_(0)  # assuming bias_ih is the bias for input gate
        # lstm_cell.bias_hh.data.fill_(0)  # assuming bias_hh is the bias for hidden gate
        # lstm_cell.weight_ih.data = torch.nn.init.xavier_normal_(lstm_cell.weight_ih.data)
        # lstm_cell.weight_hh.data = torch.nn.init.orthogonal_(lstm_cell.weight_hh.data)
        # lstm_out, new_state = lstm_cell(rnn_in, self.state_in)
        # # Extract new LSTM state
        # lstm_c, lstm_h = new_state
        # state_out = (lstm_c.unsqueeze(0), lstm_h.unsqueeze(0))  # unsqueeze to add batch dimension
        # # Reshape the LSTM output
        # rnn_out = lstm_out.view(-1, RNN_SIZE)


        # # Assuming RNN_SIZE is also the hidden size for LSTM
        # rnn_in = self.h3.unsqueeze(0)  # Add a time dimension
        # step_size = inputs.size(0)
        # state_in = (torch.zeros(1, RNN_SIZE), torch.zeros(1, RNN_SIZE))
        # # Perform dynamic LSTM operation
        # lstm_outputs = []
        # lstm_state = state_in
        # for t in range(step_size):
        #     lstm_state = self.lstm_cell(rnn_in[t], lstm_state)
        #     lstm_outputs.append(lstm_state[0])
        # # Convert the list of LSTM outputs to a PyTorch tensor
        # lstm_outputs = torch.stack(lstm_outputs, dim=0)

        # # Extract LSTM cell and hidden states
        # lstm_c, lstm_h = lstm_state
        # state_out = (lstm_c[:1, :], lstm_h[:1, :])

        # # Reshape the LSTM outputs
        # rnn_out = lstm_outputs.view(-1, RNN_SIZE)



        # # Recurrent network for temporal dependencies
        # lstm_cell = nn.LSTMCell(input_size=RNN_SIZE, hidden_size=RNN_SIZE)
        # rnn_in = self.h3.unsqueeze(0)
        # print(rnn_in.shape)
        # # step_size = inputs.size(0)
        # # packed_sequence = torch.nn.utils.rnn.pack_padded_sequence(rnn_in, [step_size], enforce_sorted=False)
        # # lstm_outputs, lstm_state = lstm_cell(packed_sequence, self.state_in)
        # # lstm_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_outputs)
        # # lstm_c, lstm_h = lstm_state
        # # state_out = (lstm_c[:1, :], lstm_h[:1, :])
        # # rnn_out = lstm_outputs.view(-1, RNN_SIZE)
        # lstm_out, lstm_state = lstm_cell(rnn_in, self.state_in)
        # lstm_c, lstm_h = lstm_state
        # # state_out = (lstm_c.unsqueeze(0), lstm_h.unsqueeze(0))
        # state_out = (lstm_c[:1, :], lstm_h[:1, :])
        # # rnn_out = lstm_out.squeeze(0)
        # rnn_out = lstm_out.view(-1, RNN_SIZE)


        rnn_in = self.h3.unsqueeze(0)
        # TODO: Modify time step HERE.
        # sequence_length = 1
        # rnn_in = self.h3.unsqueeze(0).unsqueeze(0).expand(sequence_length, -1, -1)
        lstm_out, lstm_state = self.lstm(rnn_in)
        lstm_c, lstm_h = lstm_state
        state_out = (lstm_c[:1, :], lstm_h[:1, :])
        # rnn_out = lstm_out.squeeze(0)
        rnn_out = lstm_out.view(-1, RNN_SIZE)


        policy_layer = self.fc_policy(rnn_out)
        policy = F.softmax(policy_layer, dim=1)
        policy_sig = torch.sigmoid(policy_layer)
        value = self.fc_value(rnn_out)

        return policy, value, state_out, policy_sig
    
# Function for weights initialization
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        init.xavier_uniform_(m.weight.data)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.constant_(param.data, 0)

In [6]:
# action size: 4 direction movements, stop, 4 direction spraying with short/long range, and go back to water supply station.
a_size = 4 + 1 + 4 + 1
trainer = torch.optim.SGD
TRAINING = True
GRID_SIZE = 11
learning_rate=1e-4
# agent_num
agent_num = 3

# __init__(self, scope, a, trainer, TRAINING, GRID_SIZE)
net = ACNet(a_size, agent_num, trainer, learning_rate, TRAINING, GRID_SIZE)

torch.Size([3, 116, 1, 1])
torch.Size([3, 116])
torch.Size([3, 12])
torch.Size([3, 128])
QAQ! The network is working!
