In [1]:
from Runner import Trainer
from mobilevit import DynamicMobileViT
from mobilevit import MobileViT
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
batch_size = 128
pic_size = (32, 32)

In [3]:
import torchvision.transforms as transforms
import torchvision

# 定义数据预处理操作：先调整大小，然后转换为张量并归一化
transform = transforms.Compose(
    [
        transforms.Resize(pic_size),  # <-- 添加这一行来调整图像大小
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

# 下载并加载训练集
trainset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

# 下载并加载测试集
testset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=True, num_workers=2
)

In [4]:
model_parameters = {
    "image_size": (32, 32),
    "channels": [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    "num_classes": 100,
    "expansion": 4,
    "kernel_size": 3,
    "patch_size": (1, 1),
    "dims": [96, 120, 144],
    "L": [2, 4, 3],
}
trainer = Trainer(
    MobileViT,
    (trainloader, testloader),
    batch_size=batch_size,
    model_p=model_parameters,
    learning_rate=0.02,
    scheduler=(20, 0.5),
) 
trainer.criterion = torch.nn.CrossEntropyLoss()
trainer.run(num_epochs=50, evaluation_interval=1)

Epoch [1/50], Loss: 4.3956, Training Accuracy: 3.05%, Testing Accuracy: 3.84%, Evaluation Time: 0.49 minutes.
Epoch [2/50], Loss: 4.1163, Training Accuracy: 5.25%, Testing Accuracy: 6.25%, Evaluation Time: 0.95 minutes.
Epoch [3/50], Loss: 3.9821, Training Accuracy: 7.28%, Testing Accuracy: 8.83%, Evaluation Time: 1.36 minutes.
Epoch [4/50], Loss: 3.8208, Training Accuracy: 9.58%, Testing Accuracy: 9.91%, Evaluation Time: 1.83 minutes.
Epoch [5/50], Loss: 3.7839, Training Accuracy: 10.32%, Testing Accuracy: 11.17%, Evaluation Time: 2.31 minutes.
Epoch [6/50], Loss: 3.6639, Training Accuracy: 12.03%, Testing Accuracy: 14.03%, Evaluation Time: 2.78 minutes.
Epoch [7/50], Loss: 3.5413, Training Accuracy: 14.15%, Testing Accuracy: 14.65%, Evaluation Time: 3.22 minutes.
Epoch [8/50], Loss: 3.6840, Training Accuracy: 12.30%, Testing Accuracy: 12.11%, Evaluation Time: 3.69 minutes.
Epoch [9/50], Loss: 3.5088, Training Accuracy: 14.72%, Testing Accuracy: 16.33%, Evaluation Time: 4.17 minutes.


(2.1755184375, 0.4123, 0.3517)

In [5]:
print(trainer.training_loss)

[4.3956206175231936, 4.116308553771972, 3.9821400183868407, 3.820771668167114, 3.7838640800476075, 3.6639460034179687, 3.5413423414611818, 3.6839945336151123, 3.508766156692505, 3.3942459772491453, 3.3635865324401855, 3.408719341583252, 3.219469340362549, 3.2865312580871584, 3.126504287109375, 3.040766638641357, 2.9816736534881594, 3.0661570531463624, 3.3541440618896483, 3.2579418755340575, 3.005250848007202, 2.961743271255493, 2.893730605392456, 2.791140783996582, 2.8953595093536375, 2.8291816774749754, 2.789451154632568, 2.69336881690979, 2.6475237115478514, 2.732503274612427, 2.706852803039551, 2.6739600845336913, 2.626226271286011, 2.6553423652648926, 2.5540249240112303, 2.477084315338135, 2.5997682550811767, 2.5564525764465333, 2.5445423939514162, 2.4513995276641847, 2.3571987911224364, 2.299244632797241, 2.2842054303741457, 2.315511266555786, 2.274372668991089, 2.238313023223877, 2.1991077896118165, 2.187763157196045, 2.1978079383850098, 2.1755184375]


In [6]:
print(trainer.training_accuracy)

[0.03054, 0.05252, 0.07278, 0.09576, 0.10322, 0.12026, 0.14152, 0.12298, 0.1472, 0.17216, 0.17812, 0.16906, 0.20166, 0.19344, 0.21792, 0.2339, 0.24664, 0.23456, 0.18794, 0.19872, 0.24428, 0.25408, 0.26632, 0.28512, 0.26544, 0.28048, 0.28458, 0.30362, 0.3129, 0.2994, 0.30166, 0.31042, 0.3181, 0.31408, 0.3345, 0.34736, 0.32568, 0.3325, 0.33358, 0.3517, 0.37452, 0.38604, 0.38842, 0.3834, 0.39116, 0.40138, 0.40956, 0.4097, 0.4081, 0.4123]


In [7]:
print(trainer.testing_accuracy)

[0.0384, 0.0625, 0.0883, 0.0991, 0.1117, 0.1403, 0.1465, 0.1211, 0.1633, 0.1829, 0.1728, 0.1908, 0.2036, 0.196, 0.2272, 0.2331, 0.25, 0.2201, 0.1286, 0.2023, 0.245, 0.2592, 0.2592, 0.2728, 0.2816, 0.2569, 0.2862, 0.2891, 0.2977, 0.2688, 0.2905, 0.2853, 0.2962, 0.2954, 0.3074, 0.3077, 0.2736, 0.2874, 0.3178, 0.3156, 0.3297, 0.3345, 0.3403, 0.3294, 0.3377, 0.3391, 0.3418, 0.3449, 0.3317, 0.3517]
