In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [28]:
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [6]:
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', 
                                          train=True, 
                                          transform=transforms.ToTensor(), 
                                          download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist_data/', 
                                         train=False, 
                                         transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST\raw\train-images-idx3-ubyte.gz


100.1%

Extracting ./mnist_data/MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST\raw\train-labels-idx1-ubyte.gz


113.5%

Extracting ./mnist_data/MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST\raw\t10k-images-idx3-ubyte.gz


100.4%

Extracting ./mnist_data/MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST\raw\t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./mnist_data/MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [8]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, 
                                          batch_size = batch_size, 
                                          shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, 
                                         batch_size = batch_size, 
                                         shuffle = False)

In [84]:
model = nn.Linear(input_size, num_classes)

# Using CrossEntropyLoss, the label will be aotomatically converted to one-hot data.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [89]:
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, 28*28)
        
        predicted = torch.sigmoid(model(images))
        
        loss = criterion(predicted, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i + 1) % 100 == 0:
            print('Epoch: [{}/{}]\tBatch: [{}\{}]\tLoss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch: [1/5]	Batch: [100\600]	Loss: 1.9844636917114258
Epoch: [1/5]	Batch: [200\600]	Loss: 2.003762722015381
Epoch: [1/5]	Batch: [300\600]	Loss: 1.9857181310653687
Epoch: [1/5]	Batch: [400\600]	Loss: 1.9906806945800781
Epoch: [1/5]	Batch: [500\600]	Loss: 1.967653751373291
Epoch: [1/5]	Batch: [600\600]	Loss: 1.9755593538284302
Epoch: [2/5]	Batch: [100\600]	Loss: 1.9709385633468628
Epoch: [2/5]	Batch: [200\600]	Loss: 1.9800925254821777
Epoch: [2/5]	Batch: [300\600]	Loss: 1.9560660123825073
Epoch: [2/5]	Batch: [400\600]	Loss: 1.9509189128875732
Epoch: [2/5]	Batch: [500\600]	Loss: 1.9749133586883545
Epoch: [2/5]	Batch: [600\600]	Loss: 1.9633018970489502
Epoch: [3/5]	Batch: [100\600]	Loss: 1.9403167963027954
Epoch: [3/5]	Batch: [200\600]	Loss: 1.9404596090316772
Epoch: [3/5]	Batch: [300\600]	Loss: 1.9458744525909424
Epoch: [3/5]	Batch: [400\600]	Loss: 1.9163146018981934
Epoch: [3/5]	Batch: [500\600]	Loss: 1.937270164489746
Epoch: [3/5]	Batch: [600\600]	Loss: 1.9292199611663818
Epoch: [4/5]	

In [90]:
with torch.no_grad():
    correct = 0
    total = len(test_loader)*batch_size
    for images, labels in test_loader:
        images = images.view(-1, 28*28)
        
        outputs = model(images)
        
        _, predicted = torch.max(outputs, dim=1)    # _和predicted分别代表了max()返回的10个数中最大值的具体值和索引，此处需要的只是索引
        
        correct += (predicted == labels).sum().item()
        
    print('Accuracy: {:.4f}'.format(correct/total))

Accuracy: 0.8298


In [91]:
torch.save(model.state_dict(), 'logistic_regression.ckpt')
checkpoint = torch.load('logistic_regression.ckpt')
model_ = nn.Linear(input_size, num_classes)
model_.load_state_dict(checkpoint)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [95]:
model.state_dict()

OrderedDict([('weight',
              tensor([[-0.0023,  0.0246,  0.0085,  ..., -0.0120,  0.0308,  0.0317],
                      [-0.0007, -0.0296, -0.0176,  ..., -0.0350, -0.0251, -0.0017],
                      [ 0.0123, -0.0306,  0.0226,  ...,  0.0105,  0.0078,  0.0039],
                      ...,
                      [ 0.0338, -0.0147,  0.0056,  ..., -0.0021,  0.0070,  0.0329],
                      [-0.0119,  0.0274, -0.0181,  ..., -0.0190,  0.0324,  0.0019],
                      [-0.0084, -0.0049, -0.0236,  ..., -0.0333,  0.0138,  0.0224]])),
             ('bias',
              tensor([-0.0565,  0.0325, -0.0299,  0.0027, -0.0180, -0.0196,  0.0054,  0.0079,
                      -0.0363, -0.0244]))])