In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)

from torch.utils.tensorboard import SummaryWriter

print(torch.__version__)
print(torchvision.__version__)

1.5.1
0.6.1


In [2]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [3]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        # (1) input layer
        t = t

        # (2) hidden conv layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        # (3) hidden conv layer
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        # (4) hidden linear layer
        t = t.reshape(-1, 12*4*4)
        t = self.fc1(t)
        t = F.relu(t)

        # (5) hidden linear layer
        t = self.fc2(t)
        t = F.relu(t)

        # (6) output layer
        t = self.out(t)
        #t = F.softmax(t,dim=1)

        return t

In [4]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()
                                                                                         ])
)

In [5]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)


In [6]:
tb = SummaryWriter()

network = Network()
images,labels=next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

tb.add_image('images',grid)
tb.add_graph(network,images)
tb.close()

In [7]:
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100,shuffle=True)
optimiser = optim.Adam(network.parameters(),lr=0.01)

for epoch in range(1):
    
    total_loss=0
    total_correct=0
    
    for batch in train_loader:
        images, labels = batch
        
        preds = network(images)
        loss = F.cross_entropy(preds,labels)
        
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
        total_loss+=loss.item()
        total_correct+=get_num_correct(preds,labels)
        
    print("epoch",epoch,"total_correct:",total_correct,"loss:",total_loss )

images,labels=next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

tb.add_image('images',grid)
tb.add_graph(network,images)
tb.close()

epoch 0 total_correct: 46621 loss: 351.3138353228569


In [8]:
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100,shuffle=True)
optimiser = optim.Adam(network.parameters(),lr=0.01)

images,labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

tb = SummaryWriter()
tb.add_image('images',grid)
tb.add_graph(network,images)

for epoch in range(10):
    
    total_loss=0
    total_correct=0
    
    for batch in train_loader:
        images, labels = batch
        
        preds = network(images)
        loss = F.cross_entropy(preds,labels)
        
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
        total_loss+=loss.item()
        total_correct+=get_num_correct(preds,labels)
        

    tb.add_scalar('Loss',total_loss,epoch)
    tb.add_scalar('Number Correct',total_correct,epoch)
    tb.add_scalar('Accuracy',total_correct/len(train_set),epoch)

    tb.add_histogram('conv1.bias',network.conv1.bias,epoch)
    tb.add_histogram('conv1.weight',network.conv1.weight,epoch)
    tb.add_histogram('conv1.weight.grad',network.conv1.weight.grad,epoch)

    print("epoch",epoch,"total_correct:",total_correct,"loss:",total_loss )

tb.close()

epoch 0 total_correct: 46432 loss: 352.4348495006561
epoch 1 total_correct: 51148 loss: 238.2791469693184
epoch 2 total_correct: 51956 loss: 216.7793028652668
epoch 3 total_correct: 52355 loss: 207.47387935221195
epoch 4 total_correct: 52645 loss: 199.68736720085144
epoch 5 total_correct: 52902 loss: 194.29091149568558
epoch 6 total_correct: 52948 loss: 191.0817735120654
epoch 7 total_correct: 53168 loss: 185.4342116266489
epoch 8 total_correct: 53168 loss: 186.0882028415799
epoch 9 total_correct: 53348 loss: 182.46192450076342
