In [1]:
import torch
import matplotlib.pyplot as plt


In [2]:
#IMPORT MNIST DIGIT DATASET
from torchvision import datasets, transforms


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

training_data = datasets.MNIST('data', train=True, download=True, transform=transform)
# load validation and test data
test_data = datasets.MNIST('data', train=False, download=True, transform=transform)
## split 

In [None]:
## Using transoform.ToTensor() to convert the image to a tensor
## i.e a multi-dimensional array of numbers with a shape of (1, 28, 28)
## The pixel values in the image are normalized to the range of 0 to 1
## and then these values are normalized to the range of -1 to 1 using the transform.Normalize() function
## That is done by subtracting 0.5 from each pixel value and then dividing by 0.5 to get the pixel value in the range of -1 to 1
## assumption here is that the mean of the pixel values is 0.5 and the standard deviation is 0.5
## so subtracting 0.5 from each pixel value will make the mean 0 and dividing by 0.5(i.e already existing sd) will make the standard deviation 1

In [None]:
train_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)


In [None]:
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64)

In [None]:
dataiter = iter(train_loader)
demo_image_tensor, demo_image_target = train_loader.dataset[5]
print("--",demo_image_target)
images, labels = dataiter.__next__()

print(images.shape)
print(labels.shape)

plt.imshow(demo_image_tensor.numpy().squeeze(), cmap='gray_r')

figure = plt.figure()
num_of_images = 60
for index in range(1, num_of_images + 1):
    plt.subplot(6, 10, index)
    plt.axis('off')
    plt.imshow(images[index].numpy().squeeze(), cmap='gray_r')


In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.activation = torch.nn.Sigmoid() ## using sigmoid activation function, because when we 
        ## are normalizing the pixel values to the range of -1 to 1 , we wouldn't want to lose the negative values
        self.fc1 = torch.nn.Linear(28*28, 128)
        self.fc2 = torch.nn.Linear(128, 64)
        self.fc3 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x


In [None]:
net = Net()


In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


In [None]:
#TRAINING
num_epochs = 15
train_loss_history = list()
val_loss_history = list()

for epoch in range(num_epochs):
    net.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        inputs, targets = data
        outputs = net(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward() # this is where the model learns by backpropagating
        optimizer.step() # this is where the model optimizes its weights
        train_loss += loss.item()
        train_correct += (outputs.argmax(1) == targets).sum().item() 
        train_total += targets.size()[0]
    print(f'Epoch {epoch + 1} training accuracy: {((train_correct/train_total*1.0)*100)}% training loss: {train_loss/len(train_loader):.5f}')
    scheduler.step()
    train_loss_history.append(train_loss/len(train_loader))   
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    net.eval()
    for inputs, labels in test_loader:

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)

        _, preds = torch.max(outputs.data, 1)
        val_correct += (preds == labels).sum().item()
        val_total += labels.size()[0]
        val_loss += loss.item()
    print(f'Epoch {epoch + 1} validation accuracy: {((val_correct/val_total*1.0)*100)}% validation loss: {val_loss/len(test_loader):.5f}')
    val_loss_history.append(val_loss/len(test_loader))

In [None]:
plt.plot(train_loss_history, label="Training Loss")
plt.plot(val_loss_history, label="Validation Loss")
plt.legend()
plt.show()

In [None]:
# save the model 
torch.save(net.state_dict(), 'mnist.pth')