In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1,6,3)
        self.conv2 = nn.Conv2d(6,16,3)

        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84,10)

    def num_flat_features(self,x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1,self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    

In [6]:
net = LeNet()

img = torch.rand(16,1,32,32)

out = net(img)

print(out)

tensor([[-0.0944,  0.0356, -0.0139,  0.0042,  0.0238, -0.0795,  0.0848,  0.1034,
          0.0847, -0.0732],
        [-0.1036,  0.0292, -0.0249,  0.0030,  0.0303, -0.0777,  0.0833,  0.1004,
          0.0838, -0.0733],
        [-0.1035,  0.0338, -0.0245,  0.0064,  0.0271, -0.0772,  0.0846,  0.1042,
          0.0829, -0.0718],
        [-0.1015,  0.0284, -0.0251,  0.0100,  0.0285, -0.0825,  0.0888,  0.1000,
          0.0909, -0.0692],
        [-0.1057,  0.0319, -0.0298,  0.0082,  0.0294, -0.0852,  0.0846,  0.0959,
          0.0776, -0.0718],
        [-0.1039,  0.0313, -0.0258,  0.0089,  0.0290, -0.0797,  0.0854,  0.0931,
          0.0832, -0.0686],
        [-0.1021,  0.0345, -0.0264,  0.0063,  0.0269, -0.0835,  0.0894,  0.0928,
          0.0767, -0.0730],
        [-0.1036,  0.0297, -0.0230,  0.0056,  0.0287, -0.0776,  0.0762,  0.0923,
          0.0861, -0.0726],
        [-0.0963,  0.0355, -0.0213, -0.0081,  0.0314, -0.0792,  0.0920,  0.1044,
          0.0928, -0.0640],
        [-0.1028,  

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001,momentum=0.9) # stochastic gradient descent, control learning rate