In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

torch.manual_seed(1)
epochs = 1
batch_size = 10
lr = 0.001

In [14]:
train_data.data.shape

torch.Size([60000, 28, 28])

In [15]:
train_data.targets.shape

torch.Size([60000])

In [6]:
train_data = torchvision.datasets.MNIST(root='./mnist/', 
                                        train=True, 
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
train_amount = 50 # 60000
train_x = Variable(torch.unsqueeze(train_data.data, dim=1)).type(torch.FloatTensor)[:train_amount]
train_y = train_data.targets[:train_amount]
full_train_x = Variable(torch.unsqueeze(train_data.data, dim=1)).type(torch.FloatTensor)
full_train_y = train_data.targets
train_loader = Data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)


test_amount = 10000 # 10000
test_data = torchvision.datasets.MNIST(root='./mnist/', 
                                        train=False, 
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
test_x = Variable(torch.unsqueeze(test_data.data, dim=1)).type(torch.FloatTensor)[:test_amount]
test_y = test_data.targets[:test_amount]

print(train_x.shape, test_x.shape)

torch.Size([10, 1, 28, 28]) torch.Size([10000, 1, 28, 28])


In [7]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.dense = nn.Sequential(
            nn.Linear(784, 500),
            nn.ReLU(),
            nn.Linear(500, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, 10)
        )
    
    def forward(self, x):
        output = self.dense(x)
        return output
    
cnn = CNN()
print(cnn)

CNN(
  (dense): Sequential(
    (0): Linear(in_features=784, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=50, bias=True)
    (5): ReLU()
    (6): Linear(in_features=50, out_features=10, bias=True)
  )
)


In [11]:
cnn = CNN()

optimizer = torch.optim.Adam(cnn.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x.view(-1, 28 * 28))
        b_y = Variable(y)
        output = cnn(b_x)
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 10 == 0:
            test_output = cnn(test_x.view(-1, 28 * 28))
            test_pred_y = torch.max(test_output, 1)[1].data.squeeze()
            test_accuracy =  (test_pred_y == test_y).sum().item() / float(test_y.size(0))

            output = cnn(train_x.view(-1, 28 * 28))
            train_pred_y = torch.max(output, 1)[1].data.squeeze()
            train_accuracy = (train_pred_y == train_y).sum().item() / float(train_x.size(0))
            
            output = cnn(full_train_x.view(-1, 28 * 28))
            full_train_pred_y = torch.max(output, 1)[1].data.squeeze()
            full_train_accuracy = (full_train_pred_y == full_train_y).sum().item() / float(full_train_x.size(0))
            
            print('Epoch: {} | train loss: {:.2f}| train acc: {:.2f}| full train acc: {:.2f}| test acc: {:.2f}'.format(epoch, float(loss.data), train_accuracy, full_train_accuracy, test_accuracy))

Epoch: 0 | train loss: 2.34| train acc: 0.20| full train acc: 0.30| test acc: 0.30
Epoch: 0 | train loss: 2.26| train acc: 0.20| full train acc: 0.28| test acc: 0.28
Epoch: 0 | train loss: 2.09| train acc: 0.40| full train acc: 0.44| test acc: 0.45
Epoch: 0 | train loss: 1.85| train acc: 0.60| full train acc: 0.53| test acc: 0.52
Epoch: 0 | train loss: 1.17| train acc: 0.70| full train acc: 0.55| test acc: 0.56
Epoch: 0 | train loss: 0.77| train acc: 0.90| full train acc: 0.59| test acc: 0.60
Epoch: 0 | train loss: 0.75| train acc: 0.80| full train acc: 0.72| test acc: 0.72
Epoch: 0 | train loss: 0.54| train acc: 0.80| full train acc: 0.71| test acc: 0.72
Epoch: 0 | train loss: 1.43| train acc: 0.70| full train acc: 0.74| test acc: 0.74
Epoch: 0 | train loss: 0.95| train acc: 0.80| full train acc: 0.70| test acc: 0.71
Epoch: 0 | train loss: 0.64| train acc: 0.80| full train acc: 0.76| test acc: 0.77
Epoch: 0 | train loss: 0.94| train acc: 0.80| full train acc: 0.71| test acc: 0.72
Epoc

KeyboardInterrupt: 