In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
torch.__version__

'1.0.1.post2'

In [2]:
BATCH_SIZE=512 #大概需要2G的显存
EPOCHS=30 # 总共训练批次
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu") # 让torch判断是否使用GPU
train_loader=torch.utils.data.DataLoader(datasets.MNIST('MNIST',train=True,download=True,
                                                        transform=transforms.Compose([
                                                            transforms.ToTensor(),
                                                            transforms.Normalize((0.1307,),(0.3081,))
                                                        ])),
                                        batch_size=BATCH_SIZE,shuffle=True)
test_loader=torch.utils.data.DataLoader(datasets.MNIST('MNIST',train=False,transform=transforms.Compose([
                                                          transforms.ToTensor(),
                                                          transforms.Normalize((0.1307,),(0.3081,))
                                                      ])),
                                       batch_size=BATCH_SIZE,shuffle=True)

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(1,10,5)
        self.conv2=nn.Conv2d(10,20,3)
        self.fc1=nn.Linear(20*10*10,500)
        self.fc2=nn.Linear(500,10)
    def forward(self,x):
        in_size=x.size(0) # 多少个数据
        out=self.conv1(x) # 24
        out=F.relu(out)
        out=F.max_pool2d(out,2,2) #12
        out=self.conv2(out) #10
        out=F.relu(out)
        out=out.view(in_size,-1)
        out=self.fc1(out)
        out=F.relu(out)
        out=self.fc2(out)
        out=F.log_softmax(out,dim=1)
        return out

In [4]:
model=ConvNet().to(DEVICE)
optimizer=optim.Adam(model.parameters())

In [10]:
def train(model,device,train_loader,optimizer,epoch):
    model.train()
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data)
        loss=F.nll_loss(output,target)  # 损失函数NLLLoss的输入是一个对数概率向量和一个目标标签. 它不会为我们计算对数概率，适合最后一层是log_softmax的网络.
        loss.backward()
        optimizer.step()
        if(batch_idx+1)%30==0:
            print("Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
            epoch,batch_idx*len(data),len(train_loader.dataset),
            100.*batch_idx/len(train_loader),loss.item()))

In [5]:
def test(model,device,test_loader):
    model.eval()
    test_loss=0
    correct=0
    with torch.no_grad():
        for data,target in test_loader:
            data,target=data.to(device),target.to(device)
            output=model(data)
            test_loss += F.nll_loss(output,target,reduction="sum").item() # 将一批损失相加
            pred=output.max(1,keepdim=True)[1] # 找到概率最大的下标
            correct +=pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss:{:.4f},Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss,correct,len(test_loader.dataset),
    100.*correct/len(test_loader.dataset)))

In [11]:
for epoch in range(1,EPOCHS+1):
    train(model,DEVICE,train_loader,optimizer,epoch)
    test(model,DEVICE,test_loader)


Test set: Average loss:0.0391,Accuracy: 9905/10000 (99%)


Test set: Average loss:0.0395,Accuracy: 9913/10000 (99%)


Test set: Average loss:0.0408,Accuracy: 9903/10000 (99%)


Test set: Average loss:0.0427,Accuracy: 9908/10000 (99%)


Test set: Average loss:0.0336,Accuracy: 9924/10000 (99%)


Test set: Average loss:0.0307,Accuracy: 9926/10000 (99%)


Test set: Average loss:0.0311,Accuracy: 9929/10000 (99%)


Test set: Average loss:0.0314,Accuracy: 9929/10000 (99%)


Test set: Average loss:0.0317,Accuracy: 9929/10000 (99%)


Test set: Average loss:0.0321,Accuracy: 9930/10000 (99%)


Test set: Average loss:0.0325,Accuracy: 9930/10000 (99%)


Test set: Average loss:0.0327,Accuracy: 9929/10000 (99%)


Test set: Average loss:0.0331,Accuracy: 9931/10000 (99%)


Test set: Average loss:0.0333,Accuracy: 9930/10000 (99%)


Test set: Average loss:0.0337,Accuracy: 9929/10000 (99%)


Test set: Average loss:0.0339,Accuracy: 9929/10000 (99%)


Test set: Average loss:0.0342,Accuracy: 9929/10000 (99%

In [16]:
torch.save(model.state_dict(),'model/model.pkl')

In [5]:
model.load_state_dict(torch.load('model/model.pkl'))

In [8]:
test(model,DEVICE,test_loader)


Test set: Average loss:0.0388,Accuracy: 9906/10000 (99%)



In [13]:
# 保存模型
torch.save({
    'epoch':epoch+1,
    'state_dict':model.state_dict(),
    'best_result':'99%'
},'model/model.pkl') # pickle 腌菜，腌制食品

In [6]:
# 加载模型
checkpoint=torch.load('model/model.pkl')
start_epoch=checkpoint['epoch']
best_result=checkpoint['best_result']
model.load_state_dict(checkpoint['state_dict'])

In [7]:
test(model,DEVICE,test_loader)


Test set: Average loss:0.0372,Accuracy: 9930/10000 (99%)



In [8]:
best_result

'99%'

In [10]:
params=model.state_dict()
for k,v in params.items():
    print(k) # 打印网络中的变量名
    print(params['conv1.weight']) # 打印conv1的weights
    print(params['conv1.bias'])

conv1.weight
tensor([[[[ 0.0473, -0.1242, -0.1858, -0.0667,  0.2683],
          [ 0.0933,  0.1241,  0.1764,  0.1326,  0.0780],
          [-0.1828,  0.0375,  0.1603, -0.0414, -0.3911],
          [ 0.0498, -0.0743,  0.0710, -0.0566, -0.1369],
          [-0.1982,  0.0006,  0.0075,  0.1487,  0.1845]]],


        [[[-0.2472, -0.3023, -0.1454,  0.2907,  0.0410],
          [-0.3977, -0.1565,  0.2553,  0.2942, -0.0680],
          [-0.3819, -0.1973,  0.3458,  0.0591, -0.0357],
          [-0.2645, -0.2302,  0.1210, -0.0348,  0.2715],
          [-0.0647, -0.0889, -0.1129, -0.0182,  0.1430]]],


        [[[ 0.1846,  0.2150,  0.0590, -0.0621, -0.1213],
          [-0.0629,  0.0778,  0.2212,  0.1107,  0.2403],
          [ 0.0016, -0.0563,  0.3046,  0.0389,  0.1994],
          [-0.1984, -0.0028, -0.1706,  0.0874,  0.0724],
          [-0.3062, -0.3310, -0.3172, -0.0660, -0.3313]]],


        [[[-0.1240, -0.1753, -0.3122, -0.2659, -0.1283],
          [-0.2056, -0.3021, -0.1776, -0.1498, -0.1474],
      