In [10]:
import torch
from torch import functional as F
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from visdom import Visdom

In [11]:
batch_size = 128
train_loader = torch.utils.data.DataLoader(datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])),
    batch_size=batch_size,
    shuffle=True)

In [12]:
test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])),
    batch_size=batch_size,
    shuffle=False)

In [13]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.module = nn.Sequential(nn.Linear(784, 256), nn.ReLU(inplace=True),
                                    nn.Linear(256, 64), nn.ReLU(inplace=True),
                                    nn.Linear(64, 10), nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.module(x)
        return x

In [14]:
# viz = Visdom()
# viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
# viz.line([[0.0, 0.0]], [0.],
#          win='test',
#          opts=dict(title='test loss&acc.', legend=['loss', 'acc.']))


viz = Visdom()
viz.line([0.],[0.0],win = 'train_loss', opts=dict(title='train_loss'))
viz.line([0.],[0.0],win='test_acc', opts=dict(title='test_acc'))

net = MLP()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
cel = nn.CrossEntropyLoss()
global_step = 0
for epoch in range(100):

    for step, (x, y) in enumerate(train_loader):
        x = x.reshape(-1, 784)
        logit = net(x)
        loss = cel(logit, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        viz.line([loss.item()], [global_step],
                 win='train_loss',
                 update='append')
        if step % 100 == 0:
            print(epoch, step, 'loss:', loss.item())

    correct = 0
    total_cnt = 0
    for data, target in test_loader:
        # 摊平成shape=[样本数,784]的形状
        data = data.reshape(-1, 28 * 28)
        logits = net(data)
#         test_loss = cel(logits, target)
        #             test_loss += CEL(logits, target).item()
        # 得到的预测值输出是一个10个分量的概率,在第2个维度上取max
        # logits.data是一个shape=[batch_size,10]的Tensor
        # 注意Tensor.max(dim=1)是在这个Tensor的1号维度上求最大值
        # 得到一个含有两个元素的元组,这两个元素都是shape=[batch_size]的Tensor
        # 第一个Tensor里面存的都是最大值的值,第二个Tensor里面存的是对应的索引
        # 这里要取索引,所以取了这个tuple的第二个元素
        #         print(type(logits.data), logits.data.shape,type(logits.data.max(dim=1)))
        pred = torch.softmax(logits, dim=1)
        pred = torch.argmax(pred, dim=1)
        #         pred = logits.data.max(dim=1)[1]
        # 对应位置相等则对应位置为True,这里用sum()即记录了True的数量
        correct += torch.eq(pred, target).float().sum().item()
        #             correct = pred.eq(y.data).sum()
        #             correct_cnt+=correct
        total_cnt += data.shape[0]
    acc = correct / total_cnt * 100.0
    viz.line([acc],[global_step], win='test_acc',update='append')
    viz.images(data.view(-1,1,28,28), win='x')
    viz.text(str(pred.detach().cpu().numpy()), win='pred')
#     viz.line([[test_loss.item(), acc]], [global_step],
#              win='test',
#              update='append')
    print(epoch, 'acc : {:.2f}'.format(acc))

Setting up a new session...


0 0 loss: 2.2973036766052246
0 100 loss: 2.2841978073120117
0 200 loss: 2.262608051300049
0 300 loss: 2.2666118144989014
0 400 loss: 2.255567789077759
0 acc : 39.39
1 0 loss: 2.2496254444122314
1 100 loss: 2.2269506454467773
1 200 loss: 2.1947216987609863
1 300 loss: 2.150509834289551
1 400 loss: 2.1462271213531494
1 acc : 55.60
2 0 loss: 2.1200404167175293
2 100 loss: 2.0836169719696045
2 200 loss: 2.070380926132202
2 300 loss: 2.0236196517944336
2 400 loss: 1.9937987327575684
2 acc : 58.44
3 0 loss: 1.9546140432357788
3 100 loss: 1.8760817050933838
3 200 loss: 1.8718754053115845
3 300 loss: 1.709803581237793
3 400 loss: 1.7052797079086304
3 acc : 67.26
4 0 loss: 1.7146868705749512
4 100 loss: 1.5388858318328857
4 200 loss: 1.5955249071121216
4 300 loss: 1.5448435544967651
4 400 loss: 1.4483821392059326
4 acc : 71.18
5 0 loss: 1.4001269340515137
5 100 loss: 1.128952980041504
5 200 loss: 1.189605474472046
5 300 loss: 1.1683175563812256
5 400 loss: 0.9862909913063049
5 acc : 74.70
6 0 l

KeyboardInterrupt: 