In [1]:
from models.efficientMem import ViT as ViTMem
from models.efficient import ViT as ViT
from linformer import Linformer
import torchvision.transforms as transforms

from util.dataset_tools import get_cifar10, train, test, load_model

In [2]:
device = 'cuda'
should_train = True
model_path = 'weights/efficientMem.pth'

In [None]:
train_loader, test_loader = get_cifar10(
    batch_size=128, new_transforms=[transforms.Resize(224)]
)

In [4]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64,
)

In [5]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=10,
    transformer=efficient_transformer,
    channels=3,
).to(device)

model_mem = ViTMem(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=10,
    transformer=efficient_transformer,
    channels=3,
).to(device)

In [None]:
if should_train:
    train(model, train_loader, test_loader, epochs=200, path=model_path)

Using 3 GPUs
Epoch: 1 Loss: 2.397412

100%|██████████| 79/79 [00:07<00:00, 10.55it/s]


Epoch: 1 Loss: 0.018548 Test Loss: 0.019032 Test Accuracy: 10.00%
Accuracy improved from 0.00% to 10.00%. Saving model.
Epoch: 2 Loss: 2.027050

100%|██████████| 79/79 [00:07<00:00, 10.86it/s]


Epoch: 2 Loss: 0.017576 Test Loss: 0.016241 Test Accuracy: 21.01%
Accuracy improved from 10.00% to 21.01%. Saving model.
Epoch: 3 Loss: 1.863458

100%|██████████| 79/79 [00:07<00:00, 11.06it/s]


Epoch: 3 Loss: 0.015087 Test Loss: 0.015382 Test Accuracy: 29.60%
Accuracy improved from 21.01% to 29.60%. Saving model.
Epoch: 4 Loss: 1.724913

100%|██████████| 79/79 [00:07<00:00, 11.12it/s]


Epoch: 4 Loss: 0.013531 Test Loss: 0.013260 Test Accuracy: 39.40%
Accuracy improved from 29.60% to 39.40%. Saving model.
Epoch: 5 Loss: 1.413961

100%|██████████| 79/79 [00:07<00:00, 10.84it/s]


Epoch: 5 Loss: 0.012438 Test Loss: 0.012245 Test Accuracy: 44.42%
Accuracy improved from 39.40% to 44.42%. Saving model.
Epoch: 6 Loss: 1.385124

100%|██████████| 79/79 [00:06<00:00, 11.36it/s]


Epoch: 6 Loss: 0.011553 Test Loss: 0.011823 Test Accuracy: 47.44%
Accuracy improved from 44.42% to 47.44%. Saving model.
Epoch: 7 Loss: 1.176718

100%|██████████| 79/79 [00:07<00:00, 11.26it/s]


Epoch: 7 Loss: 0.010721 Test Loss: 0.011483 Test Accuracy: 48.12%
Epoch: 8 Loss: 1.179898

100%|██████████| 79/79 [00:07<00:00, 10.79it/s]


Epoch: 8 Loss: 0.010208 Test Loss: 0.011209 Test Accuracy: 49.86%
Accuracy improved from 48.12% to 49.86%. Saving model.
Epoch: 9 Loss: 1.172237

100%|██████████| 79/79 [00:07<00:00, 11.02it/s]


Epoch: 9 Loss: 0.009768 Test Loss: 0.011230 Test Accuracy: 49.59%
Epoch: 10 Loss: 1.182268

100%|██████████| 79/79 [00:07<00:00, 10.73it/s]


Epoch: 10 Loss: 0.009342 Test Loss: 0.010868 Test Accuracy: 51.57%
Accuracy improved from 49.59% to 51.57%. Saving model.
Epoch: 11 Loss: 1.169676

100%|██████████| 79/79 [00:07<00:00, 11.02it/s]


Epoch: 11 Loss: 0.008954 Test Loss: 0.011126 Test Accuracy: 50.51%
Epoch: 12 Loss: 1.244816

100%|██████████| 79/79 [00:07<00:00, 10.67it/s]


Epoch: 12 Loss: 0.008297 Test Loss: 0.011211 Test Accuracy: 50.99%
Epoch: 13 Loss: 1.106345

100%|██████████| 79/79 [00:07<00:00, 11.09it/s]


Epoch: 13 Loss: 0.007910 Test Loss: 0.011114 Test Accuracy: 52.09%
Accuracy improved from 50.99% to 52.09%. Saving model.
Epoch: 14 Loss: 1.229300

100%|██████████| 79/79 [00:07<00:00, 10.88it/s]


Epoch: 14 Loss: 0.007553 Test Loss: 0.011165 Test Accuracy: 51.60%
Epoch: 15 Loss: 0.879707

100%|██████████| 79/79 [00:07<00:00, 11.27it/s]


Epoch: 15 Loss: 0.007160 Test Loss: 0.011453 Test Accuracy: 51.68%
Epoch: 16 Loss: 0.920472

100%|██████████| 79/79 [00:07<00:00, 11.23it/s]


Epoch: 16 Loss: 0.006774 Test Loss: 0.011464 Test Accuracy: 52.42%
Epoch: 17 Loss: 0.793221

100%|██████████| 79/79 [00:07<00:00, 10.83it/s]


Epoch: 17 Loss: 0.005961 Test Loss: 0.011749 Test Accuracy: 52.93%
Epoch: 18 Loss: 0.818740

100%|██████████| 79/79 [00:07<00:00, 10.92it/s]


Epoch: 18 Loss: 0.005585 Test Loss: 0.012171 Test Accuracy: 52.56%
Epoch: 19 Loss: 0.775679

100%|██████████| 79/79 [00:07<00:00, 10.90it/s]


Epoch: 19 Loss: 0.005255 Test Loss: 0.012381 Test Accuracy: 52.52%
Epoch: 20 Loss: 0.868895

100%|██████████| 79/79 [00:07<00:00, 11.09it/s]


Epoch: 20 Loss: 0.004793 Test Loss: 0.012907 Test Accuracy: 52.17%
Epoch: 21 Loss: 0.646193

100%|██████████| 79/79 [00:07<00:00, 11.09it/s]


Epoch: 21 Loss: 0.004455 Test Loss: 0.014256 Test Accuracy: 51.94%
Epoch: 22 Loss: 0.461640

100%|██████████| 79/79 [00:07<00:00, 11.04it/s]


Epoch: 22 Loss: 0.003511 Test Loss: 0.014324 Test Accuracy: 52.64%
Epoch: 23 Loss: 0.381621

100%|██████████| 79/79 [00:07<00:00, 11.02it/s]


Epoch: 23 Loss: 0.003203 Test Loss: 0.015056 Test Accuracy: 51.90%
Epoch: 24 Loss: 0.314167

100%|██████████| 79/79 [00:07<00:00, 10.98it/s]


Epoch: 24 Loss: 0.002888 Test Loss: 0.015961 Test Accuracy: 53.30%
Accuracy improved from 51.90% to 53.30%. Saving model.
Epoch: 25 Loss: 0.391096

100%|██████████| 79/79 [00:07<00:00, 10.83it/s]


Epoch: 25 Loss: 0.002684 Test Loss: 0.016344 Test Accuracy: 52.01%
Epoch: 26 Loss: 0.320958

100%|██████████| 79/79 [00:07<00:00, 10.85it/s]


Epoch: 26 Loss: 0.002392 Test Loss: 0.017683 Test Accuracy: 50.42%
Epoch: 27 Loss: 0.217278

100%|██████████| 79/79 [00:07<00:00, 10.87it/s]


Epoch: 27 Loss: 0.001541 Test Loss: 0.018958 Test Accuracy: 52.88%
Accuracy improved from 50.42% to 52.88%. Saving model.
Epoch: 28 Loss: 0.109320

100%|██████████| 79/79 [00:07<00:00, 11.03it/s]


Epoch: 28 Loss: 0.001276 Test Loss: 0.019435 Test Accuracy: 53.04%
Epoch: 29 Loss: 0.267086

100%|██████████| 79/79 [00:07<00:00, 11.10it/s]


Epoch: 29 Loss: 0.001285 Test Loss: 0.020716 Test Accuracy: 52.05%
Epoch: 30 Loss: 0.142288

100%|██████████| 79/79 [00:06<00:00, 11.37it/s]


Epoch: 30 Loss: 0.001211 Test Loss: 0.020584 Test Accuracy: 52.40%
Epoch: 31 Loss: 0.152778

100%|██████████| 79/79 [00:07<00:00, 10.99it/s]


Epoch: 31 Loss: 0.001023 Test Loss: 0.021869 Test Accuracy: 51.60%
Epoch: 32 Loss: 0.067727

100%|██████████| 79/79 [00:07<00:00, 11.09it/s]


Epoch: 32 Loss: 0.000466 Test Loss: 0.022885 Test Accuracy: 52.81%
Accuracy improved from 51.60% to 52.81%. Saving model.
Epoch: 33 Loss: 0.031397

100%|██████████| 79/79 [00:07<00:00, 10.92it/s]


Epoch: 33 Loss: 0.000230 Test Loss: 0.024111 Test Accuracy: 52.79%
Epoch: 34 Loss: 0.046444

100%|██████████| 79/79 [00:07<00:00, 10.81it/s]


Epoch: 34 Loss: 0.000152 Test Loss: 0.024707 Test Accuracy: 52.34%
Epoch: 35 Loss: 0.015871

100%|██████████| 79/79 [00:07<00:00, 10.83it/s]


Epoch: 35 Loss: 0.000118 Test Loss: 0.025549 Test Accuracy: 53.33%
Epoch: 36 Loss: 0.094544

100%|██████████| 79/79 [00:07<00:00, 11.01it/s]


Epoch: 36 Loss: 0.000223 Test Loss: 0.026301 Test Accuracy: 51.62%
Epoch: 37 Loss: 0.003474

100%|██████████| 79/79 [00:07<00:00, 10.75it/s]


Epoch: 37 Loss: 0.000169 Test Loss: 0.026313 Test Accuracy: 52.59%
Epoch: 38 Loss: 0.001736

100%|██████████| 79/79 [00:07<00:00, 10.90it/s]


Epoch: 38 Loss: 0.000036 Test Loss: 0.026662 Test Accuracy: 52.88%
Epoch: 39 Loss: 0.001963

100%|██████████| 79/79 [00:07<00:00, 10.87it/s]


Epoch: 39 Loss: 0.000017 Test Loss: 0.027005 Test Accuracy: 53.04%
Epoch: 40 Loss: 0.001810

100%|██████████| 79/79 [00:07<00:00, 11.00it/s]


Epoch: 40 Loss: 0.000013 Test Loss: 0.027323 Test Accuracy: 53.09%
Epoch: 41 Loss: 0.000797

100%|██████████| 79/79 [00:07<00:00, 11.08it/s]


Epoch: 41 Loss: 0.000010 Test Loss: 0.027603 Test Accuracy: 52.97%
Epoch: 42 Loss: 0.000980

100%|██████████| 79/79 [00:07<00:00, 10.83it/s]


Epoch: 42 Loss: 0.000008 Test Loss: 0.027785 Test Accuracy: 52.95%
Epoch: 43 Loss: 0.000879

100%|██████████| 79/79 [00:07<00:00, 10.80it/s]


Epoch: 43 Loss: 0.000007 Test Loss: 0.027935 Test Accuracy: 52.95%
Epoch: 44 Loss: 0.000633

100%|██████████| 79/79 [00:07<00:00, 11.26it/s]


Epoch: 44 Loss: 0.000007 Test Loss: 0.028080 Test Accuracy: 52.99%
Epoch: 45 Loss: 0.000813

100%|██████████| 79/79 [00:07<00:00, 11.06it/s]


Epoch: 45 Loss: 0.000006 Test Loss: 0.028225 Test Accuracy: 52.97%
Epoch: 46 Loss: 0.000538

100%|██████████| 79/79 [00:07<00:00, 10.96it/s]


Epoch: 46 Loss: 0.000006 Test Loss: 0.028348 Test Accuracy: 52.98%
Epoch: 47 Loss: 0.000805

100%|██████████| 79/79 [00:07<00:00, 10.96it/s]


Epoch: 47 Loss: 0.000006 Test Loss: 0.028446 Test Accuracy: 52.95%
Epoch: 48 Loss: 0.000716

KeyboardInterrupt: 

In [None]:
model = load_model(model, model_path)
model_mem = load_model(model_mem, model_path)

In [None]:
acc, _ = test(model, test_loader, device=device)
print(f'Acc: {acc}')

In [None]:
acc_mem, _ = test(model_mem, test_loader, device=device)
print(f'Acc Mem: {acc_mem}')