## README
自定义Vit模型，受限于硬件、数据集及时间，只复现了模型架构、并做了少许训练，超参完全按照ViT paper中base模型设置

In [None]:
from going_modular import data_setup, model_builder, engine, utils
from torchvision import transforms
import torch
from torch import nn
import torchvision
from torchinfo import summary
import matplotlib.pyplot as plt

DEVICE="cuda" if torch.cuda.is_available() else "cpu"
print(f"RUNNING ON {DEVICE} ")
BATCH_SIZE = 1024



# dataloader
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir="data/pizza_steak_sushi/train/",
    test_dir="data/pizza_steak_sushi/test/",
    train_transform=transform,
    test_transform=transform,
    batch_size=BATCH_SIZE
)

class_names

: 

In [None]:
image_batch, label_batch = next(iter(train_dataloader))
image, label = image_batch[0], label_batch[0]
image.shape, label

In [None]:
# test image is currect
image_permuted = image.permute(1,2,0)
plt.imshow(image_permuted)
plt.title(class_names[label])
plt.axis(False)

In [None]:
# define patches
image_size = 224
patch_size = 16
num_of_patches = image_size / patch_size
assert image_size % patch_size == 0, "Image size must be divisible by patch size"
# plot
fig, axs = plt.subplots(
    nrows=image_size // patch_size,
    ncols=image_size // patch_size,
    figsize=(patch_size, patch_size),
)

for idx_r, patch_height in enumerate(range(0, image_size, patch_size)):
    for idx_c, patch_width in enumerate(range(0, image_size, patch_size)):
        axs[idx_r, idx_c].imshow(
            image_permuted[
                patch_height : patch_height + patch_size,
                patch_width : patch_width + patch_size,
                :,
            ]
        )
        axs[idx_r, idx_c].set_xticks([])
        axs[idx_r, idx_c].set_yticks([])
        axs[idx_r, idx_c].set_ylabel(idx_r + 1)
        axs[idx_r, idx_c].set_xlabel(idx_c + 1)
        axs[idx_r, idx_c].label_outer()

In [None]:
# model define
model_vit = model_builder.ViT()

print(f"model name is [{model_vit.__class__.__name__}]")
summary(
    model=model_vit,
    input_size=(1, 3, 224, 224),  # (batch_size, num_patches, embedding_dimension)
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)


In [None]:
# writer
writer = utils.create_summary_writer(
    experiment_name="VIT", model_name=model_vit.__class__.__name__
)

In [None]:
from going_modular import engine

LEARNING_RATE = 1e-4
EPOCHS=150

optimizer = torch.optim.Adam(
    params=model_vit.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
)
loss_fn = torch.nn.CrossEntropyLoss()

model_vit = torch.compile(model=model_vit)
results = engine.train(model=model_vit,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=EPOCHS,
                       device=DEVICE,
                       writer=writer)

In [None]:
# save model
from going_modular import prediction
prediction.plot_loss_curves(results)
utils.save_model(model=model_vit,target_dir="modelzoo",model_name="VitBaseTrainEnd2End.pth")