In [None]:
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.abspath('')))
sys.path.append(os.path.join(PROJ_DIR,'src'))
import xai_faithfulness_experiments_lib_edits as fl

DICT_PATH_TRAIN = os.path.join(PROJ_DIR, 'data', 'cmnist_train_dict.pickle')
DICT_PATH_TEST = os.path.join(PROJ_DIR, 'data', 'cmnist_test_dict.pickle')

In [None]:
data[0].keys()

In [None]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')
BATCH_SIZE_TRAIN = 256
BATCH_SIZE_TEST = 256

train_set = fl.CMNISTDataset(dict_file_path=DICT_PATH_TRAIN)
train_loader = fl.get_cmnist_train_loader(DICT_PATH_TRAIN, BATCH_SIZE_TRAIN)
test_loader = fl.get_cmnist_test_loader(DICT_PATH_TEST, BATCH_SIZE_TRAIN)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
SAMPLE_NUM = 16

x_batch, y_batch = next(iter(train_loader))

plt.imshow(np.moveaxis(x_batch[SAMPLE_NUM].numpy(), 0, -1))
plt.title(y_batch[SAMPLE_NUM])
plt.show()

In [None]:
import torchvision

# Load the pre-trained ResNet18 model.
model = torchvision.models.resnet18(weights='DEFAULT')

# Freeze all the pre-trained layers.
for param in model.parameters():
    param.requires_grad = True

# Modify the last layer for MNIST
num_classes = 10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

In [None]:
MODEL_EPOCHS= 2
MODEL_LR = 1.0e-2
MOMENTUM = 0.9

loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=MODEL_LR, momentum=MOMENTUM)
#optimizer = torch.optim.Adam(model.parameters(), lr=MODEL_LR)
#optimizer = torch.optim.RMSprop(model.parameters(), lr=MODEL_LR)

x_test_batch, y_test_batch = next(iter(test_loader))
x_test_batch = x_test_batch.to(device)
y_test_batch = y_test_batch.to(device)
num_train_batches = len(train_set) // BATCH_SIZE_TRAIN +  1 if len(train_set) % BATCH_SIZE_TRAIN > 0 else 0

for epoch in range(MODEL_EPOCHS):
    batch_num = 0
    for x_batch, y_batch in train_loader:
        x_batch =  x_batch.to(device)
        y_batch =  y_batch.to(device)
        batch_num += 1
        optimizer.zero_grad()
        
        preds = model(x_batch)
        loss_value = loss(preds, y_batch)
        loss_value.backward()        
        optimizer.step()
        print(f'Batch num:{batch_num}/{num_train_batches}\tLoss:{loss_value.item():.4f}\r')

    test_preds = model.forward(x_test_batch)        
    accuracy = (test_preds.argmax(dim=1) == y_test_batch).float().mean() 
    print(f'Epoch {epoch+1}/{MODEL_EPOCHS} - Loss: {loss_value.item()} - Test accuracy: {accuracy}')  
    
model.eval()

test_hits = 0
num_elems = 0
for x_batch, y_batch in test_loader:
    x_batch =  x_batch.to(device)
    y_batch =  y_batch.to(device)
    test_preds = model.forward(x_batch)        
    test_hits += (test_preds.argmax(dim=1) == y_batch).float().sum()
    num_elems += y_batch.shape[0]
print(test_hits / num_elems)

In [None]:
# Save model
torch.save(model.state_dict(), os.path.join(PROJ_DIR,'assets','models','cmnist-resnet18.pth'))

import json
MODELS_PATH = os.path.join(PROJ_DIR,'assets','models')
with open(os.path.join(MODELS_PATH, 'model-accuracies.json')) as fIn:
    models = json.load(fIn)
models['cmnist-resnet18'] = accuracy.item()
with open(os.path.join(MODELS_PATH, 'model-accuracies.json'), 'w') as fOut:
    json.dump(models, fOut)