### Read the data

In [None]:
import os
import json
import torch 
import numpy as np

DATA_PATH = "leopard/data/json"

def get_categories():
    categories = os.listdir(DATA_PATH)
    categories.remove("restaurant")
    categories.remove("conll")

    # move emotion to the end
    categories.remove("emotion")
    categories.append("emotion")
    return categories

def get_labelled_training_sentences(category, shot, episode):
    sentences = []
    labels = []
    label_keys = {}
    label_index = 0
    data_path = DATA_PATH + "/" + category + "/"
    for file_name in os.listdir(data_path):
        if file_name.endswith("_" + str(episode) + "_" + str(shot) + ".json"):
            data = json.load(open(data_path + file_name))            
            for index in range(len(data)):
                processed_sentence = data[index]['processed_sent']
                processed_sentence = processed_sentence.replace('[CLS]', '')
                processed_sentence = processed_sentence.replace('[SEP]', '')
                processed_sentence = processed_sentence.replace('[PAD]', '')
                label = data[index]['label']
                sentences.append(processed_sentence)
                # convert categorical labels to numeric values
                if label not in label_keys:
                    label_keys[label] = label_index
                    label_index += 1
                labels.append(label_keys[label])
    return sentences, labels, label_keys

def get_labelled_test_sentences(category):
    sentences = []
    labels = []
    data_path = DATA_PATH + "/" + category + "/"
    for file_name in os.listdir(data_path):
        if file_name.endswith("_eval.json"):
            data = json.load(open(data_path + file_name))            
            for index in range(len(data)):
                processed_sentence = data[index]['processed_sent']
                processed_sentence = processed_sentence.replace('[CLS]', '')
                processed_sentence = processed_sentence.replace('[SEP]', '')
                processed_sentence = processed_sentence.replace('[PAD]', '')
                label = data[index]['label']
                sentences.append(processed_sentence)
                labels.append(label)
    return sentences, labels

Create custom dataset objects.

In [None]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self, encodings, labels):
        self.labels = labels
        self.encodings = encodings

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.encodings)

  def __getitem__(self, index):
        'Generates one sample of data'
        return self.encodings[index], self.labels[index]

### Create training and test splits

In [None]:
import ast

EXTRACT_FEATURES_COMMAND_BASE = "python mt-dnn/extractor.py --do_lower_case --finput mt-dnn/input_examples/single-input.txt --foutput mt-dnn/input_examples/single-output.json --bert_model bert-base-uncased --checkpoint mt-dnn/mt_dnn_models/mt_dnn_base_uncased.pt"
EXTRACT_FEATURES_COMMAND_LARGE = "python mt-dnn/extractor.py --do_lower_case --finput mt-dnn/input_examples/single-input.txt --foutput mt-dnn/input_examples/single-output.json --bert_model bert-base-uncased --checkpoint mt-dnn/mt_dnn_models/mt_dnn_large_uncased.pt"

def get_labelled_training_data(category, shot, episode):    
    sentences, training_labels, label_keys = get_labelled_training_sentences(category, shot, episode)  
    # write all sentences to the input file
    with open("mt-dnn/input_examples/single-input.txt", 'w', encoding='utf-8') as writer:
        writer.write('\n'.join(sentences))
    # execute the command to get encodings
    os.system(EXTRACT_FEATURES_COMMAND_BASE)
    # fetch sentence encodings from the output file
    training_encodings = []
    with open('mt-dnn/input_examples/single-output.json', 'r') as data_file:
        encodings_json = data_file.read()
    encodings_data = json.loads(encodings_json)
    for encoding in encodings_data:
        training_encodings.append(np.array(ast.literal_eval(encoding['11']), dtype=np.float32))
    return training_encodings, training_labels, label_keys

In [None]:
def get_labelled_test_data(category):
    sentences, test_labels = get_labelled_test_sentences(category)
    # write all sentences to the input file
    with open("mt-dnn/input_examples/single-input.txt", 'w', encoding='utf-8') as writer:
        writer.write('\n'.join(sentences))
    # execute the command to get encodings
    os.system(EXTRACT_FEATURES_COMMAND_BASE)
    # fetch sentence encodings from the output file
    test_encodings = []
    with open('mt-dnn/input_examples/single-output.json', 'r') as data_file:
        encodings_json = data_file.read()
    encodings_data = json.loads(encodings_json)
    for encoding in encodings_data:        
        test_encodings.append(np.array(ast.literal_eval(encoding['11']), dtype=np.float32))        
    return test_encodings, test_labels

### Define the classifier

In [None]:
import torch.nn as nn

INPUT_DIMS = 768

def get_model(output_dims):
  return nn.Linear(INPUT_DIMS, output_dims)

### Train the classifier

In [None]:
import torch.optim as optim

def train(model, trainloader, epochs):
  criterion = nn.CrossEntropyLoss()
  optimiser = optim.AdamW(model.parameters(), lr=2e-5)

  for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      inputs, labels = data
      # zero the parameter gradients
      optimiser.zero_grad()
      # forward + backward + optimize
      outputs = model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimiser.step()
      # print statistics
      running_loss += loss.item()
      # print("The loss is", loss.item())

### Test the classifier

In [None]:
def test(model, test_loader):
  correct = 0
  total = 0
  # since we're not training, we don't need to calculate the gradients for our outputs
  with torch.no_grad():
    for data in test_loader:
      encodings, labels = data
      # calculate outputs by running images through the network
      outputs = model(encodings)
      # the class with the highest energy is what we choose as prediction
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  return correct, total

### Get the results for all episodes

In [None]:
# get the training set
# get a model from it and train the model
# check the model performance on the test set

import warnings

warnings.filterwarnings('ignore')

# Parameters
params = {'batch_size': 4,
          'shuffle': True,
          'num_workers': 0}

epochs = {4: 100,
          8: 125,
          16: 150}

display_stats = False

for category in get_categories():
  test_encodings, categorical_test_labels = get_labelled_test_data(category)
  accuracies = {}
  for shot in [4,8,16]:
    accuracies[shot] = []
  for episode in range(3):
    for shot in [4,8,16]:
      predictions = []
      true_labels = []

      training_encodings, training_labels, label_keys = get_labelled_training_data(category, shot, episode)            
      if not training_encodings:
        continue
    
      # convert categorical attributes to numeric indices
      test_labels = [label_keys[label] for label in categorical_test_labels]
      
      # define the model
      classes = max(max(training_labels) for training_label in training_labels) + 1
      model = get_model(output_dims=classes)
            
      # create the dataloaders for training and test splits
      training_set = Dataset(training_encodings, training_labels)
      train_loader = torch.utils.data.DataLoader(training_set, **params)

      test_set = Dataset(test_encodings, test_labels)
      test_loader = torch.utils.data.DataLoader(test_set, **params)
      
      # train the model
      train(model, train_loader, epochs=epochs[shot])
      
      # test the model performance
      correct, total = test(model, test_loader)
      
      if display_stats == True:
        print("The training split is", len(training_encodings))
        print("The test split is", len(test_encodings))
        print("The number of classes are", classes)
        print("For category", category, "and shot =", str(shot) + "...")
        print("Accuracy is", correct/total, "\n")
      accuracies[shot].append(correct/total)
  print("\n" + "For category " + category + "...")
  for shot in [4,8,16]:
    print("The accuracy is", round(np.mean(accuracies[shot]), 4), "+-", round(np.std(accuracies[shot]), 4), "for shot =", shot)
  print("\n")