In [1]:
from Runner import Trainer
from mobilevit import DynamicMobileViT
from mobilevit import MobileViT
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
batch_size = 128
pic_size = (64, 64)

In [3]:
import torchvision.transforms as transforms
import torchvision

# 定义数据预处理操作：先调整大小，然后转换为张量并归一化
transform = transforms.Compose(
    [
        transforms.Resize(pic_size),  # <-- 添加这一行来调整图像大小
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

# 下载并加载训练集
trainset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

# 下载并加载测试集
testset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2
)

In [4]:
model_parameters = {
    "image_size": pic_size,
    "channels": [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    "num_classes": 100,
    "expansion": pic_size[0] // 2,
    "kernel_size": 3,
    "patch_size": (2, 2),
    "dims": [96, 120, 144],
    "L": [8, 8, 8],
}
trainer = Trainer(
    MobileViT,
    (trainloader, testloader),
    batch_size=batch_size,
    model_p=model_parameters,
)
trainer.criterion = torch.nn.CrossEntropyLoss() 
# print(trainer.model)
trainer.run(num_epochs=50, evaluation_interval=1)

Epoch [1/50], Loss: 4.2620, Training Accuracy: 4.09%, Testing Accuracy: 2.88%, Evaluation Time: 1.35 minutes.
Epoch [2/50], Loss: 4.0250, Training Accuracy: 6.88%, Testing Accuracy: 10.44%, Evaluation Time: 2.68 minutes.
Epoch [3/50], Loss: 3.6815, Training Accuracy: 12.27%, Testing Accuracy: 14.02%, Evaluation Time: 4.02 minutes.
Epoch [4/50], Loss: 3.3931, Training Accuracy: 17.90%, Testing Accuracy: 19.27%, Evaluation Time: 5.36 minutes.
Epoch [5/50], Loss: 3.2315, Training Accuracy: 20.71%, Testing Accuracy: 22.51%, Evaluation Time: 6.70 minutes.
Epoch [6/50], Loss: 3.0250, Training Accuracy: 24.25%, Testing Accuracy: 25.47%, Evaluation Time: 8.03 minutes.
Epoch [7/50], Loss: 2.8279, Training Accuracy: 28.56%, Testing Accuracy: 28.40%, Evaluation Time: 9.37 minutes.
Epoch [8/50], Loss: 2.6504, Training Accuracy: 32.04%, Testing Accuracy: 30.46%, Evaluation Time: 11.94 minutes.
Epoch [9/50], Loss: 2.6486, Training Accuracy: 31.87%, Testing Accuracy: 31.00%, Evaluation Time: 13.27 mi

(1.6460073776626587, 0.54268, 0.3532)

In [6]:
print(trainer.training_loss)

[4.261985183563232, 4.025037463378906, 3.681514261627197, 3.3931316424560545, 3.231485662460327, 3.0249885208129883, 2.827874513092041, 2.6503642908477785, 2.648610160369873, 2.4293954519653322, 2.311944264602661, 2.245968692321777, 2.1073405327606203, 2.0413344165802, 1.9821628901672363, 1.9001611084365844, 1.8611690036010742, 1.8262063994979858, 1.7774003897857666, 1.759181618270874, 1.7390322598648071, 1.7132100914764403, 1.7025542046356201, 1.695543721923828, 1.681934633102417, 1.6752265090179443, 1.6681104083251954, 1.6664517053222656, 1.659771476249695, 1.658633309288025, 1.653932202758789, 1.6552530449676515, 1.655748970375061, 1.6498652701187133, 1.6511372784805298, 1.6479934735107422, 1.6505160472869873, 1.6474957574081421, 1.6539987649536132, 1.6494747262191773, 1.6419018747329712, 1.6495826654434205, 1.6492099010849, 1.6460682270050049, 1.6478718793869018, 1.6510162377166748, 1.6468734450531006, 1.6476946060180664, 1.6479098006820678, 1.6460073776626587]


In [5]:
print(trainer.training_accuracy)

[0.04094, 0.0688, 0.1227, 0.17904, 0.20706, 0.24248, 0.2856, 0.32042, 0.31874, 0.36576, 0.39004, 0.40624, 0.43484, 0.45034, 0.4628, 0.48358, 0.49136, 0.49816, 0.51194, 0.51364, 0.5202, 0.52706, 0.52874, 0.52932, 0.5333, 0.5346, 0.5379, 0.53894, 0.53774, 0.53914, 0.54052, 0.53944, 0.54192, 0.54198, 0.54012, 0.5411, 0.54168, 0.54136, 0.54078, 0.54058, 0.54404, 0.53998, 0.54032, 0.54148, 0.54288, 0.5414, 0.54184, 0.54124, 0.54186, 0.54268]


In [7]:
print(trainer.testing_accuracy)

[0.0288, 0.1044, 0.1402, 0.1927, 0.2251, 0.2547, 0.284, 0.3046, 0.31, 0.33, 0.3337, 0.3352, 0.3484, 0.3467, 0.3534, 0.3498, 0.3553, 0.3546, 0.3542, 0.3534, 0.3535, 0.354, 0.3556, 0.3547, 0.3529, 0.3533, 0.3526, 0.3527, 0.3531, 0.3533, 0.3532, 0.3531, 0.3527, 0.3533, 0.3536, 0.3534, 0.3536, 0.3533, 0.3529, 0.353, 0.3533, 0.3533, 0.353, 0.3532, 0.3531, 0.3531, 0.3532, 0.3534, 0.3532, 0.3532]
