# RPT (Research Paper Tagger)

In [1]:
import os
import zipfile
import json
import random
from tqdm import tqdm
import plotly
import plotly.express as px
import plotly.graph_objects as go

import numpy as np
import pandas as pd

from helpers import tokenize_and_format, flat_accuracy

import torch
from transformers import BertForSequenceClassification, AdamW, BertConfig, get_linear_schedule_with_warmup

from sklearn.metrics import precision_recall_fscore_support, top_k_accuracy_score

In [2]:
random.seed(0)
np.random.seed(0)

torch.manual_seed(0)
torch.use_deterministic_algorithms(False)
# Confirm that the GPU is detected

assert torch.cuda.is_available()

# Get the GPU device name.
device_name = torch.cuda.get_device_name()
n_gpu = torch.cuda.device_count()
print(f"Found device: {device_name}, n_gpu: {n_gpu}")
device = torch.device("cuda")

Found device: NVIDIA GeForce RTX 2060 with Max-Q Design, n_gpu: 1


In [3]:
with open("Data/Raw data/training_data.jsonl", "r") as f:
    training_data = json.load(f)
    
with open("Data/Raw data/validation_data.jsonl", "r") as f:
    validation_data = json.load(f)
    
with open("Data/Raw data/test_data.jsonl", "r") as f:
    test_data = json.load(f)
    
with open("Data/Metadata/label_string_to_ID.jsonl", "r") as f:
    label_string_to_ID = json.load(f)
    
with open("Data/Metadata/label_ID_to_string.jsonl", "r") as f:
    label_ID_to_string = json.load(f)

### Predictions using only abstract

In [4]:
training_inputs = []
training_label_strings = []

validation_inputs = []
validation_label_strings = []

test_inputs = []
test_label_strings = []

for training_example in training_data:
    
    training_input = training_example[0][0] + '. ' + training_example[0][2]
    training_inputs.append(training_input)
    
    training_label_strings.append(training_example[1])
    
for validation_example in validation_data:
    
    validation_input = validation_example[0][0] + '. ' + validation_example[0][2]
    validation_inputs.append(validation_input)
    
    validation_label_strings.append(validation_example[1])
    
for test_example in test_data:
    
    test_input = test_example[0][0] + '. ' + test_example[0][2]
    test_inputs.append(test_input)
    
    test_label_strings.append(test_example[1])

In [5]:
max_seq_length = 332

training_input_ids, training_attention_masks = tokenize_and_format(training_inputs, max_seq_length)
validation_input_ids, validation_attention_masks = tokenize_and_format(validation_inputs, max_seq_length)
test_input_ids, test_attention_masks = tokenize_and_format(test_inputs, max_seq_length)

In [6]:
training_label_IDs = []
validation_label_IDs = []
test_label_IDs = []

for training_label_string in training_label_strings:
    training_label_IDs.append(label_string_to_ID[training_label_string])
    
for validation_label_string in validation_label_strings:
    validation_label_IDs.append(label_string_to_ID[validation_label_string])
    
for test_label_string in test_label_strings:
    test_label_IDs.append(label_string_to_ID[test_label_string])
    
    
# Convert the lists into tensors.
training_input_ids = torch.cat(training_input_ids, dim=0)
training_attention_masks = torch.cat(training_attention_masks, dim=0)
training_label_IDs = torch.tensor(training_label_IDs)

validation_input_ids = torch.cat(validation_input_ids, dim=0)
validation_attention_masks = torch.cat(validation_attention_masks, dim=0)
validation_label_IDs = torch.tensor(validation_label_IDs)

test_input_ids = torch.cat(test_input_ids, dim=0)
test_attention_masks = torch.cat(test_attention_masks, dim=0)
test_label_IDs = torch.tensor(test_label_IDs)

In [7]:
train_set = [(training_input_ids[i], training_attention_masks[i], training_label_IDs[i]) for i in range(len(training_inputs))]
val_set = [(validation_input_ids[i], validation_attention_masks[i], validation_label_IDs[i]) for i in range(len(validation_inputs))]
test_set = [(test_input_ids[i], test_attention_masks[i], test_label_IDs[i]) for i in range(len(test_inputs))]

#### Fine-tune the BERT model

In [10]:
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 20, # The number of output labels.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

# Tell pytorch to run this model on the GPU.
model.cuda()





hyperparameter_config_iter = 1

save_path = "Saved models/Hyperparameter configuration " + str(hyperparameter_config_iter)

if(os.path.exists(save_path)):
    raise Exception("ERROR! Hyperparameter config " + str(hyperparameter_config_iter))

else:
    os.makedirs(save_path)
    os.makedirs(save_path + "/Plots")


# Fine-tuning hyperparameters

batch_size = 4
full_batch_size = 32

no_of_steps = int(full_batch_size/batch_size)

optimizer = AdamW(model.parameters(),
                  lr = 5e-5, # args.learning_rate - default is 5e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8
                )
epochs = 30

hyperparameter_dict = dict()
hyperparameter_dict['batch_size'] = full_batch_size
hyperparameter_dict['epochs'] = epochs

with open(save_path + "/Hyperparameters.json", 'w') as f:
    json.dump(hyperparameter_dict, f)

def save(model, optimizer, output_path):
    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, output_path)

# function to get validation accuracy
def get_performance(data_set):
    # Put the model in evaluation mode
    model.eval()

    # Tracking variables 
    total_eval_accuracy = 0
    total_eval_loss = 0

    num_batches = int(len(data_set)/batch_size) + 1

    total_correct = 0

    for i in range(num_batches):

        end_index = min(batch_size * (i+1), len(data_set))

        batch = data_set[i*batch_size:end_index]

        if len(batch) == 0: continue

        input_id_tensors = torch.stack([data[0] for data in batch])
        input_mask_tensors = torch.stack([data[1] for data in batch])
        label_tensors = torch.stack([data[2] for data in batch])

        # Move tensors to the GPU
        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)

        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():        

            # Forward pass, calculate logit predictions.
            outputs = model(b_input_ids,
                                    token_type_ids=None,
                                    attention_mask=b_input_mask,
                                    labels=b_labels)
            loss = outputs.loss
            logits = outputs.logits

            # Accumulate the validation loss.
            total_eval_loss += loss.item()

            # Move logits and labels to CPU
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()

            # Calculate the number of correctly labeled examples in batch
            pred_flat = np.argmax(logits, axis=1).flatten()
            labels_flat = label_ids.flatten()
            num_correct = np.sum(pred_flat == labels_flat)
            total_correct += num_correct
        
    # Report the final accuracy for this validation run.
    avg_val_accuracy = total_correct / len(data_set)
    return avg_val_accuracy



# training loop

max_val_acc = -1

metric_vs_epoch = dict()

epoch_list = []
training_loss_list = []
training_acc_list = []

val_acc_list = []

# For each epoch...
for epoch_i in range(0, epochs):
    
    epoch_list.append(epoch_i + 1)
    
    # Perform one full pass over the training set.

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    # Reset the total loss for this epoch.
    total_train_loss = 0

    # Put the model into training mode.
    model.train()

    # For each batch of training data...
    num_batches = int(len(train_set)/batch_size) + 1

    for i in tqdm(range(num_batches)):
        
        end_index = min(batch_size * (i+1), len(train_set))

        batch = train_set[i*batch_size:end_index]

        if len(batch) == 0: continue

        input_id_tensors = torch.stack([data[0] for data in batch])
        input_mask_tensors = torch.stack([data[1] for data in batch])
        label_tensors = torch.stack([data[2] for data in batch])

        # Move tensors to the GPU
        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)

        if(i%no_of_steps == 0):
            # Clear the previously calculated gradient
            model.zero_grad()

        # Perform a forward pass (evaluate the model on this training batch).
        outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask, 
                            labels=b_labels)
        loss = outputs.loss
        logits = outputs.logits

        total_train_loss += loss.item()

        # Perform a backward pass to calculate the gradients.
        loss.backward()

        if((i+1)%no_of_steps == 0):
            # Update parameters and take a step using the computed gradient.
            optimizer.step()
        
    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set. Implement this function in the cell above.
    
    training_acc = get_performance(train_set)
    val_acc = get_performance(val_set)
    
    print(f"Total loss: {total_train_loss}")
    print(f"Validation accuracy: {val_acc}")
    
    val_acc_list.append(val_acc)
    training_acc_list.append(training_acc)
    
    training_loss_list.append(total_train_loss)
    
    if(val_acc > max_val_acc):
        
        max_val_acc = val_acc
        
        model.save_pretrained(save_path + "/best validation accuracy model")
        save(model, optimizer, save_path + "/best validation accuracy.modelState")
    
    
print("")
print("Training complete!")            

metric_vs_epoch["Epochs"] = epoch_list
metric_vs_epoch["Training loss"] = training_loss_list
metric_vs_epoch["Training accuracy"] = training_acc_list
metric_vs_epoch["Validation accuracy"] = val_acc_list

with open(save_path + "/Plots/Plot data.json", 'w') as f:
    json.dump(metric_vs_epoch, f)

metric_vs_epoch_df = pd.DataFrame(metric_vs_epoch, columns = ["Epochs", "Training loss", "Training accuracy", "Validation accuracy"])

fig = px.line(metric_vs_epoch_df, x='Epochs', y="Training loss", title="Training loss vs epochs")
plotly.offline.plot(fig, filename = save_path + "/Plots/Training loss.html")

accuracy_vs_epoch = dict()
accuracy_vs_epoch["Epochs"] = epoch_list + epoch_list
accuracy_vs_epoch["Accuracy"] = training_acc_list + val_acc_list
accuracy_vs_epoch["Dataset"] = ["Training"]*len(training_acc_list) + ["Validation"]*len(val_acc_list)

accuracy_vs_epoch_df = pd.DataFrame(accuracy_vs_epoch, columns = ["Epochs", "Accuracy", "Dataset"])

fig = px.line(accuracy_vs_epoch_df, x='Epochs', y='Accuracy', color='Dataset', markers=True, title="Training/Validation accuracy vs epochs")
plotly.offline.plot(fig, filename = save_path + "/Plots/Accuracy.html")

with open(save_path + "/Best validation accuracy.txt", 'w') as f:
    f.write("Best validation accuracy: " + str(max_val_acc))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at


Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:31<00:00,  3.81it/s]



Total loss: 951.603481054306
Validation accuracy: 0.3657142857142857
Total loss: 951.603481054306
Validation accuracy: 0.3657142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:28<00:00,  3.93it/s]



Total loss: 726.5376545786858
Validation accuracy: 0.5142857142857142
Total loss: 726.5376545786858
Validation accuracy: 0.5142857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:28<00:00,  3.90it/s]



Total loss: 506.66910749673843
Validation accuracy: 0.5657142857142857
Total loss: 506.66910749673843
Validation accuracy: 0.5657142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.89it/s]



Total loss: 346.2321638315916
Validation accuracy: 0.5942857142857143
Total loss: 346.2321638315916
Validation accuracy: 0.5942857142857143

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.88it/s]



Total loss: 254.2959809154272
Validation accuracy: 0.6228571428571429
Total loss: 254.2959809154272
Validation accuracy: 0.6228571428571429

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.87it/s]



Total loss: 172.7014456242323
Validation accuracy: 0.6171428571428571

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 172.7014456242323
Validation accuracy: 0.6171428571428571

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.87it/s]



Total loss: 103.61433965340257
Validation accuracy: 0.6457142857142857
Total loss: 103.61433965340257
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 73.29466884583235
Validation accuracy: 0.64

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 73.29466884583235
Validation accuracy: 0.64

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 51.6582138184458
Validation accuracy: 0.6171428571428571

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 51.6582138184458
Validation accuracy: 0.6171428571428571

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 34.08404191862792
Validation accuracy: 0.6342857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 34.08404191862792
Validation accuracy: 0.6342857142857142

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 23.732166479341686
Validation accuracy: 0.6228571428571429

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 23.732166479341686
Validation accuracy: 0.6228571428571429

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 18.028556033037603
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 18.028556033037603
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 14.819039726629853
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 14.819039726629853
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 15.089112772606313
Validation accuracy: 0.64

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 15.089112772606313
Validation accuracy: 0.64

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 13.307247541379184
Validation accuracy: 0.64

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 13.307247541379184
Validation accuracy: 0.64

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 13.292059191968292
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 13.292059191968292
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:29<00:00,  3.86it/s]



Total loss: 13.528193207923323
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 13.528193207923323
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 11.530293301213533
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 11.530293301213533
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 13.549445481505245
Validation accuracy: 0.6342857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 13.549445481505245
Validation accuracy: 0.6342857142857142

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 12.120482798432931
Validation accuracy: 0.64

Training...
Total loss: 12.120482798432931
Validation accuracy: 0.64

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 11.768166802590713
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 11.768166802590713
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 11.590393365127966
Validation accuracy: 0.64

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 11.590393365127966
Validation accuracy: 0.64

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 10.766582627315074
Validation accuracy: 0.6457142857142857

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 10.766582627315074
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.84it/s]



Total loss: 12.057971419068053
Validation accuracy: 0.6457142857142857

Training...
Total loss: 12.057971419068053
Validation accuracy: 0.6457142857142857

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.84it/s]



Total loss: 11.059794471831992
Validation accuracy: 0.64

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 11.059794471831992
Validation accuracy: 0.64

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.84it/s]



Total loss: 11.075509556103498
Validation accuracy: 0.6342857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 11.075509556103498
Validation accuracy: 0.6342857142857142

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.84it/s]



Total loss: 12.423062508460134
Validation accuracy: 0.6342857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 12.423062508460134
Validation accuracy: 0.6342857142857142

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.84it/s]



Total loss: 9.380897851195186
Validation accuracy: 0.6342857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 9.380897851195186
Validation accuracy: 0.6342857142857142

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 12.25325699953828
Validation accuracy: 0.6342857142857142

Training...


  0%|                                                                                          | 0/347 [00:00<?, ?it/s]

Total loss: 12.25325699953828
Validation accuracy: 0.6342857142857142

Training...


100%|████████████████████████████████████████████████████████████████████████████████| 347/347 [01:30<00:00,  3.85it/s]



Total loss: 10.857569964369759
Validation accuracy: 0.6342857142857142

Training complete!
Total loss: 10.857569964369759
Validation accuracy: 0.6342857142857142

Training complete!


### Get the classification metrics for the best model

In [8]:
best_hyperparameter_configuration = "Hyperparameter configuration 2"

model = BertForSequenceClassification.from_pretrained(
    "Saved models/" + best_hyperparameter_configuration + "/best validation accuracy model/",
    local_files_only = True,
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

model.cuda()


batch_size = 8

def get_outputs(data_set):
    # Put the model in evaluation mode
    model.eval()

    num_batches = int(len(data_set)/batch_size) + 1

    total_correct = 0
    
    outputs = []
    
    all_labels = []
    all_logits = []

    for i in range(num_batches):

        end_index = min(batch_size * (i+1), len(data_set))

        batch = data_set[i*batch_size:end_index]

        if len(batch) == 0: continue

        input_id_tensors = torch.stack([data[0] for data in batch])
        input_mask_tensors = torch.stack([data[1] for data in batch])
        label_tensors = torch.stack([data[2] for data in batch])

        # Move tensors to the GPU
        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)

        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():        

            # Forward pass, calculate logit predictions.
            
            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
            
            loss = outputs.loss
            logits = outputs.logits

            # Move logits and labels to CPU
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()

            # Calculate the number of correctly labeled examples in batch
            pred_flat = np.argmax(logits, axis=1).flatten()
            labels_flat = label_ids.flatten()
            
            all_labels.append(labels_flat)
            all_logits.append(logits)
            
    
    all_labels = np.concatenate(all_labels)
    all_logits = np.concatenate(all_logits)
    
    print("Top 1 accuracy: ", top_k_accuracy_score( y_true = all_labels, y_score = all_logits, k=1, labels = np.array(range(0, 20, 1)) ) )
    print("Top 3 accuracy: ", top_k_accuracy_score( y_true = all_labels, y_score = all_logits, k=3, labels = np.array(range(0, 20, 1)) ) )
    print("Top 5 accuracy: ", top_k_accuracy_score( y_true = all_labels, y_score = all_logits, k=5, labels = np.array(range(0, 20, 1)) ) )
    
    print("Precision, Recall, F-1:", precision_recall_fscore_support(y_true = all_labels, y_pred = np.argmax(all_logits, axis = 1), average='weighted'))
    
get_outputs(test_set)

Top 1 accuracy:  0.644808743169399
Top 3 accuracy:  0.7923497267759563
Top 5 accuracy:  0.8415300546448088
Precision, Recall, F-1: (0.7121894845077311, 0.644808743169399, 0.6407995012179304, None)
