In [None]:
# Load library
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(''))))
sys.path.append(os.path.join(PROJ_DIR,'src'))

import gce_lib as ff

In [None]:
DATASET_NAME = 'cifar'
MODEL_NAME = 'resnet50'
batch_size = 64
MODEL_LABEL_NUM = 100

# Load dataset
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')

train_loader = ff.get_image_test_loader(DATASET_NAME, batch_size, PROJ_DIR)
test_loader = ff.get_image_test_loader(DATASET_NAME, batch_size, PROJ_DIR)

examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [None]:
from torchvision.models import vgg16
# Declare classifier
network = ff.CIFARResnet50Wrapper(output_logits=True, device=device)

In [None]:
MODEL_LR = 1e-4
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=MODEL_LR)#, weight_decay=1e-3)
MODEL_EPOCHS = 20

_, (x_test, y_test) = next(enumerate(test_loader))
x_test = x_test.to(device)
y_test = y_test.to(device)


if MODEL_NAME == 'untrained':
    MODEL_EPOCHS = 0 # Untrained
for epoch in range(MODEL_EPOCHS):
    for x_train, y_train in train_loader:
        x_train = x_train.to(device)
        y_train = y_train.to(device)
        if 'ood' in MODEL_NAME:
            x_train_tensor_masked = masker(x_train) # A different set of RandomMasks for each batch
        else:
            x_train_tensor_masked = x_train # Unmasked
        optimizer.zero_grad()
        
        preds = network(x_train_tensor_masked)
        loss_value = loss(preds, y_train)
        loss_value.backward()        
        optimizer.step()

        train_accuracy = (preds.argmax(dim=1) == y_train).float().mean() 

        test_preds = network.forward(x_test)        
        test_accuracy = (test_preds.argmax(dim=1) == y_test).float().mean() 
        print(f'Epoch {epoch + 1}/{MODEL_EPOCHS} - Loss: {loss_value.item():.4f} - Train accuracy: {train_accuracy:.2f} - Test accuracy: {test_accuracy:.2f}')
        #if test_accuracy > 0.6: # Undertrained
        #    break

In [None]:
from tqdm import tqdm
total_hits = 0
num_elems = 0
network.eval()
for x_test, y_test in tqdm(test_loader):
    x_test = x_test.to(device)
    y_test = y_test.to(device)
    test_preds = network(x_test)
    total_hits += (test_preds.argmax(dim=1) == y_test).float().sum() 
    num_elems += x_test.shape[0]

print(f'Test accuracy: {100*total_hits/num_elems:.2f}%')

In [None]:
# Save model
torch.save(network.state_dict(), os.path.join(PROJ_DIR,'assets','models',f'{DATASET_NAME}-{MODEL_NAME}-mlp.pth'))

In [None]:
network = ff.load_pretrained_cifar_model(os.path.join(PROJ_DIR,'assets','models',f'{DATASET_NAME}-{MODEL_NAME}-mlp.pth'))

In [None]:
network.eval()
test_accuracy = 0
num_batches = 0
for batch_idx, (x_test, y_test) in enumerate(train_loader):
    x_test = x_test.to(device)
    y_test = y_test.to(device)
    num_batches += 1
    test_preds = network.forward(x_test)        
    test_accuracy += (test_preds.argmax(dim=1) == y_test).float().mean()
print((test_accuracy.item()) / num_batches)

In [None]:
import json
MODELS_PATH = os.path.join(PROJ_DIR,'assets','models')
with open(os.path.join(MODELS_PATH, 'model-accuracies.json')) as fIn:
    models = json.load(fIn)
models[f'{DATASET_NAME}-{MODEL_NAME}'] = test_accuracy.item()
with open(os.path.join(MODELS_PATH, 'model-accuracies.json'), 'w') as fOut:
    json.dump(models, fOut)