## 1.0 Introduction

In this notebook we will explore few-shot in-context learning where we will use randomly selected samples from the training dataset as examples in the prompt to classify a text from the validation dataset.

The training and validation csv files are the same as used in all other notebooks in this repository.

The GPT-3.5 Turbo model with 16K context window will be used as it is more than large enough to contain the generated prompt with all added training examples.

!Note: The code in this notebook is based on the new OpenAI API version (1.X or higher).

In [1]:
# Import Modules
import os
import numpy as np
import pandas as pd
import time
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm

# OpenAI
from openai import OpenAI

## 2.0 Load Datasets

We will load the training and validation CSV files that were generated earlier with the notebook 'Prepare_Train_and_Validation_Datasets.ipynb'.

The validation dataset will be used in the same way as in all other notebook. The training dataset however will be used to continuously get random samples to use as examples in the prompt.

In [2]:
# Load Datasets
train_df = pd.read_csv('./data/train_df.csv')
val_df = pd.read_csv('./data/val_df.csv')

# Summary
print(train_df.shape)
print(val_df.shape)

(3069, 11)
(1559, 11)


Let's review a small subset of the training data...

In [3]:
# Summary
train_df.head()

Unnamed: 0,id,title,text,mainSection,published_at,publisher,partisan,url,text_wordcount,max_words_text,labels
0,10706318,Ogen als schoteltjes bij de Tachtigjarige Oorlog,Ogen als schoteltjes bij de Tachtigjarige Oorl...,/home,2018-10-07,trouw,True,www.trouw.nl/home/ogen-als-schoteltjes-bij-de-...,539,Ogen als schoteltjes bij de Tachtigjarige Oorl...,1
1,12633805,"Geen beeld, maar een monument voor Mandela in ...","Geen beeld, maar een monument voor Mandela in ...",/amsterdam,2019-05-10,parool,True,www.parool.nl/amsterdam/geen-beeld-maar-een-mo...,662,"Geen beeld, maar een monument voor Mandela in ...",1
2,7140125,Hoe ga je een onveilige arbeidscultuur zoals i...,Hoe ga je een onveilige arbeidscultuur zoals i...,/,2017-04-18,trouw,True,,494,Hoe ga je een onveilige arbeidscultuur zoals i...,1
3,4490774,Wetenschappers ontdekken lichtgevende discokikker,Wetenschappers ontdekken lichtgevende discokik...,/,2017-03-14,trouw,True,,291,Wetenschappers ontdekken lichtgevende discokik...,1
4,10592180,Meer fouten kabinet bij steun aan strijdgroepe...,Meer fouten kabinet bij steun aan strijdgroepe...,/home,2018-09-11,trouw,True,www.trouw.nl/home/meer-fouten-kabinet-bij-steu...,471,Meer fouten kabinet bij steun aan strijdgroepe...,1


...and also the validation data...

In [4]:
# Summary
val_df.head()

Unnamed: 0,id,title,text,mainSection,published_at,publisher,partisan,url,text_wordcount,max_words_text,labels
0,9266995,Verdachte dodelijke steekpartijen Maastricht l...,Verdachte dodelijke steekpartijen Maastricht l...,/nieuws,2017-12-18,ad,False,www.ad.nl/binnenland/verdachte-dodelijke-steek...,188,Verdachte dodelijke steekpartijen Maastricht l...,0
1,4130077,Honderden arrestaties bij acties tegen mensen ...,Honderden arrestaties bij acties tegen mensen ...,/nieuws,2017-02-11,ad,False,www.ad.nl/buitenland/honderden-arrestaties-bij...,122,Honderden arrestaties bij acties tegen mensen ...,0
2,11147268,Waarom de 'oudejaarsbonus' voor de jongeren va...,Waarom de 'oudejaarsbonus' voor de jongeren va...,/home,2019-01-20,trouw,True,www.trouw.nl/home/waarom-de-oudejaarsbonus-voo...,262,Waarom de 'oudejaarsbonus' voor de jongeren va...,1
3,10749100,Klaar voor de verdediging,Klaar voor de verdedigingOver ruim een week be...,/nieuws,2018-10-16,ad,False,www.ad.nl/binnenland/klaar-voor-de-verdediging...,411,Klaar voor de verdedigingOver ruim een week be...,0
4,10700707,Windvlaag grijpt springmatras en doodt 2-jarig...,Windvlaag grijpt springmatras en doodt 2-jarig...,/nieuws,2018-10-05,ad,False,www.ad.nl/buitenland/windvlaag-grijpt-springma...,286,Windvlaag grijpt springmatras en doodt 2-jarig...,0


For the InContext Learning we will further subsample the validation set to 100 / 100 samples of Politiek and Neutraal.

I did a lot of small experiments with various different prompts and in-context added samples. The achieved accuracy varied between 40% to around 60% on the validation set for all experiments. 
The dataset is clearly not the best one suitable for few-shot classification.

Since this is one of my many side projects I won't create a nice credit card bill for validating on all validation samples ;-)

In [5]:
# Subsample Validation Dataset
label0_val_df = val_df[val_df.labels == 0].sample(n = 100)
label1_val_df = val_df[val_df.labels == 1].sample(n = 100)

# Concatenate
small_val_df = pd.concat([label0_val_df, label1_val_df]).sample(frac = 1.0)

# Summary
small_val_df.labels.value_counts()

labels
1    100
0    100
Name: count, dtype: int64

## 3.0 In-context Learning with OpenAI GPT-3.5 Model

In this section we will determine the classification accuracy on the validation dataset when using few shot in-context learning.

In each request to determine the classification label on a sample text from the validation dataset we will provide 3 rxamples for each of the 2 labels.

The examples for the 2 labels will be sampled at random from the training dataset.

### 3.1 Settings and Support Functions

In [6]:
# Constants
MAX_WORDS = 192
N_SAMPLES = 3

# OpenAI API Key
client = OpenAI(api_key = os.environ["OPENAI_API_KEY"])

# Set Model
openai_model = "gpt-3.5-turbo-1106"

The following function will randomly sample from the training dataset the examples that will be put in the prompt.

In [7]:
def get_incontext_samples(train_df):
    politic_samples = train_df[train_df.labels == 1].sample(n = N_SAMPLES)['max_words_text'].tolist()
    neutral_samples = train_df[train_df.labels == 0].sample(n = N_SAMPLES)['max_words_text'].tolist()

    return politic_samples, neutral_samples

### 3.2 In-context learning prompt

Next is the function to create the prompt text. It contains the necessary prompt engineering to specify the examples that are provided for the 2 labels 'Politiek' and 'Neutraal'.

In the final part the text to be classified is added and again we ask the model to only respond with one of the two allowed labels.

In [8]:
def create_prompt(item_text, politic_samples, neutral_samples):
    # Base Prompt
    prompt_text = [
    {"role": "system", 
     "content": 
"""
Je bent redacteur bij een krant.
Je beoordeeld een krantenartikel of het politiek of neutraal is. 
Je mag de beoordeling alleen aangeven met 1 van de 2 volgende labels:
Politiek
Neutraal

Hieronder staan een aantal voorbeelden van Politieke en Neutrale teksten aangegeven met '# Voorbeeld: ' en vervolgens de tekst.
Eronder staat met 'Label: ' aangegeven of het Politiek of Neutraal is.

"""
     }, 
    {"role": "user", "content": ""}]

    # Set Text and Label
    prompt_text[1]['content'] = f"""
\n\n# Voorbeeld: {politic_samples[0]}
\nLabel: Politiek
\n\n# Voorbeeld: {neutral_samples[0]}
\nLabel: Neutraal 
\n\n# Voorbeeld: {politic_samples[1]}
\nLabel: Politiek
\n\n# Voorbeeld: {neutral_samples[1]}
\nLabel: Neutraal
\n\n# Voorbeeld: {politic_samples[2]}
\nLabel: Politiek
\n\n# Voorbeeld: {neutral_samples[2]}
\nLabel: Neutraal

Beoordeel de nu volgende tekst. 
Antwoord met alleen de labels Politiek of Neutraal.
\n{item_text}
Label: 
"""
    
    return prompt_text

### 3.3 Validation

In this part we loop through the validation samples and have the default GPT-3.5 model determine the classification label based on the provided examples in the prompt.

In [9]:
# Placeholders
gt_labels, pred_labels = [], []

# Inference on Validation Dataframe
for index, row in tqdm(small_val_df.iterrows(), total = small_val_df.shape[0]):
    text = row['max_words_text']
    label = row['labels']
    gt_labels.append(label)

    # Get In Context Samples
    politics_samples, neutral_samples = get_incontext_samples(train_df)
    
    # Create Prompt
    prompt = create_prompt(text, politics_samples, neutral_samples)

    # Call API for inference on FineTuned Model
    # Simple error handling .... earlier this summer I experienced occasional errors with the model not responding/not being available.. A simple wait period was enough to have it working.
    # Recently no bad experiences with that...however I keep the error handling for it in-place
    try:
        completion = client.chat.completions.create(model = openai_model,
                                                    messages = prompt, 
                                                    temperature = 0.0)
    except Exception as err:
        print(err)
        time.sleep(120.0)
        completion = client.chat.completions.create(model = openai_model,
                                                    messages = prompt,
                                                    temperature = 0.0)

    # Response Determine label based on OpenAI ChatCompletion Response
    predicted_label = -1
    prediction_text = completion.choices[0].message.content
    if prediction_text == 'Politiek':
        predicted_label = 1
    if prediction_text == 'Neutraal':
        predicted_label = 0

    # Set the predicted label wrong in case Politiek or Neutraal was not set.
    if predicted_label == -1:
        predicted_label = 1 - label

    # Store Result
    pred_labels.append(predicted_label)
    #print(f'Index: {index}   GT: {label}   Pred: {predicted_label}')

    # Delay...To stay within the limits for tokens/requests per minute
    time.sleep(0.1)

  0%|          | 0/200 [00:00<?, ?it/s]

### 3.4 Classification Results

Below the classification results for the in-context learning approach.

With an achieved accuracy of 56% on the validation dataset it is clear that in-context learning is performing very bad for this particular dataset.

The validation accuracy is significantly lower than the 85% - 90% validation accuracy we could achieve when fine-tuning an OpenAI GPT-3.5 (or any other..) model.

In [10]:
# Classification Results
print(classification_report(gt_labels, pred_labels, 
                            target_names = ['Neutraal', 'Politiek'], 
                            digits = 3))

              precision    recall  f1-score   support

    Neutraal      0.583     0.420     0.488       100
    Politiek      0.547     0.700     0.614       100

    accuracy                          0.560       200
   macro avg      0.565     0.560     0.551       200
weighted avg      0.565     0.560     0.551       200

