# Training a single batch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120) # display options for output
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x1f2c25e1cc0>

In [2]:
print(torch.__version__)
print(torchvision.__version__)

1.1.0
0.3.0


In [3]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
        
        
        # (2) hidden conv layer
        
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        # (3) hidden conv layer
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        # (4) hidden linear layer
        t = t.reshape(-1, 12 * 4 * 4)
        
        t = F.relu(self.fc1(t))
        
        # (5) hidden linear layer
        
        t = F.relu(self.fc2(t))
        
        # (6) output layer
        t = self.out(t)
        #t = F.softmax(t,dim = 1)
        
        return t

In [5]:
train_set = torchvision.datasets.FashionMNIST(
    root= './data/FashionMNIST'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [6]:
network = Network()

In [7]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
batch = next(iter(train_loader)) # Get batch
images, labels = batch

# Calculating the Loss

In [8]:
preds = network(images) # Pass batch
loss = F.cross_entropy(preds, labels) # Calculating the loss
loss.item()

2.3028502464294434

# Calculating the Gradients

In [9]:
print(network.conv1.weight.grad)

None


In [10]:
loss.backward()

In [11]:
network.conv1.weight.grad.shape

torch.Size([6, 1, 5, 5])

# Updating the weights

In [12]:
optimizer = optim.Adam(network.parameters(), lr=0.01)

In [13]:
loss.item()

2.3028502464294434

In [14]:
get_num_correct(preds, labels)

10

In [15]:
optimizer.step() # Updating the weights

In [16]:
preds = network(images)
loss = F.cross_entropy(preds, labels)

In [17]:
loss.item()

2.2853760719299316

In [18]:
get_num_correct(preds, labels)

11