In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [4]:
train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data\MNIST\raw\train-images-idx3-ubyte.gz


9913344it [00:00, 14632633.21it/s]                             


Extracting ../../data\MNIST\raw\train-images-idx3-ubyte.gz to ../../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data\MNIST\raw\train-labels-idx1-ubyte.gz


29696it [00:00, 632119.30it/s]           


Extracting ../../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data\MNIST\raw\t10k-images-idx3-ubyte.gz


1649664it [00:00, 4298107.50it/s]                            


Extracting ../../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


5120it [00:00, ?it/s]                   

Extracting ../../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data\MNIST\raw






In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle =True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle =False)

In [6]:
class NeuralNet(nn.Module):
  def __init__(self, input_size, hidden_size, num_classes):
    super(NeuralNet, self).__init__()
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_size, num_classes)
  def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    return out

In [7]:
model = NeuralNet(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [9]:
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    images = images.reshape(-1, 28*28).to(device)
    labels = labels.to(device)

    outputs = model(images)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer

    if (i+1)%100 ==0:
      print('Epoch [{}/{}], step [{}/{}], Loss:{:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.reshape(-1, 28*28).to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  print('Accuracy of the network on the 10000 test images:{}%'.format(100*correct/total))

Epoch [1/5], step [100/600], Loss:2.2901
Epoch [1/5], step [200/600], Loss:2.3060
Epoch [1/5], step [300/600], Loss:2.3134
Epoch [1/5], step [400/600], Loss:2.3134
Epoch [1/5], step [500/600], Loss:2.2984
Epoch [1/5], step [600/600], Loss:2.2928
Epoch [2/5], step [100/600], Loss:2.3076
Epoch [2/5], step [200/600], Loss:2.3027
Epoch [2/5], step [300/600], Loss:2.3050
Epoch [2/5], step [400/600], Loss:2.2969
Epoch [2/5], step [500/600], Loss:2.3133
Epoch [2/5], step [600/600], Loss:2.2940
Epoch [3/5], step [100/600], Loss:2.3069
Epoch [3/5], step [200/600], Loss:2.3109
Epoch [3/5], step [300/600], Loss:2.2938
Epoch [3/5], step [400/600], Loss:2.3090
Epoch [3/5], step [500/600], Loss:2.2982
Epoch [3/5], step [600/600], Loss:2.3011
Epoch [4/5], step [100/600], Loss:2.3149
Epoch [4/5], step [200/600], Loss:2.2947
Epoch [4/5], step [300/600], Loss:2.2984
Epoch [4/5], step [400/600], Loss:2.3064
Epoch [4/5], step [500/600], Loss:2.3041
Epoch [4/5], step [600/600], Loss:2.3047
Epoch [5/5], ste