In [1]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random

In [2]:
training_epochs = 10
batch_size = 100
seed = 777
learning_rate = 3e-1

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(seed)
torch.manual_seed(seed)
if device == 'cuda':
    torch.cuda.manual_seed_all(seed)

In [4]:
mnist_train = dsets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

In [5]:
linear = torch.nn.Linear(784, 10, bias=True).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(linear.parameters(), lr=learning_rate)

In [6]:
best_acc = -1
for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = len(data_loader)

    for batch, (X, Y) in enumerate(data_loader, 1):
        X = X.view(-1, 28 * 28).to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = linear(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch

        if batch % (total_batch//10) == 0:
            with torch.no_grad():
                ranges = list(range(len(mnist_train)))
                incidence = random.sample(ranges, len(mnist_test))
                X_train = mnist_train.data[incidence].view(-1, 28 * 28).float().to(device)
                Y_train = mnist_train.targets[incidence].to(device)
                
                X_test = mnist_test.data.view(-1, 28 * 28).float().to(device)
                Y_test = mnist_test.targets.to(device)

                prediction = linear(X_train)
                correct_prediction = torch.argmax(prediction, 1) == Y_train
                train_accuracy = correct_prediction.float().mean()

                prediction = linear(X_test)
                correct_prediction = torch.argmax(prediction, 1) == Y_test
                test_accuracy = correct_prediction.float().mean()

                if best_acc < test_accuracy.item():
                    best_acc = test_accuracy.item()

            print('Epoch: {:02d}\tBatch: {:03d}\tcost = {:.9f}\tTrain_Acc: {:.2f}\tTest_Acc: {:.2f}\t\tBest_Test_Acc: {:.2f}'.format(epoch + 1, batch, cost.item(), train_accuracy.item()*100, test_accuracy.item()*100, best_acc*100))
    print('='*120)
print('Learning finished')

Epoch: 01	Batch: 060	cost = 0.706622005	Train_Acc: 86.36	Test_Acc: 87.31		Best_Test_Acc: 87.31
Epoch: 01	Batch: 120	cost = 0.463007241	Train_Acc: 88.44	Test_Acc: 89.19		Best_Test_Acc: 89.19
Epoch: 01	Batch: 180	cost = 0.319952726	Train_Acc: 88.86	Test_Acc: 89.65		Best_Test_Acc: 89.65
Epoch: 01	Batch: 240	cost = 0.474801898	Train_Acc: 88.71	Test_Acc: 89.57		Best_Test_Acc: 89.65
Epoch: 01	Batch: 300	cost = 0.515292048	Train_Acc: 89.40	Test_Acc: 90.13		Best_Test_Acc: 90.13
Epoch: 01	Batch: 360	cost = 0.470868409	Train_Acc: 89.20	Test_Acc: 89.90		Best_Test_Acc: 90.13
Epoch: 01	Batch: 420	cost = 0.303796589	Train_Acc: 89.22	Test_Acc: 90.35		Best_Test_Acc: 90.35
Epoch: 01	Batch: 480	cost = 0.385777444	Train_Acc: 88.48	Test_Acc: 89.58		Best_Test_Acc: 90.35
Epoch: 01	Batch: 540	cost = 0.433931261	Train_Acc: 89.54	Test_Acc: 89.97		Best_Test_Acc: 90.35
Epoch: 01	Batch: 600	cost = 0.182896808	Train_Acc: 88.57	Test_Acc: 89.91		Best_Test_Acc: 90.35
Epoch: 02	Batch: 060	cost = 0.605541706	Train_Acc: