In [1]:
import os
if 'Conformal-Sparsemax/notebooks' in os.getcwd():
    os.chdir(os.path.dirname(os.getcwd()))

In [12]:
from conformal_sparsemax.classifier import CNN, CNN_CIFAR, get_data,train,evaluate

In [13]:
import torch
from torch import nn
from sklearn.metrics import f1_score

In [14]:
from entmax.losses import SparsemaxLoss, Entmax15Loss

In [17]:
loss = 'sparsemax' #sparsemax or softmax
dataset = 'CIFAR100' #CIFAR100 or MNIST

if loss == 'sparsemax':
    criterion = SparsemaxLoss()
elif loss == 'entmax15':
    criterion = Entmax15Loss()
elif loss == 'softmax':
    criterion = torch.nn.NLLLoss()
train_dataloader, dev_dataloader, test_dataloader, _ = get_data(0.2,16,dataset = dataset)

if dataset == 'CIFAR100':
    model = CNN_CIFAR(loss)
elif dataset == 'MNIST':
    model = CNN(loss,n_classes=10,input_size=28,channels=1)
else:
    raise Exception('Wrong dataset name')
    
model, train_history, val_history, f1_history = train(model,
                                            train_dataloader,
                                            dev_dataloader,
                                            criterion,
                                            epochs=50,
                                            patience=3)

_, predicted_labels, true_labels, test_loss = evaluate(
                                                    model,
                                                    test_dataloader,
                                                    criterion)

f1 = f1_score(true_labels, predicted_labels, average='weighted')

print(f'Test loss: {test_loss:.3f}')
print(f'Test f1: {f1:.3f}')

Files already downloaded and verified
Files already downloaded and verified
-- Epoch 1 --


2450it [01:10, 34.69it/s]


train_loss: 0.500
val_loss: 0.498
val_f1: 0.027
-- Epoch 2 --


847it [00:25, 33.25it/s]


KeyboardInterrupt: 

In [7]:
results = {
    'train_history':train_history,
    'val_history':val_history,
    'f1_history':f1_history,
}

In [8]:
import json
with open(f'results/{dataset}_{loss}_results.json', 'w') as f:
    json.dump(results, f)

In [9]:
torch.save(model.state_dict(), f'models/{dataset}_{loss}.pth')