In [1]:
import torch, torchvision
from tqdm import tqdm

In [2]:
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        self.conv_layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=20, out_channels=40, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.lin_layers = torch.nn.Sequential(
            torch.nn.Linear(in_features=40*7*7, out_features=32),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=32, out_features=10),
            torch.nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 40*7*7)
        x = self.lin_layers(x)
        return x

In [3]:
model = ConvNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
criterion = torch.nn.CrossEntropyLoss()

In [4]:
xforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # Try with no normalization for now
train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=xforms)
test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=xforms)
batch_size=64
train_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=batch_size,
                                          shuffle=True)

In [9]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)

In [10]:
epochs = 15
loss_history = []
for epoch in tqdm(range(epochs)):
    epoch_loss = 0
    val_loss = 0
    for data, labels in train_loader:

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    writer.add_scalar(tag="loss", scalar_value=epoch_loss/len(train_loader), global_step=epoch)
    
writer.close()


  0%|          | 0/15 [00:00<?, ?it/s][A
  7%|▋         | 1/15 [00:44<10:19, 44.24s/it][A
 13%|█▎        | 2/15 [01:28<09:33, 44.15s/it][A
 20%|██        | 3/15 [02:12<08:49, 44.11s/it][A
 27%|██▋       | 4/15 [02:56<08:05, 44.16s/it][A
 33%|███▎      | 5/15 [03:40<07:20, 44.08s/it][A
 40%|████      | 6/15 [04:24<06:37, 44.12s/it][A
 47%|████▋     | 7/15 [05:07<05:49, 43.68s/it][A
 53%|█████▎    | 8/15 [05:49<05:03, 43.38s/it][A
 60%|██████    | 9/15 [06:33<04:20, 43.48s/it][A
 67%|██████▋   | 10/15 [07:16<03:35, 43.18s/it][A
 73%|███████▎  | 11/15 [07:59<02:52, 43.12s/it][A
 80%|████████  | 12/15 [08:41<02:08, 42.98s/it][A
 87%|████████▋ | 13/15 [09:24<01:26, 43.04s/it][A
 93%|█████████▎| 14/15 [10:07<00:42, 42.99s/it][A
100%|██████████| 15/15 [10:49<00:00, 42.48s/it][A

In [7]:
correct_count = 0
model.eval()
with torch.no_grad():
    for image, label in test_loader:
        output = model(image)
        
        pred = output.argmax(dim=1, keepdim=True)
        correct_count += pred.eq(label.view_as(pred)).sum().item()
        
print(f"Correct/Total: {correct_count}/{len(test_loader.dataset)}")
print(f"Accuracy: {correct_count/len(test_loader.dataset):.4f}")

Correct/Total: 9860/10000
Accuracy: 0.9860
