In [None]:
import torch, torchvision
from tqdm import tqdm

In [None]:
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        self.conv_layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(in_channels=20, out_channels=40, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.lin_layers = torch.nn.Sequential(
            torch.nn.Linear(in_features=40*7*7, out_features=32),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=32, out_features=10),
            torch.nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 40*7*7)
        x = self.lin_layers(x)
        return x

In [None]:
model = ConvNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
xforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # Try with no normalization for now
train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=xforms)
test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=xforms)
batch_size=64
train_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=batch_size,
                                          shuffle=True)

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)

In [None]:
epochs = 15
loss_history = []
for epoch in tqdm(range(epochs)):
    epoch_loss = 0
    val_loss = 0
    for data, labels in train_loader:

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    writer.add_scalar(tag="loss", scalar_value=epoch_loss/len(train_loader), global_step=epoch)
    
writer.close()

In [None]:
correct_count = 0
model.eval()
with torch.no_grad():
    for image, label in test_loader:
        output = model(image)
        
        pred = output.argmax(dim=1, keepdim=True)
        correct_count += pred.eq(label.view_as(pred)).sum().item()
        
print(f"Correct/Total: {correct_count}/{len(test_loader.dataset)}")
print(f"Accuracy: {correct_count/len(test_loader.dataset):.4f}")