In [2]:
import megengine
import megengine.data as data
import megengine.data.transform as T
import megengine.module as M
import megengine.functional as F
import megengine.optimizer as optimizer
import megengine.jit as jit
import megengine.distributed as dist
import multiprocessing as mp

F.sync_batch_norm
M.SyncBatchNorm

In [3]:
class BasicBlock(M.Module):
    """每个ResNet18的Block都包含两层卷积"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        # 第一层卷积，接 BN 和 ReLU
        self.conv1 = M.ConvBnRelu2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=3, stride=stride, padding=1)
        # 第二层卷积，只接 BN
        self.conv2 = M.ConvBn2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, stride=1, padding=1)
        # 残差连接，当输入输出不一致/需要下采样时，用 ConvBn 实现变换
        if in_channels == out_channels and stride == 1:
            self.res_conn = M.Identity()
        else:
            self.res_conn = M.ConvBn2d(
                in_channels=in_channels, out_channels=out_channels,
                kernel_size=1, stride=stride)
    
    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = x + self.res_conn(identity)
        return F.relu(x)

In [4]:
class ResNet18(M.Module):
    def __init__(self):
        self.conv1 = M.ConvBnRelu2d(in_channels=3, out_channels=64,
                                    kernel_size=3, padding=1)
        # 8 个 BasicBlock，3 次下采样(stride=2)，共 8x2=16 层卷积
        self.blocks = M.Sequential(
            BasicBlock(64,  64),
            BasicBlock(64,  64),
            BasicBlock(64,  128, stride=2),
            BasicBlock(128, 128),
            BasicBlock(128, 256, stride=2),
            BasicBlock(256, 256),
            BasicBlock(256, 512, stride=2),
            BasicBlock(512, 512),
        )
        # 全连接分类器，输出维度为 10 类的预测
        self.classifier = M.Sequential(
            M.Dropout(0.2),
            M.Linear(512, 10)
        )
    
    def forward(self, x):
        # 1. 特征提取，输入为 Nx3x32x32 的图片，输出为 Nx512x4x4的张量(Tensor)
        x = self.conv1(x)
        x = self.blocks(x)
        # 2. 4x4平均池化(Average Pooling)
        x = F.avg_pool2d(x, 4)
        x = F.flatten(x, 1)
        # 3. 分类预测
        x = self.classifier(x)
        return x

In [12]:
def training():
    # megengine内置CIFAR10的数据集
    dataset = data.dataset.CIFAR10(root="/data", train=True)
    
    # 构造数据生产线
    dataloader = data.DataLoader(
        dataset,
        sampler=data.RandomSampler(dataset, batch_size=64, drop_last=True),
        transform=T.Compose([
            T.RandomHorizontalFlip(),
            T.Normalize(mean=0., std=255.),  # f(x) = (x - mean) / std
            T.ToMode("CHW"),
        ])
    )
    
    # 构造网络与输入
    model = ResNet18()
    image = megengine.tensor(dtype="float32")
    label = megengine.tensor(dtype="int32")
    
    # 构造网络优化器
    opt = optimizer.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=1e-4)
    
    # 构造静态的计算图以充分发挥性能
    @jit.trace
    def train_func(image, label):
        # 前传
        loglikelihood = model(image)
        loss = F.cross_entropy_with_softmax(loglikelihood, label)
        accuracy = F.accuracy(loglikelihood, label)

        # 反传并更新网络参数
        opt.zero_grad()
        opt.backward(loss)
        opt.step()
        return loss, accuracy
        
    for epoch in range(90):
        # 训练一个epoch == 遍历一次训练数据集
        for i, batch_data in enumerate(dataloader):
            # 进行一次迭代
            image.set_value(batch_data[0])
            label.set_value(batch_data[1])
            
            loss, acc1 = train_func(image, label)

            if i % 50 == 0:
                print("epoch", epoch, "step", i, "loss", loss, "acc@1", acc1)

In [15]:
def worker(rank):
    # 每个子进程需要初始化分布式进程组
    dist.init_process_group(
        master_ip="localhost",  # 主节点的IP地址。单机多卡的情况下可以简单设为localhost
        master_port=2233,       # 进行通信的可用的端口号，0-65535，注意不能被其他进程占用
        world_size=8,           # 参与任务的进程总数（即总进程个数，也等于显卡个数）
        rank=rank,              # 当前进程的进程号（即第几个进程）
        dev=rank,               # 第几个进程用第几块显卡
    )
    print("init process", rank)
    # 开始训练
    training()

def main():
    # 在一台机器上启动 8 个子进程（因为总共有 8 块显卡）
    # 关于 multiprocessing 的使用方法参见python官方文档
    for rank in range(8):
        p = mp.Process(target=worker, args=(rank,))
        p.start()

In [16]:
main()

init process 0
init process 1
init process 2
init process 3
init process 4
init process 5
init process 6
init process 7
epoch 0 step 0 loss Tensor([2.5735], device=gpu6:0) acc@1 Tensor([0.0781], device=gpu6:0)
epoch 0 step 0 loss Tensor([2.4922], device=gpu1:0) acc@1 Tensor([0.0625], device=gpu1:0)
epoch 0 step 0 loss Tensor([2.5264], device=gpu3:0) acc@1 Tensor([0.1406], device=gpu3:0)
epoch 0 step 0 loss Tensor([2.7726], device=gpu5:0) acc@1 Tensor([0.], device=gpu5:0)
epoch 0 step 0 loss Tensor([2.6011], device=gpu2:0) acc@1 Tensor([0.0938], device=gpu2:0)
epoch 0 step 0 loss Tensor([2.5178], device=gpu7:0) acc@1 Tensor([0.0312], device=gpu7:0)
epoch 0 step 0 loss Tensor([2.554], device=gpu4:0) acc@1 Tensor([0.1094], device=gpu4:0)
epoch 0 step 0 loss Tensor([2.5615], device=gpu0:0) acc@1 Tensor([0.1406], device=gpu0:0)
epoch 0 step 50 loss Tensor([1.7417], device=gpu6:0) acc@1 Tensor([0.4219], device=gpu6:0)
epoch 0 step 50 loss Tensor([1.521], device=gpu5:0) acc@1 Tensor([0.5312],