In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [4]:
# create a fully connected network

class NN(nn.Module):
    def __init__(self, input_size, num_classes):  #28*28 = 784
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        

In [7]:
model = NN(784, 10)
x = torch.randn(64, 784)
model(x).shape

torch.Size([64, 10])

In [8]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [36]:
# Hyperparameter
input_size = 784
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 10

In [12]:
# load data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size= batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size= batch_size, shuffle=True)

In [13]:
model = NN(input_size=input_size, num_classes=num_classes).to(device)

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

In [18]:
[i for i in model.parameters()]

[Parameter containing:
 tensor([[ 0.0151, -0.0277,  0.0110,  ..., -0.0267, -0.0296, -0.0188],
         [-0.0087,  0.0309,  0.0021,  ...,  0.0169, -0.0242,  0.0082],
         [-0.0008,  0.0060,  0.0250,  ..., -0.0267,  0.0128, -0.0055],
         ...,
         [-0.0299, -0.0311,  0.0224,  ..., -0.0184, -0.0111, -0.0140],
         [-0.0053, -0.0293,  0.0072,  ...,  0.0348,  0.0321, -0.0349],
         [ 0.0037,  0.0350,  0.0193,  ...,  0.0212, -0.0357, -0.0235]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([-0.0311,  0.0214, -0.0041,  0.0308,  0.0009, -0.0212,  0.0087, -0.0104,
          0.0273,  0.0264, -0.0272, -0.0072,  0.0196,  0.0209, -0.0131, -0.0119,
         -0.0246,  0.0113, -0.0122, -0.0092, -0.0065,  0.0002,  0.0127, -0.0313,
          0.0340, -0.0157, -0.0124, -0.0264,  0.0243, -0.0155, -0.0062, -0.0005,
         -0.0210, -0.0116,  0.0342,  0.0005, -0.0173,  0.0287, -0.0239, -0.0218,
         -0.0168,  0.0032,  0.0131, -0.0348, -0.0016,  0.0190,

In [37]:
# Train Network

for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)
        # print(data.shape)
        # print(targets.shape)
        # break

        # Get to correct shape
        data = data.reshape(data.shape[0],-1)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adm step
        optimizer.step()
          

In [33]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on training data')
    else:
        print('Checking accuracy on test data')
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            x = x.reshape(x.shape[0], -1)
            
            scores = model(x)
            # 64*10
            _, prediction = scores.max(1)
            num_correct += (prediction == y).sum()
            num_samples += prediction.size(0)

        print(f'Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')
    model.train()
    

In [38]:
check_accuracy(train_loader, model)

Checking accuracy on training data
Got 59147/60000 with accuracy 98.58


In [39]:
check_accuracy(test_loader, model)

Checking accuracy on test data
Got 9707/10000 with accuracy 97.07
