<font color=black size=5 face=雅黑>**一. MNIST数据集下载**</font>

In [1]:
import torch 
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
torch.manual_seed(1111)
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 20
learning_rate = 0.0001

# MNIST Dataset
train_dataset = dsets.MNIST(root='data/',
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='data/',
                           train=False, 
                           transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

Files already downloaded


<font color=black size=5 face=雅黑>**二. 定义RNN模型**</font>

In [2]:
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, bias=True):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
    
    def forward(self, x):
        x = x.transpose(0,1)
        x, _ = self.lstm(x)  
        x = x[-1]
        x = self.fc(x)
        return x

rnn = RNNModel(input_size, hidden_size, num_layers, num_classes, bias=True)

<font color=black size=5 face=雅黑>**三.其他**</font>

In [3]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

def accuracy(preds, y):
    #round predictions to the closest integer
    rounded_preds = torch.max(preds,dim = 1)[1]
    correct = (rounded_preds == y).float().sum()
    acc = correct / 100
    return acc
    
def train():
    epoch_acc = 0
    epoch_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, sequence_length, input_size)
        labels = labels
        
        # Forward + Backward + Optimize
        optimizer.zero_grad()
        outputs = rnn(images)
        loss = criterion(outputs, labels)
        acc =  accuracy(outputs , labels)
        loss.backward()
        optimizer.step()
        epoch_acc += acc.item()
        epoch_loss += loss.item()
    return epoch_loss/len(train_loader),epoch_acc/len(train_loader)

def evaluate():
    optimizer.zero_grad()
    epoch_acc = 0
    epoch_loss = 0
    for i, (images, labels) in enumerate(test_loader):
        images = images.view(-1, sequence_length, input_size)
        labels = labels
        
        # Forward + Backward + Optimize
        outputs = rnn(images)
        loss = criterion(outputs, labels)
        acc =  accuracy(outputs , labels)
        epoch_acc += acc.item()
        epoch_loss += loss.item()
    return epoch_loss/len(test_loader),epoch_acc/len(test_loader)

<font color=black size=5 face=雅黑>**四.正常的训练**</font>

In [4]:
'''num_epochs =5
   for epoch in range(num_epochs):
        train_loss,tran_acc = train()
        evaluate_loss,evaluate_acc = evaluate()
        print(f'Epoch: {epoch+1:02}')
        print(f'\tTrain Loss: {train_loss*100:.3f} | Train Acc: {tran_acc*100:.2f}%')
        print(f'\tEvaluate Loss: {evaluate_loss*100:.3f} | Evaluate Acc: {evaluate_acc*100:.2f}%')'''

"num_epochs =5\n   for epoch in range(num_epochs):\n        train_loss,tran_acc = train()\n        evaluate_loss,evaluate_acc = evaluate()\n        print(f'Epoch: {epoch+1:02}')\n        print(f'\tTrain Loss: {train_loss*100:.3f} | Train Acc: {tran_acc*100:.2f}%')\n        print(f'\tEvaluate Loss: {evaluate_loss*100:.3f} | Evaluate Acc: {evaluate_acc*100:.2f}%')"

<font color=black size=5 face=雅黑>**五.量化意识训练**</font>

In [5]:
rnn.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(rnn)

  reduce_range will be deprecated in a future release of PyTorch."


RNNModel(
  (lstm): LSTM(
    28, 128, num_layers=2
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (fc): Linear(
    in_features=128, out_features=10, bias=True
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (activation_post_process): FakeQuantize(
      fake_qu

<font color=black size=5 face=雅黑>**六.量化意识训练**</font>

In [6]:
num_epochs =5
for epoch in range(num_epochs):
    train_loss,tran_acc = train()
    evaluate_loss,evaluate_acc = evaluate()
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss*100:.3f} | Train Acc: {tran_acc*100:.2f}%')
    print(f'\tEvaluate Loss: {evaluate_loss*100:.3f} | Evaluate Acc: {evaluate_acc*100:.2f}%')

Epoch: 01
	Train Loss: 128.822 | Train Acc: 56.82%
	Evaluate Loss: 50.556 | Evaluate Acc: 84.93%
Epoch: 02
	Train Loss: 38.553 | Train Acc: 88.80%
	Evaluate Loss: 30.220 | Evaluate Acc: 91.10%
Epoch: 03
	Train Loss: 25.486 | Train Acc: 92.56%
	Evaluate Loss: 21.035 | Evaluate Acc: 93.88%
Epoch: 04
	Train Loss: 19.743 | Train Acc: 94.04%
	Evaluate Loss: 17.760 | Evaluate Acc: 94.81%
Epoch: 05
	Train Loss: 16.679 | Train Acc: 95.02%
	Evaluate Loss: 14.681 | Evaluate Acc: 95.69%
