# Surprise similarity for text classification
install the surprise_similarity package with:
 `pip install surprise_similarity`

## Data preparation
We'll use the datasets library to load the Yahoo! Answers dataset.  Once we've loaded the dataset we select a subset of the training data to use for our few-shot training.  (We actually select several subsets for statistical signficance.) We then for training/test datasets as lists of (input, target) tuples to train and evaluate our classifier.

In [1]:
# Set a random seed for dataset reproducability
import random
random.seed(666)
from datasets import load_dataset


dataset_input_target_mapping = {"yahoo_answers_topics": ("question_title", "topic")}

# get standardized train/test dataframes with columns "input" and "targets" 
# and a list of all possible targets
def prepare_dataset(dataset_name="yahoo_answers_topics"):
    ds = load_dataset(dataset_name)
    df_train = ds["train"].to_pandas()
    df_test = ds["test"].to_pandas()
    df_train.rename(
        columns={dataset_input_target_mapping[dataset_name][0]: "inputs"}, inplace=True
    )
    df_test.rename(
        columns={dataset_input_target_mapping[dataset_name][0]: "inputs"}, inplace=True
    )

    int_to_text_target = {
        i: v
        for i, v in enumerate(
            ds["train"].features[dataset_input_target_mapping[dataset_name][1]].names
        )
    }
    label_set = list(int_to_text_target.values())

    train_targets = [
        int_to_text_target[i] if isinstance(i, int) else i
        for i in df_train[dataset_input_target_mapping[dataset_name][1]]
    ]
    test_targets = [
        int_to_text_target[i] if isinstance(i, int) else i
        for i in df_test[dataset_input_target_mapping[dataset_name][1]]
    ]
    df_train["targets"] = train_targets
    df_test["targets"] = test_targets
    return df_train, df_test, label_set

# get few-shot training and test data
def prepare_training_data(
    df_train,
    df_test,
    label_set,
    train_samples_per_label,
    random_seed,
    balanced_training=True,
):
    test_input_output = list(zip(df_test["inputs"], df_test["targets"]))
    training_input_output = []
    if balanced_training:
        for label in label_set:
            df_tmp = df_train.loc[df_train.targets == label].sample(
                train_samples_per_label, random_state=random_seed
            )
            training_input_output.extend(list(zip(df_tmp["inputs"], df_tmp["targets"])))
    else:
        df_tmp = df_train.sample(
            train_samples_per_label * len(label_set), random_state=random_seed
        )
        training_input_output.extend(list(zip(df_tmp["inputs"], df_tmp["targets"])))
    return training_input_output, test_input_output

  from .autonotebook import tqdm as notebook_tqdm


## Training
We will train a SurpriseSimilarity classifier on the Yahoo! Answers few-shot datasets and evaluate on the full test set using both F1 and accuracy as metrics.

In [2]:
from sklearn.metrics import f1_score, accuracy_score
from surprise_similarity import SurpriseSimilarity

# Use sklearn's f1_score and accuracy_score to evaluate the model
def f1(x, y):
    return f1_score(x, y, average="weighted", zero_division=0)

# Train the model on the training data and evaluate on the test data, return f1 and accuracy
def train_and_run_on_testset(
    training_data,
    test_data,
):
    ss = SurpriseSimilarity()
    training_data = random.sample(training_data, len(training_data))
    ss.train(
        keys=[item[0] for item in training_data],
        queries=[item[1] for item in training_data],
        shuffle=False,  # defaults to true, use False here for better reproducability
    )
    if ss.max_itns:
        print("Reached max iterations")
    print("Starting prediction on testset")
    predictions = ss.predict(
        keys=[item[0] for item in test_data],
        queries=list(set([item[1] for item in test_data])),
    )
    f1_result = f1([it[1] for it in test_data], predictions)
    acc_result = accuracy_score([it[1] for it in test_data], predictions)
    return f1_result, acc_result

## Run experiments
We would like to the performance of the SurpriseSimilarity classifier as a function of the number of few-shot training examples.  We would also like to know how this performance vaires depending on the specific training examples that are selected, so we run 5 experiments per training sample size to estimate 1 standard deviation error bars.  

This takes a while - for a quick experiment, reduce the maximum value in `train_samples_per_label_lis` or `n_runs_per_train_size`.

In [3]:
def execute_few_shot_experiment(train_samples_per_label_list=[3, 9, 27, 81, 243, 729],
                                n_runs_per_train_size=5,
                                dataset_name='yahoo_answers_topics',
                                balanced_training=True,
                                ):
    df_train, df_test, label_set = prepare_dataset(dataset_name=dataset_name)

    results_per_train_size = {num: {'f1': [], 'acc': []} for num in train_samples_per_label_list}
    for train_samples_per_label in train_samples_per_label_list:
        print(f'Starting {n_runs_per_train_size} runs for {train_samples_per_label} training samples per label')
        for run_count in range(n_runs_per_train_size):
            print(f'Starting run {run_count}')
            train_io, test_io = prepare_training_data(df_train=df_train,
                                                      df_test=df_test,
                                                      label_set=label_set,
                                                      train_samples_per_label=train_samples_per_label,
                                                      random_seed=run_count,
                                                      balanced_training=balanced_training,
                                                      )
            f1_result, acc_result = train_and_run_on_testset(training_data=train_io,
                                                            test_data=test_io,
                                                            )
            results_per_train_size[train_samples_per_label]['f1'].append(f1_result)
            results_per_train_size[train_samples_per_label]['acc'].append(acc_result)
    return results_per_train_size

In [4]:
balanced_few_shot_results = execute_few_shot_experiment([3, 9, 27, 81, 243])

Found cached dataset yahoo_answers_topics (/home/ubuntu/.cache/huggingface/datasets/yahoo_answers_topics/yahoo_answers_topics/1.0.0/0edb353eefe79d9245d7bd7cac5ae6af19530439da520d6dde1c206ee38f4439)
100%|██████████| 2/2 [00:00<00:00, 93.37it/s]


Starting 5 runs for 3 training samples per label
Starting run 0
Training on 300 examples...

Training time: 0:20min (9 iterations, F1: 0.9)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.31it/s]


Starting run 1
Training on 300 examples...

Training time: 0:16min (8 iterations, F1: 0.933)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.67it/s]


Starting run 2
Training on 300 examples...

Training time: 0:14min (8 iterations, F1: 0.9)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:27<00:00, 66.99it/s]


Starting run 3
Training on 300 examples...

Training time: 0:17min (9 iterations, F1: 0.967)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.47it/s]


Starting run 4
Training on 300 examples...

Training time: 0:18min (9 iterations, F1: 0.967)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.80it/s]


Starting 5 runs for 9 training samples per label
Starting run 0
Training on 900 examples...

Training time: 0:32min (5 iterations, F1: 0.911)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.83it/s]


Starting run 1
Training on 900 examples...

Training time: 0:31min (5 iterations, F1: 0.944)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.51it/s]


Starting run 2
Training on 900 examples...

Training time: 0:28min (5 iterations, F1: 0.911)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.81it/s]


Starting run 3
Training on 900 examples...

Training time: 0:30min (5 iterations, F1: 0.944)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.83it/s]


Starting run 4
Training on 900 examples...

Training time: 0:29min (5 iterations, F1: 0.933)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.73it/s]


Starting 5 runs for 27 training samples per label
Starting run 0
Training on 2700 examples...

Training time: 1:10min (4 iterations, F1: 0.967)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.95it/s]


Starting run 1
Training on 2700 examples...

Training time: 1:10min (4 iterations, F1: 0.941)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.76it/s]


Starting run 2
Training on 2700 examples...

Training time: 1:07min (4 iterations, F1: 0.97)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.63it/s]


Starting run 3
Training on 2700 examples...

Training time: 1:10min (4 iterations, F1: 0.981)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.96it/s]


Starting run 4
Training on 2700 examples...

Training time: 0:52min (3 iterations, F1: 0.911)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.86it/s]


Starting 5 runs for 81 training samples per label
Starting run 0
Training on 8100 examples...

Training time: 2:32min (3 iterations, F1: 0.957)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.60it/s]


Starting run 1
Training on 8100 examples...

Training time: 2:34min (3 iterations, F1: 0.972)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.88it/s]


Starting run 2
Training on 8100 examples...

Training time: 2:34min (3 iterations, F1: 0.968)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.65it/s]


Starting run 3
Training on 8100 examples...

Training time: 2:37min (3 iterations, F1: 0.984)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.65it/s]


Starting run 4
Training on 8100 examples...

Training time: 2:34min (3 iterations, F1: 0.965)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.61it/s]


Starting 5 runs for 243 training samples per label
Starting run 0
Training on 24280 examples...


Batches: 100%|██████████| 76/76 [00:01<00:00, 61.45it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 68.27it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 67.76it/s]



Training time: 5:06min (2 iterations, F1: 0.919)
Starting prediction on testset


Batches: 100%|██████████| 1874/1874 [00:28<00:00, 66.72it/s]


Starting run 1
Training on 24290 examples...


Batches: 100%|██████████| 76/76 [00:01<00:00, 66.53it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 67.67it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 67.49it/s]



Training time: 5:06min (2 iterations, F1: 0.93)
Starting prediction on testset


Batches: 100%|██████████| 1874/1874 [00:28<00:00, 66.59it/s]


Starting run 2
Training on 24290 examples...


Batches: 100%|██████████| 76/76 [00:01<00:00, 66.40it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 67.47it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 66.90it/s]



Training time: 5:07min (2 iterations, F1: 0.931)
Starting prediction on testset


Batches: 100%|██████████| 1875/1875 [00:28<00:00, 66.84it/s]


Starting run 3
Training on 24280 examples...


Batches: 100%|██████████| 76/76 [00:01<00:00, 66.51it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 67.22it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 66.96it/s]



Training time: 5:11min (2 iterations, F1: 0.923)
Starting prediction on testset


Batches: 100%|██████████| 1874/1874 [00:27<00:00, 67.10it/s]


Starting run 4
Training on 24300 examples...


Batches: 100%|██████████| 76/76 [00:01<00:00, 67.04it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 68.35it/s]
Batches: 100%|██████████| 76/76 [00:01<00:00, 67.89it/s]



Training time: 5:06min (2 iterations, F1: 0.929)
Starting prediction on testset


Batches: 100%|██████████| 1874/1874 [00:27<00:00, 66.93it/s]


In [6]:
import numpy as np
f1_means = []
f1_error_bars = []
acc_means = []
acc_error_bars = []
for k,v in balanced_few_shot_results.items():
    f1_means.append(np.mean(v['f1']))
    f1_error_bars.append(np.std(v['f1']))
    acc_means.append(np.mean(v['acc']))
    acc_error_bars.append(np.std(v['acc']))


In [23]:
print('training e.g. per label', '      F1      ', '      Acc      ')
print('-'*60)
for i, k in enumerate(balanced_few_shot_results.keys()):
    print(f'         {k}'+' '*(7-len(str(k))),
          f'        {round(f1_means[i],3)}'+u"\u00B1"+f'{round(f1_error_bars[i],3)}',
          f'        {round(acc_means[i],3)}'+u"\u00B1"+f'{round(acc_error_bars[i],3)}')

training e.g. per label       F1             Acc      
------------------------------------------------------------
         3               0.635±0.004         0.637±0.004
         9               0.667±0.003         0.668±0.003
         27              0.682±0.004         0.684±0.004
         81              0.688±0.002         0.689±0.002
         243             0.698±0.002         0.7±0.002
