In [1]:
import pickle

from models.efficientMem import ViT as ViTMem
from models.efficient import ViT as ViT
from linformer import Linformer
import torchvision.transforms as transforms

from util.dataset_tools import get_cifar10, train, test, load_model, augment_set
from util.adv_tools import loader_to_data, data_to_loader, get_attack

In [14]:
device = 'cuda'
should_train = False
model_path = 'weights/efficientMem.pth'
SEQ_LENGTH = 5
LAMBDA = 0.1

In [3]:
train_loader, test_loader = get_cifar10(
    batch_size=128, new_transforms=[transforms.Resize(224)]
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64,
)

In [5]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=10,
    transformer=efficient_transformer,
    channels=3,
).to(device)

model_mem = ViTMem(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=10,
    transformer=efficient_transformer,
    channels=3,
).to(device)

In [6]:
if should_train:
    train(model, train_loader, test_loader, epochs=200, path=model_path)

In [7]:
model = load_model(model, model_path)
model_mem = load_model(model_mem, model_path)

In [8]:
acc_mem, _ = test(model_mem, test_loader, parellel=True)
print(f'Acc Mem: {acc_mem}')

100%|██████████| 79/79 [3:50:59<00:00, 175.44s/it]  

Acc Mem: 48.23





In [9]:
acc, _ = test(model, test_loader, parellel=True)
print(f'Acc: {acc}')

100%|██████████| 79/79 [00:06<00:00, 11.41it/s]

Acc: 53.37





In [8]:
x_test, y_test = loader_to_data(test_loader)
x_aug, y_aug = augment_set(x_test, y_test, SEQ_LENGTH)
test_loader_aug = data_to_loader(x_aug, y_aug)

In [15]:
if not should_train:
    x_fgsm = get_attack(
        x_test, "FGSM", LAMBDA / 10, model=model, input_shape=(3, 224, 224)
    )
    x_fgsm_aug, y_fgsm_aug = augment_set(x_fgsm, y_test, SEQ_LENGTH)
    adv_loader_fgsm = data_to_loader(x_fgsm_aug, y_fgsm_aug)

    with open("data/c10/adv_loader_fgsm.pkl", "wb") as f:
        pickle.dump(adv_loader_fgsm, f)
else:
    with open("data/c10/adv_loader_fgsm.pkl", "rb") as f:
        adv_loader_fgsm = pickle.load(f)

acc, _ = test(model, adv_loader_fgsm, "c10_fgsm_sw")
print("Accuracy on FGSM (SW):", acc)

acc, _ = test(model_mem, adv_loader_fgsm, "c10_fgsm_mem")
print("Accuracy on FGSM (Memristor):", acc)

100%|██████████| 782/782 [00:45<00:00, 17.35it/s]


Accuracy on FGSM (SW): 35.566


100%|██████████| 782/782 [41:55:25<00:00, 193.00s/it]   

Accuracy on FGSM (Memristor): 35.942



