In [1]:
from Runner import Trainer
from mobilevit import DynamicMobileViT
from mobilevit import MobileViT
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
batch_size = 128
pic_size = (32, 32)

In [3]:
import torchvision.transforms as transforms
import torchvision

# 定义数据预处理操作：先调整大小，然后转换为张量并归一化
transform = transforms.Compose(
    [
        transforms.Resize(pic_size),  # <-- 添加这一行来调整图像大小
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

# 下载并加载训练集
trainset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

# 下载并加载测试集
testset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=True, num_workers=2
)

In [4]:
model_parameters = {
    "image_size": (32, 32),
    "channels": [64, 64, 96, 96, 128, 128, 192, 192, 320, 320, 640],
    "num_classes": 100,
    "expansion": 2,
    "kernel_size": 3,
    "patch_size": (1, 1),
    "dims": [144, 192, 240],
    "L": [4, 8, 6],
}
trainer = Trainer(
    MobileViT,
    (trainloader, testloader),
    batch_size=batch_size,
    model_p=model_parameters,
    learning_rate=0.02,
    scheduler=(100, 0.8),
)
trainer.criterion = torch.nn.CrossEntropyLoss()
trainer.run(num_epochs=1000, evaluation_interval=5)

Epoch [5/1000], Loss: 3.8832, Training Accuracy: 9.82%, Testing Accuracy: 8.46%, Evaluation Time: 4.22 minutes.
Epoch [10/1000], Loss: 3.7546, Training Accuracy: 11.96%, Testing Accuracy: 11.40%, Evaluation Time: 6.86 minutes.
Epoch [15/1000], Loss: 3.9077, Training Accuracy: 10.00%, Testing Accuracy: 12.33%, Evaluation Time: 9.52 minutes.
Epoch [20/1000], Loss: 3.8780, Training Accuracy: 10.37%, Testing Accuracy: 13.50%, Evaluation Time: 12.17 minutes.
Epoch [25/1000], Loss: 3.5532, Training Accuracy: 15.14%, Testing Accuracy: 13.18%, Evaluation Time: 14.79 minutes.
Epoch [30/1000], Loss: 3.5726, Training Accuracy: 15.02%, Testing Accuracy: 15.82%, Evaluation Time: 17.43 minutes.
Epoch [35/1000], Loss: 3.5701, Training Accuracy: 15.27%, Testing Accuracy: 15.17%, Evaluation Time: 20.06 minutes.
Epoch [40/1000], Loss: 4.1713, Training Accuracy: 6.72%, Testing Accuracy: 6.67%, Evaluation Time: 22.68 minutes.
Epoch [45/1000], Loss: 4.1522, Training Accuracy: 5.97%, Testing Accuracy: 7.40%

(1.3496312496948242, 0.61394, 0.2991)

In [5]:
print(trainer.training_loss)

[3.8831650550842287, 3.7545831340789797, 3.90771606552124, 3.8779685932159422, 3.5532317668914795, 3.5726242889404296, 3.570063620529175, 4.171268519897461, 4.152182073059082, 3.7047876470184327, 3.5359194845581055, 3.3163913114929198, 3.2224277948760984, 3.065159268798828, 3.0485261504364014, 2.9992130340576173, 3.1483382707214354, 3.296661267700195, 3.3740972773742675, 3.5634657752227783, 3.3389105000305177, 3.1057157781982423, 3.2769297146606444, 3.1441875242614747, 3.136159824676514, 3.2059114372253417, 2.9885128443908693, 3.031801535644531, 3.0572744145202635, 3.0912300513458253, 2.826985710372925, 2.923717773590088, 2.7774399655914306, 2.7475577796936035, 3.028362559814453, 3.1619902042388914, 2.804385614700317, 2.611591706542969, 2.755512269744873, 2.7616489183044433, 2.5803797372436525, 2.4473050688171387, 2.352517367706299, 2.6711506968688963, 2.2943871115875245, 2.413571512298584, 2.160355765762329, 2.0605505338287355, 3.0005962781524658, 2.772085107269287, 2.908381989440918,

In [6]:
print(trainer.training_accuracy)

[0.09824, 0.11958, 0.1, 0.10374, 0.15138, 0.15022, 0.15274, 0.06716, 0.05974, 0.12294, 0.15958, 0.19448, 0.2064, 0.2382, 0.23944, 0.25406, 0.2264, 0.20362, 0.1899, 0.15202, 0.1903, 0.2335, 0.2008, 0.2246, 0.2296, 0.21658, 0.25248, 0.2466, 0.2403, 0.23678, 0.28424, 0.26832, 0.2929, 0.29826, 0.2454, 0.22082, 0.2917, 0.32486, 0.29858, 0.29664, 0.33724, 0.35922, 0.37894, 0.32218, 0.39228, 0.36566, 0.42216, 0.44558, 0.24952, 0.29474, 0.26714, 0.21974, 0.25382, 0.23404, 0.23738, 0.24006, 0.23126, 0.21762, 0.26894, 0.22076, 0.31182, 0.33882, 0.36782, 0.31546, 0.3553, 0.3537, 0.3759, 0.40074, 0.33392, 0.37244, 0.37392, 0.36984, 0.4126, 0.4254, 0.35202, 0.30998, 0.36508, 0.37752, 0.40576, 0.29054, 0.29106, 0.3199, 0.3414, 0.2994, 0.34984, 0.3727, 0.28462, 0.2986, 0.33874, 0.34682, 0.35946, 0.36924, 0.40272, 0.42692, 0.41312, 0.38596, 0.36848, 0.39806, 0.37362, 0.39088, 0.40672, 0.43882, 0.40048, 0.43954, 0.46788, 0.37824, 0.43894, 0.46766, 0.4653, 0.46368, 0.48382, 0.47742, 0.44868, 0.47782, 0.

In [7]:
print(trainer.testing_accuracy)

[0.0846, 0.114, 0.1233, 0.135, 0.1318, 0.1582, 0.1517, 0.0667, 0.074, 0.1278, 0.0743, 0.1973, 0.2089, 0.2259, 0.2383, 0.1122, 0.2329, 0.2089, 0.1106, 0.1521, 0.1951, 0.2142, 0.2053, 0.2257, 0.2226, 0.1968, 0.2359, 0.2271, 0.2291, 0.2297, 0.2596, 0.25, 0.2653, 0.215, 0.2469, 0.2292, 0.2608, 0.2816, 0.2784, 0.2456, 0.2478, 0.2982, 0.3085, 0.2879, 0.3049, 0.3113, 0.3048, 0.3177, 0.2496, 0.263, 0.2516, 0.2269, 0.2488, 0.2241, 0.219, 0.2091, 0.2294, 0.2031, 0.2212, 0.2377, 0.2611, 0.2781, 0.2938, 0.2781, 0.2907, 0.2851, 0.29, 0.3056, 0.2795, 0.2966, 0.3031, 0.2962, 0.3088, 0.3195, 0.2969, 0.2649, 0.2918, 0.2776, 0.3088, 0.2246, 0.2518, 0.2703, 0.2815, 0.2498, 0.2892, 0.3009, 0.2554, 0.2725, 0.2811, 0.2833, 0.2865, 0.2944, 0.3047, 0.3125, 0.299, 0.2898, 0.2946, 0.2991, 0.3016, 0.2728, 0.3011, 0.3114, 0.2903, 0.3009, 0.3123, 0.2934, 0.3144, 0.3177, 0.3093, 0.3051, 0.3096, 0.3182, 0.2967, 0.306, 0.3095, 0.315, 0.3079, 0.313, 0.3025, 0.3088, 0.308, 0.3126, 0.3173, 0.3157, 0.3181, 0.3092, 0.3176