In [1]:
# 这里我们做一个更强的训练
!wget -c "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
!tar -zxvf "cifar-10-binary.tar.gz"
!mv "cifar-10-batches-bin" "./datasets/cifar-10-batches-bin"

--2022-11-16 17:54:04--  https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170052171 (162M) [application/x-gzip]
Saving to: ‘cifar-10-binary.tar.gz’


2022-11-16 17:54:36 (5.25 MB/s) - ‘cifar-10-binary.tar.gz’ saved [170052171/170052171]

cifar-10-batches-bin/
cifar-10-batches-bin/data_batch_1.bin
cifar-10-batches-bin/batches.meta.txt
cifar-10-batches-bin/data_batch_3.bin
cifar-10-batches-bin/data_batch_4.bin
cifar-10-batches-bin/test_batch.bin
cifar-10-batches-bin/readme.html
cifar-10-batches-bin/data_batch_5.bin
cifar-10-batches-bin/data_batch_2.bin


In [1]:
import mindspore as ms
import mindspore.dataset as ds
from mindspore.dataset import vision
from mindspore.dataset.transforms.transforms import TypeCast
import mindspore.dataset.engine as de # 数据增强引擎。

def create_cifar_dataset(dataset_path, do_train, batch_size=32, image_size=(224, 224), rank_size=1, rank_id=0):
    dataset = ds.Cifar10Dataset(dataset_path, shuffle=do_train,
                                num_shards=rank_size, shard_id=rank_id)

    # define map operations
    trans = []
    # 训练的时候增强数据
    if do_train:
        trans += [
            # vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            # vision.RandomHorizontalFlip(prob=0.5)
            vision.AutoAugment(vision.AutoAugmentPolicy.CIFAR10),# 这是经过实验得出的最优策略
        ]

    trans += [
        vision.Resize(image_size),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    type_cast_op = TypeCast(ms.int32)

    data_set = dataset.map(operations=type_cast_op, input_columns="label")
    data_set = data_set.map(operations=trans, input_columns="image")

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=do_train)
    return data_set
path = 'datasets/cifar-10-batches-bin'
train_loader, test_loader = [create_cifar_dataset(path, do_train) for do_train in [True, False]]

In [None]:
# ms.set_context(mode=ms.context.GRAPH_MODE, device_target="GPU")

In [7]:
# import sys
# sys.argv = []

In [6]:
# import mindspore_hub as mshub
# # model = "mindspore/1.9/res2net50_cifar10" # 这个模型也有，不过就没有意思了
# model = "mindspore/1.9/resnet50_imagenet2012" # 我们体验一下迁移学习
# # model = "mindspore/1.6/googlenet_cifar10"
# network = mshub.load(model, include_top=False, activation="Sigmoid", num_classes=10)
# network.set_train(False) # 这个API比pytorch直观很多

In [8]:
!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/source-codes/resnet.py

--2022-11-16 18:13:53--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/source-codes/resnet.py
Resolving obs.dualstack.cn-north-4.myhuaweicloud.com (obs.dualstack.cn-north-4.myhuaweicloud.com)... 121.36.121.131, 121.36.121.130, 2407:c080:170f:fffb:1:1:2:5, ...
Connecting to obs.dualstack.cn-north-4.myhuaweicloud.com (obs.dualstack.cn-north-4.myhuaweicloud.com)|121.36.121.131|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9521 (9.3K) [binary/octet-stream]
Saving to: ‘resnet.py’


2022-11-16 18:13:53 (960 KB/s) - ‘resnet.py’ saved [9521/9521]



In [10]:
from resnet import resnet50
network = resnet50(batch_size=32, num_classes=10)

In [11]:
import mindspore.nn as nn
from mindspore.nn import SoftmaxCrossEntropyWithLogits
# 这个损失函数的名字比torch清晰一些
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), 0.01, 0.9)

In [18]:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore import load_checkpoint, load_param_into_net
import os
from mindspore import Model
model = Model(network, loss_fn=ls, optimizer=opt, metrics={'acc'})

steps_per_epoch = train_loader.get_dataset_size()
config_ck = ms.CheckpointConfig(save_checkpoint_steps=steps_per_epoch, 
                                keep_checkpoint_max=16)
ckpt_cb = ms.ModelCheckpoint(prefix='CIFAR-10-resnet50', 
                             directory='./checkpoint/ms', config=config_ck)
locc_monitor = ms.LossMonitor(1) # 每个step打印一次loss
time_monitor = TimeMonitor(steps_per_epoch)


In [16]:
from mindspore.train.callback import SummaryCollector
summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_freq=1)


In [19]:
model.train(epoch=1, 
            train_dataset=train_loader, 
            callbacks=[summary_collector, ckpt_cb, locc_monitor, time_monitor], 
            dataset_sink_mode=False)



epoch: 1 step: 1, loss is 2.291393280029297
epoch: 1 step: 2, loss is 2.2892565727233887
epoch: 1 step: 3, loss is 2.274559736251831
epoch: 1 step: 4, loss is 2.3019986152648926
epoch: 1 step: 5, loss is 2.327470302581787
epoch: 1 step: 6, loss is 2.2708733081817627
epoch: 1 step: 7, loss is 2.3094005584716797
epoch: 1 step: 8, loss is 2.2858121395111084
epoch: 1 step: 9, loss is 2.287644624710083
epoch: 1 step: 10, loss is 2.2976865768432617
epoch: 1 step: 11, loss is 2.2875921726226807
epoch: 1 step: 12, loss is 2.2821764945983887
epoch: 1 step: 13, loss is 2.2654058933258057
epoch: 1 step: 14, loss is 2.325120210647583
epoch: 1 step: 15, loss is 2.300588369369507
epoch: 1 step: 16, loss is 2.2634737491607666
epoch: 1 step: 17, loss is 2.3121283054351807
epoch: 1 step: 18, loss is 2.3303170204162598
epoch: 1 step: 19, loss is 2.341153144836426
epoch: 1 step: 20, loss is 2.308833599090576
epoch: 1 step: 21, loss is 2.3046936988830566
epoch: 1 step: 22, loss is 2.303579568862915
epoch:

: 

: 