In [1]:
import megengine
import megengine.data as data
import megengine.data.transform as T
import megengine.module as M
import megengine.functional as F
import megengine.optimizer as optimizer
import megengine.jit as jit

In [2]:
class BasicBlock(M.Module):
    """每个ResNet18的Block都包含两层卷积"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        # 第一层卷积，接 BN 和 ReLU
        self.conv1 = M.ConvBnRelu2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=3, stride=stride, padding=1)
        # 第二层卷积，只接 BN
        self.conv2 = M.ConvBn2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, stride=1, padding=1)
        # 残差连接，当输入输出不一致/需要下采样时，用 ConvBn 实现变换
        if in_channels == out_channels and stride == 1:
            self.res_conn = M.Identity()
        else:
            self.res_conn = M.ConvBn2d(
                in_channels=in_channels, out_channels=out_channels,
                kernel_size=1, stride=stride)
    
    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = x + self.res_conn(identity)
        return F.relu(x)

In [3]:
class ResNet18(M.Module):
    def __init__(self):
        self.conv1 = M.ConvBnRelu2d(in_channels=3, out_channels=64,
                                    kernel_size=3, padding=1)
        # 8 个 BasicBlock，3 次下采样(stride=2)，共 8x2=16 层卷积
        self.blocks = M.Sequential(
            BasicBlock(64,  64),
            BasicBlock(64,  64),
            BasicBlock(64,  128, stride=2),
            BasicBlock(128, 128),
            BasicBlock(128, 256, stride=2),
            BasicBlock(256, 256),
            BasicBlock(256, 512, stride=2),
            BasicBlock(512, 512),
        )
        # 全连接分类器，输出维度为 10 类的预测
        self.classifier = M.Sequential(
            M.Dropout(0.2),
            M.Linear(512, 10)
        )
    
    def forward(self, x):
        # 1. 特征提取，输入为 Nx3x32x32 的图片，输出为 Nx512x4x4的张量(Tensor)
        x = self.conv1(x)
        x = self.blocks(x)
        # 2. 4x4平均池化(Average Pooling)
        x = F.avg_pool2d(x, 4)
        x = F.flatten(x, 1)
        # 3. 分类预测
        x = self.classifier(x)
        return x

In [8]:
def training():
    # megengine内置CIFAR10的数据集
    dataset = data.dataset.CIFAR10(root="/data", train=True)
    
    # 构造数据生产线
    dataloader = data.DataLoader(
        dataset,
        sampler=data.RandomSampler(dataset, batch_size=64, drop_last=True),
        transform=T.Compose([
            T.RandomHorizontalFlip(),
            T.Normalize(mean=0., std=255.),  # f(x) = (x - mean) / std
            T.ToMode("CHW"),
        ])
    )
    
    # 构造网络与输入
    model = ResNet18()
    image = megengine.tensor(dtype="float32")
    label = megengine.tensor(dtype="int32")
    
    # 构造网络优化器
    opt = optimizer.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=1e-4)
    
    # 构造静态的计算图以充分发挥性能
    @jit.trace
    def train_func(image, label):
        # 前传
        loglikelihood = model(image)
        loss = F.cross_entropy_with_softmax(loglikelihood, label)
        accuracy = F.accuracy(loglikelihood, label)

        # 反传并更新网络参数
        opt.zero_grad()
        opt.backward(loss)
        opt.step()
        return loss, accuracy
        
    for epoch in range(90):
        # 训练一个epoch == 遍历一次训练数据集
        for param_group in opt.param_groups:
            param_group["lr"] = 0.01 * (1 - epoch / 90)
        for i, batch_data in enumerate(dataloader):
            # 进行一次迭代
            image.set_value(batch_data[0])
            label.set_value(batch_data[1])
            
            loss, acc1 = train_func(image, label)

            if i % 50 == 0:
                print("epoch", epoch, "step", i, "loss", loss, "acc@1", acc1)
        
        megengine.save(model.state_dict(), "checkpoint.pkl")

In [9]:
training()

epoch 0 step 0 loss Tensor([2.5812]) acc@1 Tensor([0.0625])
epoch 0 step 50 loss Tensor([1.9665]) acc@1 Tensor([0.25])
epoch 0 step 100 loss Tensor([1.9619]) acc@1 Tensor([0.25])
epoch 0 step 150 loss Tensor([1.7975]) acc@1 Tensor([0.3281])
epoch 0 step 200 loss Tensor([1.7861]) acc@1 Tensor([0.4375])
epoch 0 step 250 loss Tensor([1.6038]) acc@1 Tensor([0.4062])
epoch 0 step 300 loss Tensor([1.8104]) acc@1 Tensor([0.375])
epoch 0 step 350 loss Tensor([1.8505]) acc@1 Tensor([0.4062])
epoch 0 step 400 loss Tensor([1.3248]) acc@1 Tensor([0.5469])
epoch 0 step 450 loss Tensor([1.4438]) acc@1 Tensor([0.5156])
epoch 0 step 500 loss Tensor([1.2161]) acc@1 Tensor([0.5938])
epoch 0 step 550 loss Tensor([1.1768]) acc@1 Tensor([0.5781])
epoch 0 step 600 loss Tensor([1.1995]) acc@1 Tensor([0.5938])
epoch 0 step 650 loss Tensor([1.08]) acc@1 Tensor([0.6406])
epoch 0 step 700 loss Tensor([0.9324]) acc@1 Tensor([0.6562])
epoch 0 step 750 loss Tensor([1.2374]) acc@1 Tensor([0.5781])
epoch 1 step 0 los

[33m11 20:40:34[mgb] [0m[1;4;31mERR caught exception from python callback: python exception:
  Traceback (most recent call last):
    File "/home/zhouyizhuang/.local/lib/python3.6/site-packages/megengine/_internal/mgb.py", line 3171, in call
    self._func(value[0])
    File "/home/zhouyizhuang/.local/lib/python3.6/site-packages/megengine/jit/__init__.py", line 214, in callback
    dest.set_value(value, share=False)
    File "/home/zhouyizhuang/.local/lib/python3.6/site-packages/megengine/core/tensor.py", line 280, in set_value
    self.__val.set_value(value, sync=sync, inplace=inplace, share=share)
    File "/home/zhouyizhuang/.local/lib/python3.6/site-packages/megengine/_internal/mgb.py", line 2291, in set_value
    self._copy_from_value_proxy(w)
    File "/home/zhouyizhuang/.local/lib/python3.6/site-packages/megengine/_internal/mgb.py", line 2185, in _copy_from_value_proxy
    return _mgb.SharedND__copy_from_value_proxy(self, value)
  KeyboardInterrupt[0m
[33m11 20:40:34[mgb] 

KeyboardInterrupt: 