#### Defining the loss function (based on Bellman equation)

In [None]:
def calc_loss(batch, net, gamma=0.99, device="cuda"):
    states, actions, rewards, dones, next_states = batch
    
    state_action_values = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
    next_state_values = net(next_states)
    next_state_values = next_state_values.max(1)[0]
    next_state_values[dones==1] = 0.0
    next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * gamma + rewards.to(device)
    return nn.MSELoss()(state_action_values, expected_state_action_values)

#### Definition of the neural network

In [None]:
class Net_2(nn.Module):
    
    channels = [16, 32, 64]
    kernels = [3, 3, 3]
    strides = [1, 1, 1]
    linears = [250,40]
    in_channels = 1
    
    def __init__(self, maze_size, n_actions, rows, cols):
        super().__init__()
        self.rows = rows
        self.cols = cols

        self.conv = nn.Sequential(nn.Conv2d(in_channels = self.in_channels,
                                            out_channels = self.channels[0],
                                            kernel_size = self.kernels[0],
                                            stride = self.strides[0]),
                                  nn.PReLU(),
                                  nn.Conv2d(in_channels = self.channels[0],
                                            out_channels = self.channels[1],
                                            kernel_size = self.kernels[1],
                                            stride = self.strides[1]),
                                  nn.PReLU()
                                 )
        
        size_out_conv = self.get_conv_size(rows, cols)
        
        self.linear = nn.Sequential(nn.Linear(size_out_conv, int(maze_size*1.5)),
                                    nn.PReLU(),
                                    nn.Linear(int(maze_size*1.5),int(maze_size/2)),
                                    nn.PReLU(),
                                    nn.Linear(int(maze_size/2), n_actions),
                                   )

    def forward(self, x):
        x = x.view(len(x), self.in_channels, self.rows, self.cols)
        out_conv = self.conv(x).view(len(x),-1)
        out_lin = self.linear(out_conv)
        return out_lin
    
    def get_conv_size(self, x, y):
        out_conv = self.conv(torch.zeros(1,self.in_channels, x, y))
        return int(np.prod(out_conv.size()))