In [None]:
# Load library
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(''))))
sys.path.append(os.path.join(PROJ_DIR,'src'))

import gce_lib as ff

In [None]:
DATASET_NAME = 'mnist'
MODEL_NAME = 'mlp'

# Load dataset
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')
import torchvision

batch_size = 256

MNIST_PATH = os.path.join(PROJ_DIR, 'data', 'mnist')

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(MNIST_PATH, train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(MNIST_PATH, train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [None]:
# Declare classifier
network = ff.MNISTClassifier().to(device)

In [None]:
MODEL_LR = 1e-3
loss = torch.nn.BCELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=MODEL_LR)#, weight_decay=1e-3)
MODEL_EPOCHS = 10
MODEL_LABEL_NUM = 10

_, (x_test, y_test) = next(enumerate(test_loader))
x_test = x_test.to(device)
y_test = y_test.to(device)

for epoch in range(MODEL_EPOCHS):
    for batch_idx, (x_train, y_train) in enumerate(train_loader):
        x_train = x_train.to(device)
        y_train = y_train.to(device)
        x_train_tensor_masked = x_train # Unmasked
        optimizer.zero_grad()
        
        preds = network(x_train_tensor_masked)
        label_onehot = torch.zeros(y_train.shape[0], MODEL_LABEL_NUM).to(device)
        label_onehot.scatter_(1, y_train.unsqueeze(1), 1)
        loss_value = loss(preds, label_onehot)
        loss_value.backward()        
        optimizer.step()

        train_accuracy = (preds.argmax(dim=1) == y_train).float().mean() 

        test_preds = network.forward(x_test)        
        test_accuracy = (test_preds.argmax(dim=1) == y_test).float().mean() 
        print(f'Epoch {epoch + 1}/{MODEL_EPOCHS} - Loss: {loss_value.item():.4f} - Train accuracy: {train_accuracy:.2f} - Test accuracy: {test_accuracy:.2f}')
        #if test_accuracy > 0.6: # Undertrained
        #    break

In [None]:
# Save model
torch.save(network.state_dict(), os.path.join(PROJ_DIR,'assets','models',f'TEMP{DATASET_NAME}-{MODEL_NAME}-mlp.pth'))

In [None]:
network = ff.load_pretrained_mnist_model(os.path.join(PROJ_DIR,'assets','models',f'{DATASET_NAME}-{MODEL_NAME}-mlp.pth'))

In [None]:
network.eval()
test_accuracy = 0
num_batches = 0
for batch_idx, (x_test, y_test) in enumerate(train_loader):
    x_test = x_test.to(device)
    y_test = y_test.to(device)
    num_batches += 1
    test_preds = network.forward(x_test)        
    test_accuracy += (test_preds.argmax(dim=1) == y_test).float().mean()
print((test_accuracy.item()) / num_batches)

In [None]:
import json
MODELS_PATH = os.path.join(PROJ_DIR,'assets','models')
with open(os.path.join(MODELS_PATH, 'model-accuracies.json')) as fIn:
    models = json.load(fIn)
models[f'{DATASET_NAME}-{MODEL_NAME}'] = test_accuracy.item()
with open(os.path.join(MODELS_PATH, 'model-accuracies.json'), 'w') as fOut:
    json.dump(models, fOut)