In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
import torchvision.transforms as trans
import torchvision.datasets as dsets
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import math
import matplotlib.pyplot as plt
%matplotlib inline
device = torch.device('cpu')

In [3]:
trainset = dsets.MNIST(root='../data/mnist',train=True,transform=trans.ToTensor(),download=True)
testset = dsets.MNIST(root='../data/mnist',train=False,transform=trans.ToTensor(),download=True)

train_loader = DataLoader(trainset,batch_size=500,shuffle=True,num_workers=4)
test_loader = DataLoader(testset,batch_size=500,shuffle=False,num_workers=4)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw
Processing...
Done!




In [20]:
class FNN(nn.Module):
  def __init__(self):
    super(FNN,self).__init__()
    self.fc1 = nn.Linear(784,500)
    self.fc2 = nn.Linear(500,300)
    self.fc3 = nn.Linear(300,10)

  def forward(self,x):
    o = x.view(-1,784)
    o = F.relu(self.fc1(o))
    o = F.relu(self.fc2(o))
    o = self.fc3(o)
    return o

x = torch.rand(200,1,28,28)
net = FNN()
print(net)

FNN(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=300, bias=True)
  (fc3): Linear(in_features=300, out_features=10, bias=True)
)


In [0]:
def eval(model,criterion,dataloader):
  loss,accuracy = 0,0
  for bx,by in dataloader:
    bx,by = bx.to(device),by.to(device)
    logit = model(bx)
    error = criterion(logit,by)
    loss += error.item()
    _,pred_y = logit.max(dim=1)
    acc = (pred_y.data==by).float().sum()/bx.size(0)
    accuracy += acc

  accuracy /= len(dataloader)
  loss /= len(dataloader)
  return loss,accuracy
   

In [23]:
lr=1e-2
net = FNN().to(device)
optimizer = torch.optim.Adam(net.parameters(),lr=lr)
scheduler = lr_scheduler.MultiStepLR(optimizer,milestones=[15,25],gamma=0.5)
nepoch = 30
criterion = nn.CrossEntropyLoss()

for epoch in range(nepoch):
  time_st = time.time()
  loss = 0
  for bx,by in train_loader:
    bx,by = bx.to(device),by.to(device)
    logit = net(bx)
    error = criterion(logit,by)
    error.backward()
    loss += error.item()
    optimizer.step()
    optimizer.zero_grad()
  time_ed = time.time()
  loss /= len(train_loader)
  scheduler.step()
  print("epoch:%d,time=%.3f,loss=%.5f"%(epoch+1,time_ed-time_st,loss))


epoch:1,time=6.620,loss=0.27836
epoch:2,time=6.616,loss=0.10067
epoch:3,time=6.581,loss=0.06853
epoch:4,time=6.594,loss=0.05369
epoch:5,time=6.657,loss=0.04556
epoch:6,time=6.807,loss=0.04568
epoch:7,time=6.993,loss=0.04188
epoch:8,time=6.892,loss=0.03590
epoch:9,time=6.920,loss=0.03631
epoch:10,time=6.941,loss=0.03908
epoch:11,time=6.947,loss=0.03378
epoch:12,time=6.981,loss=0.02619
epoch:13,time=6.980,loss=0.02474
epoch:14,time=7.078,loss=0.03449
epoch:15,time=7.057,loss=0.03987
epoch:16,time=7.063,loss=0.01350
epoch:17,time=7.131,loss=0.00267
epoch:18,time=7.211,loss=0.00114
epoch:19,time=7.258,loss=0.00037
epoch:20,time=7.270,loss=0.00015
epoch:21,time=7.274,loss=0.00010
epoch:22,time=7.276,loss=0.00008
epoch:23,time=7.323,loss=0.00006
epoch:24,time=7.307,loss=0.00005
epoch:25,time=7.323,loss=0.00005
epoch:26,time=7.320,loss=0.00004
epoch:27,time=7.350,loss=0.00004
epoch:28,time=7.357,loss=0.00004
epoch:29,time=7.394,loss=0.00003
epoch:30,time=7.359,loss=0.00003


In [24]:
loss,acc = eval(net,criterion,test_loader)
print("loss=%.5f,accuracy=%.3f"%(loss,acc))

loss=0.12092,accuracy=0.985


参数初始化

In [0]:
for name,param in net.named_parameters():
  if name.find('weight')!=-1:
    n_o,n_i = param.size()
    param.data.normal_(0,math.sqrt(2/n_i))
  if name.find('bias')!=-1:
    param.data.zero_()

In [0]:
for name,param in net.named_parameters():
  if name.find('weight')!=-1:
    nn.init.kaiming_normal_(param)
  if name.find('bias')!=-1:
    nn.init.zeros_(param)

In [0]:
modules = list(net.modules())

In [29]:
for i in modules[1:]:
  print(i)
  nn.init.kaiming_normal_(i.weight)
  nn.init.zeros_(i.bias)

Linear(in_features=784, out_features=500, bias=True)
Linear(in_features=500, out_features=300, bias=True)
Linear(in_features=300, out_features=10, bias=True)
