In [1]:
import mindspore.dataset as ds
import numpy as np
from mindspore import nn, Model
from mindspore.train.callback import LossMonitor
from mindspore.dataset import vision


class PetImageClassifier:
    def __init__(self, image_size=227, batch_size=8, learning_rate=0.001, momentum=0.9, epochs=10):
        self.image_size = image_size
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.epochs = epochs

        # Define dataset transformations
        self.mean = [0.5 * 255] * 3
        self.std = [0.5 * 255] * 3
        self.transforms = [
            vision.Resize((self.image_size, self.image_size)),
            vision.Normalize(mean=self.mean, std=self.std),
            vision.HWC2CHW()
        ]

        # Load train dataset
        self.train_dataset = ds.ImageFolderDataset(
            dataset_dir='./dataset/PetImages/train',
            decode=True).map(
            operations=self.transforms, num_parallel_workers=1).batch(self.batch_size)
        self.train_dataset, self.val_dataset = self.train_dataset.split([0.9, 0.1])

        # Load test dataset
        self.test_dataset = ds.ImageFolderDataset(
            dataset_dir='./dataset/PetImages/eval',
            decode=True).map(
            operations=self.transforms, num_parallel_workers=1).batch(self.batch_size)

        self.net = AlexNet()
        self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        self.opt = nn.Momentum(self.net.trainable_params(), learning_rate=self.learning_rate, momentum=self.momentum)
        self.model = Model(self.net, loss_fn=self.loss, optimizer=self.opt, metrics={'accuracy'})

    def train(self):
        self.model.train(
            epoch=self.epochs,
            train_dataset=self.train_dataset,
            callbacks=[LossMonitor()],
            dataset_sink_mode=True)

    def test(self):
        accuracy = self.model.eval(self.val_dataset, dataset_sink_mode=False)
        print(f'Test accuracy: {accuracy}')

    def predict(self):
        predictions = []
        for data in self.test_dataset.create_dict_iterator():
            inputs = data['image']
            output = self.model.predict(inputs)
            pred = np.argmax(output.asnumpy(), axis=1)
            predictions.append(pred)
        print(predictions)


class AlexNet(nn.Cell):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.SequentialCell(
            nn.Conv2d(3, 96, 11, stride=4, pad_mode="valid"),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, 5, pad_mode="same"),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, 3, pad_mode="same"),
            nn.ReLU(),
            nn.Conv2d(384, 384, 3, pad_mode="same"),
            nn.ReLU(),
            nn.Conv2d(384, 256, 3, pad_mode="same"),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Flatten()
        )
        self.classifier = nn.SequentialCell(
            nn.Dense(6 * 6 * 256, 4096),
            nn.ReLU(),
            nn.Dense(4096, 4096),
            nn.ReLU(),
            nn.Dense(4096, 100)
        )

    def construct(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


if __name__ == '__main__':
    # Create an instance of PetImageClassifier and train the model
    pet_image_classifier = PetImageClassifier()
    pet_image_classifier.train()

    # Test the trained model
    pet_image_classifier.test()

    # Make predictions on test dataset
    pet_image_classifier.predict()




epoch: 1 step: 1, loss is 4.604930877685547
epoch: 1 step: 2, loss is 4.604280471801758
epoch: 1 step: 3, loss is 4.603142738342285
epoch: 1 step: 4, loss is 4.60131311416626
epoch: 1 step: 5, loss is 4.599898815155029
epoch: 1 step: 6, loss is 4.5969085693359375
epoch: 1 step: 7, loss is 4.594311714172363
epoch: 1 step: 8, loss is 4.591119766235352
epoch: 1 step: 9, loss is 4.588165760040283
epoch: 1 step: 10, loss is 4.584085464477539
epoch: 1 step: 11, loss is 4.580217361450195
epoch: 1 step: 12, loss is 4.576235771179199
epoch: 1 step: 13, loss is 4.572023391723633
epoch: 1 step: 14, loss is 4.567639350891113
epoch: 1 step: 15, loss is 4.563488483428955
epoch: 1 step: 16, loss is 4.558172225952148
epoch: 1 step: 17, loss is 4.554418563842773
epoch: 1 step: 18, loss is 4.549302101135254
epoch: 1 step: 19, loss is 4.543544769287109
epoch: 1 step: 20, loss is 4.540306091308594
epoch: 1 step: 21, loss is 4.534407615661621
epoch: 1 step: 22, loss is 4.528200626373291
epoch: 1 step: 23, 



epoch: 2 step: 1, loss is 0.4324832856655121
epoch: 2 step: 2, loss is 0.807800829410553
epoch: 2 step: 3, loss is 0.5668818950653076
epoch: 2 step: 4, loss is 0.7725943922996521
epoch: 2 step: 5, loss is 0.42559313774108887
epoch: 2 step: 6, loss is 0.5771293044090271
epoch: 2 step: 7, loss is 0.6826063990592957
epoch: 2 step: 8, loss is 0.6983305215835571
epoch: 2 step: 9, loss is 0.6381983757019043
epoch: 2 step: 10, loss is 0.8669359087944031
epoch: 2 step: 11, loss is 0.783919632434845
epoch: 2 step: 12, loss is 0.6718325018882751
epoch: 2 step: 13, loss is 0.7330284118652344
epoch: 2 step: 14, loss is 0.686238169670105
epoch: 2 step: 15, loss is 0.9054510593414307
epoch: 2 step: 16, loss is 0.6128048896789551
epoch: 2 step: 17, loss is 0.8461189270019531
epoch: 2 step: 18, loss is 0.7585234642028809
epoch: 2 step: 19, loss is 0.9070937037467957
epoch: 2 step: 20, loss is 0.8438024520874023
epoch: 2 step: 21, loss is 0.5411402583122253
epoch: 2 step: 22, loss is 0.6051280498504639



epoch: 3 step: 1, loss is 0.564555823802948
epoch: 3 step: 2, loss is 0.5017753839492798
epoch: 3 step: 3, loss is 0.5629329681396484
epoch: 3 step: 4, loss is 0.5949744582176208
epoch: 3 step: 5, loss is 0.6332755088806152
epoch: 3 step: 6, loss is 0.43098896741867065
epoch: 3 step: 7, loss is 0.5895817279815674
epoch: 3 step: 8, loss is 0.6216392517089844
epoch: 3 step: 9, loss is 0.4797360897064209
epoch: 3 step: 10, loss is 0.6390798091888428
epoch: 3 step: 11, loss is 0.5568848848342896
epoch: 3 step: 12, loss is 0.5227073431015015
epoch: 3 step: 13, loss is 0.5993086695671082
epoch: 3 step: 14, loss is 0.7245907783508301
epoch: 3 step: 15, loss is 0.7104002237319946
epoch: 3 step: 16, loss is 0.44309481978416443
epoch: 3 step: 17, loss is 0.7737090587615967
epoch: 3 step: 18, loss is 0.6110953092575073
epoch: 3 step: 19, loss is 0.5854219794273376
epoch: 3 step: 20, loss is 0.45382559299468994
epoch: 3 step: 21, loss is 0.7689433097839355
epoch: 3 step: 22, loss is 0.389621973037



epoch: 4 step: 1, loss is 0.47063493728637695
epoch: 4 step: 2, loss is 0.586329996585846
epoch: 4 step: 3, loss is 0.5602282881736755
epoch: 4 step: 4, loss is 0.47704917192459106
epoch: 4 step: 5, loss is 0.634671151638031
epoch: 4 step: 6, loss is 0.4510873258113861
epoch: 4 step: 7, loss is 0.6007012128829956
epoch: 4 step: 8, loss is 0.24606826901435852
epoch: 4 step: 9, loss is 0.4211300313472748
epoch: 4 step: 10, loss is 0.6304262280464172
epoch: 4 step: 11, loss is 0.8232211470603943
epoch: 4 step: 12, loss is 0.45671093463897705
epoch: 4 step: 13, loss is 0.5684467554092407
epoch: 4 step: 14, loss is 0.4865541458129883
epoch: 4 step: 15, loss is 0.5513104796409607
epoch: 4 step: 16, loss is 0.5090092420578003
epoch: 4 step: 17, loss is 0.5836883187294006
epoch: 4 step: 18, loss is 0.5037916898727417
epoch: 4 step: 19, loss is 0.2364293932914734
epoch: 4 step: 20, loss is 0.45762065052986145
epoch: 4 step: 21, loss is 0.38240018486976624
epoch: 4 step: 22, loss is 0.7304754257



epoch: 4 step: 2500, loss is 0.2798921465873718
epoch: 4 step: 2501, loss is 0.4671650230884552
epoch: 4 step: 2502, loss is 0.8840817809104919
epoch: 4 step: 2503, loss is 0.3174862563610077
epoch: 4 step: 2504, loss is 0.42528170347213745
epoch: 4 step: 2505, loss is 0.5682119727134705
epoch: 4 step: 2506, loss is 0.3455196022987366
epoch: 4 step: 2507, loss is 0.16962619125843048
epoch: 4 step: 2508, loss is 0.32621893286705017
epoch: 5 step: 1, loss is 0.4449954628944397
epoch: 5 step: 2, loss is 0.33611810207366943
epoch: 5 step: 3, loss is 0.11877581477165222
epoch: 5 step: 4, loss is 0.18464501202106476
epoch: 5 step: 5, loss is 0.43471643328666687
epoch: 5 step: 6, loss is 0.2369561493396759
epoch: 5 step: 7, loss is 0.09789130836725235
epoch: 5 step: 8, loss is 0.8517168164253235
epoch: 5 step: 9, loss is 0.7678530812263489
epoch: 5 step: 10, loss is 0.24687789380550385
epoch: 5 step: 11, loss is 0.5042166113853455
epoch: 5 step: 12, loss is 0.4754415452480316
epoch: 5 step: 1



epoch: 6 step: 1, loss is 0.37908801436424255
epoch: 6 step: 2, loss is 0.6397701501846313
epoch: 6 step: 3, loss is 0.1932392120361328
epoch: 6 step: 4, loss is 0.5709540247917175
epoch: 6 step: 5, loss is 0.3407653272151947
epoch: 6 step: 6, loss is 0.32193055748939514
epoch: 6 step: 7, loss is 0.28332170844078064
epoch: 6 step: 8, loss is 0.29434865713119507
epoch: 6 step: 9, loss is 0.7570362091064453
epoch: 6 step: 10, loss is 0.3674922287464142
epoch: 6 step: 11, loss is 0.6332653164863586
epoch: 6 step: 12, loss is 0.36102285981178284
epoch: 6 step: 13, loss is 0.7055620551109314
epoch: 6 step: 14, loss is 0.4602992534637451
epoch: 6 step: 15, loss is 0.4771195352077484
epoch: 6 step: 16, loss is 0.2933010160923004
epoch: 6 step: 17, loss is 0.26521360874176025
epoch: 6 step: 18, loss is 0.7057885527610779
epoch: 6 step: 19, loss is 0.2394437938928604
epoch: 6 step: 20, loss is 0.39358535408973694
epoch: 6 step: 21, loss is 0.3931790590286255
epoch: 6 step: 22, loss is 0.4631404



epoch: 7 step: 1, loss is 0.23267200589179993
epoch: 7 step: 2, loss is 0.17417100071907043
epoch: 7 step: 3, loss is 0.5129176378250122
epoch: 7 step: 4, loss is 0.14134100079536438
epoch: 7 step: 5, loss is 0.4039674401283264
epoch: 7 step: 6, loss is 0.7248643636703491
epoch: 7 step: 7, loss is 0.24529534578323364
epoch: 7 step: 8, loss is 0.2589629590511322
epoch: 7 step: 9, loss is 0.6925811171531677
epoch: 7 step: 10, loss is 0.08429653942584991
epoch: 7 step: 11, loss is 0.37940770387649536
epoch: 7 step: 12, loss is 0.1323212832212448
epoch: 7 step: 13, loss is 0.2803479731082916
epoch: 7 step: 14, loss is 0.34223324060440063
epoch: 7 step: 15, loss is 0.2244202345609665
epoch: 7 step: 16, loss is 0.17576289176940918
epoch: 7 step: 17, loss is 0.5389060974121094
epoch: 7 step: 18, loss is 0.3325822055339813
epoch: 7 step: 19, loss is 0.16016550362110138
epoch: 7 step: 20, loss is 0.5231564044952393
epoch: 7 step: 21, loss is 0.37519779801368713
epoch: 7 step: 22, loss is 0.4968



epoch: 8 step: 1, loss is 0.20588819682598114
epoch: 8 step: 2, loss is 0.1065378189086914
epoch: 8 step: 3, loss is 0.18660204112529755
epoch: 8 step: 4, loss is 0.26820358633995056
epoch: 8 step: 5, loss is 0.21716108918190002
epoch: 8 step: 6, loss is 0.2191799134016037
epoch: 8 step: 7, loss is 0.31219449639320374
epoch: 8 step: 8, loss is 0.09910470992326736
epoch: 8 step: 9, loss is 0.05361131578683853
epoch: 8 step: 10, loss is 0.1383945196866989
epoch: 8 step: 11, loss is 0.05303046852350235
epoch: 8 step: 12, loss is 0.40637341141700745
epoch: 8 step: 13, loss is 0.4043847620487213
epoch: 8 step: 14, loss is 0.09819456934928894
epoch: 8 step: 15, loss is 0.30940601229667664
epoch: 8 step: 16, loss is 0.4853968620300293
epoch: 8 step: 17, loss is 0.3347378969192505
epoch: 8 step: 18, loss is 0.05933932960033417
epoch: 8 step: 19, loss is 0.09843938052654266
epoch: 8 step: 20, loss is 0.5284938812255859
epoch: 8 step: 21, loss is 0.20718331634998322
epoch: 8 step: 22, loss is 0.



epoch: 9 step: 1, loss is 0.17572495341300964
epoch: 9 step: 2, loss is 0.2603878378868103
epoch: 9 step: 3, loss is 0.1716996282339096
epoch: 9 step: 4, loss is 0.5208238363265991
epoch: 9 step: 5, loss is 0.26201507449150085
epoch: 9 step: 6, loss is 0.09356184303760529
epoch: 9 step: 7, loss is 0.07166824489831924
epoch: 9 step: 8, loss is 0.3219011723995209
epoch: 9 step: 9, loss is 0.10997441411018372
epoch: 9 step: 10, loss is 0.15733066201210022
epoch: 9 step: 11, loss is 0.08393912017345428
epoch: 9 step: 12, loss is 0.04652966186404228
epoch: 9 step: 13, loss is 0.13936394453048706
epoch: 9 step: 14, loss is 0.08156430721282959
epoch: 9 step: 15, loss is 0.14917954802513123
epoch: 9 step: 16, loss is 0.32112953066825867
epoch: 9 step: 17, loss is 0.04240585118532181
epoch: 9 step: 18, loss is 0.3009236752986908
epoch: 9 step: 19, loss is 0.200850710272789
epoch: 9 step: 20, loss is 0.04676799848675728
epoch: 9 step: 21, loss is 0.20418818295001984
epoch: 9 step: 22, loss is 0.



epoch: 10 step: 1, loss is 0.02297615446150303
epoch: 10 step: 2, loss is 0.004764561075717211
epoch: 10 step: 3, loss is 0.3689391016960144
epoch: 10 step: 4, loss is 0.2755938470363617
epoch: 10 step: 5, loss is 0.2133324295282364
epoch: 10 step: 6, loss is 0.4298229217529297
epoch: 10 step: 7, loss is 0.06851629167795181
epoch: 10 step: 8, loss is 0.4676697850227356
epoch: 10 step: 9, loss is 0.210661843419075
epoch: 10 step: 10, loss is 0.057994499802589417
epoch: 10 step: 11, loss is 0.527441680431366
epoch: 10 step: 12, loss is 0.14853563904762268
epoch: 10 step: 13, loss is 0.4333086907863617
epoch: 10 step: 14, loss is 0.025780828669667244
epoch: 10 step: 15, loss is 0.1065327376127243
epoch: 10 step: 16, loss is 0.1896851658821106
epoch: 10 step: 17, loss is 0.2075122892856598
epoch: 10 step: 18, loss is 0.06383974105119705
epoch: 10 step: 19, loss is 0.4537537395954132
epoch: 10 step: 20, loss is 0.5119455456733704
epoch: 10 step: 21, loss is 0.22385498881340027
epoch: 10 ste



Test accuracy: {'accuracy': 0.9507168458781362}
[array([0, 0, 0, 0, 0, 1, 0, 1]), array([1, 1, 1, 0, 0, 0, 1, 1]), array([1, 0, 1, 1, 0, 0, 0, 0]), array([1, 0, 0, 1, 1, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0]), array([0, 1, 1, 1, 1, 0, 0, 1]), array([1, 1, 0, 0, 1, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0, 0]), array([0, 1, 1, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 1, 0, 0]), array([1, 1, 1, 1, 0, 0, 0, 1]), array([1, 1, 1, 1, 0, 0, 0, 0]), array([1, 1, 1, 0, 1, 0, 0, 1]), array([1, 1, 0, 0, 1, 0, 1, 1]), array([1, 0, 0, 0, 1, 1, 0, 0]), array([0, 0, 0, 0, 1, 1, 0, 1]), array([0, 1, 0, 0, 0, 0, 1, 0]), array([1, 1, 1, 0, 0, 1, 1, 1]), array([0, 1, 1, 1, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 1, 0]), array([1, 0, 1, 1, 0, 1, 1, 0]), array([0, 1, 0, 0, 0, 1, 0, 0]), array([0, 1, 1, 0, 0, 0, 0, 1]), array([0, 1, 1, 0, 1, 1, 1, 1]), array([1, 0, 1, 1, 0, 0, 1, 0]), array([0, 0, 0, 1, 1, 1, 0, 1]), array([0, 1, 1, 0, 0, 1, 1, 1]), array([1, 0, 1, 0, 1, 0, 1, 0]), array([0, 1, 1, 0, 1, 0, 0,