In [1]:
from Runner import Trainer
from mobilevit import DynamicMobileViT
from mobilevit import MobileViT
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
batch_size = 128

In [3]:
# 定义数据预处理操作：转换为张量并归一化
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

# 下载并加载训练集
trainset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

# 下载并加载测试集
testset = torchvision.datasets.CIFAR100(
    root="../Data/cifar-100-python", train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2
)

In [4]:
def check_dataset_structure(dataset, name):
    print(f"Checking structure for {name}:")

    # Check the type of the dataset
    print(f"Type of {name}: {type(dataset)}")

    # If the dataset is a PyTorch Dataset, we can check its length and sample shapes
    if hasattr(dataset, "__len__") and hasattr(dataset, "__getitem__"):
        print(f"Length of {name}: {len(dataset)}")

        # Get a sample from the dataset
        sample_data, sample_label = dataset[0]

        # Check the shape of the sample data and label
        print(f"Sample data shape: {sample_data.shape}")
        print(f"Sample label: {sample_label}")

        # Check the type of the sample data and label
        print(f"Sample data type: {type(sample_data)}")
        print(f"Sample label type: {type(sample_label)}")
    else:
        print(
            f"{name} does not have standard PyTorch Dataset methods (__len__ and __getitem__)."
        )

    print("\n")


# Check the structure of trainset and testset
check_dataset_structure(trainloader, "trainset")
check_dataset_structure(testloader, "testset")

Checking structure for trainset:
Type of trainset: <class 'torch.utils.data.dataloader.DataLoader'>
trainset does not have standard PyTorch Dataset methods (__len__ and __getitem__).


Checking structure for testset:
Type of testset: <class 'torch.utils.data.dataloader.DataLoader'>
testset does not have standard PyTorch Dataset methods (__len__ and __getitem__).




In [5]:
model_parameters = {
    "image_size": (32, 32),
    "channels": [16, 16, 24, 24, 32, 32, 48, 48, 64, 64, 120],
    "num_classes": 100,
    "expansion": 1,
    "kernel_size": 3,
    "patch_size": (1, 1),
    "dims": [32, 24, 48],
    # "L": [2, 4, 3],
}
trainer = Trainer(
    MobileViT,
    (trainloader, testloader),
    batch_size=batch_size,
    model_p=model_parameters,
)
trainer.criterion = torch.nn.CrossEntropyLoss()
print(trainer.model)
# trainer.run(num_epochs=2, evaluation_interval=1)

MobileViT(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU()
  )
  (mv2): ModuleList(
    (0): MV2Block(
      (conv): Sequential(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
        (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): MV2Block(
      (conv): Sequential(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
        (3): Conv2d(16, 24, kernel_size=(1, 1), 

In [6]:
trainer.run(num_epochs=2, evaluation_interval=1)

Epoch [1/2], Loss: 4.0028, Training Accuracy: 8.35%, Testing Accuracy: 14.14%, Evaluation Time: 0.44 minutes.
Epoch [2/2], Loss: 3.4717, Training Accuracy: 16.56%, Testing Accuracy: 18.86%, Evaluation Time: 1.68 minutes.


(3.471665129547119, 0.16558, 0.1886)

In [10]:
print(trainer.training_loss)

[4.002791945266724, 3.471665129547119]


In [11]:
print(trainer.training_accuracy)

[0.0835, 0.16558]


In [12]:
print(trainer.testing_accuracy)

[0.1414, 0.1886]


In [7]:
raise Exception("Stop here")

Exception: Stop here

In [None]:
model_parameters = {
    "image_size": (32, 32),
    "channels": [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320],
    "num_classes": 100,
    "expansion": 1,
    "kernel_size": 3,
    "patch_size": (1, 1),
    #'dims': [64, 40, 96],
    # "L": [2, 4, 3],
}
trainer = Trainer(
    DynamicMobileViT,
    (trainloader, testloader),
    batch_size=batch_size,
    model_p=model_parameters,
)
trainer.criterion = torch.nn.CrossEntropyLoss()
print(trainer.model)
# trainer.run(num_epochs=2, evaluation_interval=1)

DynamicMobileViT(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU()
  )
  (mv2): ModuleList(
    (0): MV2Block(
      (conv): Sequential(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
        (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): MV2Block(
      (conv): Sequential(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
        (3): Conv2d(16, 24, kernel_size=(



In [None]:
trainer.run(num_epochs=2, evaluation_interval=1)

RuntimeError: Given groups=1, weight of size [48, 48, 3, 3], expected input[128, 80, 1, 1] to have 48 channels, but got 80 channels instead