In [None]:
from models.efficientMem import ViT as ViTMem
from models.efficient import ViT as ViT
from linformer import Linformer
import torchvision.transforms as transforms

from util.dataset_tools import get_cifar10, train, test, save_model, load_model

In [None]:
device = 'cuda'
should_train = True
model_path = 'weights/efficientMem.pth'

In [None]:
train_loader, test_loader = get_cifar10(
    batch_size=128, new_transforms=[transforms.Resize(224, 224)]
)

In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64,
)

In [None]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)

model_mem = ViTMem(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)

In [None]:
if should_train:
    train(model, train_loader, epochs=100, device=device, path=model_path)

In [None]:
model = load_model(model, model_path)
model_mem = load_model(model_mem, model_path)

In [None]:
acc, _ = test(model, test_loader, device=device)
print(f'Acc: {acc}')

In [None]:
acc_mem, _ = test(model_mem, test_loader, device=device)
print(f'Acc Mem: {acc_mem}')