In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

In [0]:
class CartPoleNet(nn.Module): 
  def __init__(self):
    super(CartPoleNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 3)
    self.conv2 = nn.Conv2d(6, 16, 3)
    self.conv3 = nn.Conv2d(16, 32, 3)
    self.conv4 = nn.Conv2d(32, 64, 3)
    # fully connected layers, make sure to add up images here
    self.fc1 = nn.Linear(6 * 6 * 64, 120)
    self.fc2 = nn.Linear(120, 100)
    self.fc3 = nn.Linear(100, 50)
    self.fc4 = nn.Linear(50, 4)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)), 2) 
    x = F.max_pool2d(F.relu(self.conv2(x)), 2) 
    x = F.max_pool2d(F.relu(self.conv3(x)), 2) 
    x = F.max_pool2d(F.relu(self.conv4(x)), 2)
    # sum up along batch axis before moving onto fully connected layers
    x = x.view(-1, self.num_flat_features(x))
    x = torch.sum(x, dim=0)
    # fully connected layers
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))
    x = self.fc4(x)
    return x

  def num_flat_features(self, x):
    size = x.size()[1:]  # all dimensions except the batch dimension
    num_features = 1
    for s in size:
        num_features *= s
    return num_features


In [15]:
net = CartPoleNet()
print(net)

CartPoleNet(
  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=2304, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=50, bias=True)
  (fc4): Linear(in_features=50, out_features=4, bias=True)
)


In [18]:
input = torch.randn(4, 3, 128, 128)
out = net(input)
print(out)

tensor([ 0.0170,  0.0631,  0.1101, -0.1387], grad_fn=<AddBackward0>)
