In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchsummary import summary


In [10]:
class DDQNLidar(nn.Module):
    def __init__(self, action_size):
        """ Create Q-network
        Parameters
        ----------
        action_size: int
            number of actions
        device: torch.device
            device on which to the model will be allocated
        """
        super().__init__()

        self.action_size = action_size

        self.width = 20    
        self.height = 20
        self.depth = 2

        self.c1 = nn.Conv2d(2, 16, kernel_size=2, stride=2) #16 x 10 x 10
        self.c2 = nn.Conv2d(16, 64, kernel_size=2, stride=2) # 64 x 5 x 5
        self.c3 = nn.Conv2d(64, 256, kernel_size=5, stride=1) # 256 x 1 x 1

        #results in 1x1x128

        self.flat = nn.Flatten()

        self.fc1 = nn.Linear(256, 32768)
        self.fc2 = nn.Linear(32768, 8192)
        self.fc3 = nn.Linear(8192, 512)
        self.V = nn.Linear(512, 1)
        self.A = nn.Linear(512, action_size)

    def forward(self, observation):
        """ Forward pass to compute Q-values
        Parameters
        ----------
        observation: np.array
            array of state(s)
        Returns
        ----------
        torch.Tensor
            Q-values  
        """

        # if isinstance(observation, torch.Tensor):
        #     print("Is tensor")
        # else:
        #     #b, h, w, c
        #     #b, c, h, w 
        #     observation = torch.from_numpy(observation).to(self.device).permute(0, 3, 1, 2)
        #     observation = observation[:, :, :self.height, :]

        c1 = torch.relu(self.c1(observation))
        c2 = torch.relu(self.c2(c1))
        c3 = torch.relu(self.c3(c2))

        flat = self.flat(c3)

        fc1 = torch.relu(self.fc1(flat))
        fc2 = torch.relu(self.fc2(fc1))
        fc3 = torch.relu(self.fc3(fc2))
        V = self.V(fc3).expand(fc3.size(0), self.action_size)
        A = self.A(fc3)

        Q = V + A - A.mean(1).unsqueeze(1).expand(fc3.size(0), self.action_size)

        return Q

In [14]:

model = DDQNLidar(4)
summary(model, input_size=(2, 20, 20), device="cpu")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 10, 10]             144
            Conv2d-2             [-1, 64, 5, 5]           4,160
            Conv2d-3            [-1, 256, 1, 1]         409,856
           Flatten-4                  [-1, 256]               0
            Linear-5                [-1, 32768]       8,421,376
            Linear-6                 [-1, 8192]     268,443,648
            Linear-7                  [-1, 512]       4,194,816
            Linear-8                    [-1, 1]             513
            Linear-9                    [-1, 4]           2,052
Total params: 281,476,565
Trainable params: 281,476,565
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.34
Params size (MB): 1073.75
Estimated Total Size (MB): 1074.10
-----------------------------