In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# get environment variable: OPENAI_API_KEY
from dotenv import load_dotenv
load_dotenv() 

True

In [3]:
import tiktoken
import openai

In [4]:
# find tokens for 'positive' and 'negative'
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
enc.encode("positive"), enc.encode("negative") 

([31587], [43324])

In [5]:
# download dataset from hf datasets https://github.com/huggingface/datasets
# !pip install datasets 
from datasets import load_dataset

dataset = load_dataset("sst2")

Found cached dataset sst2 (/home/survival/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
val_data = dataset["validation"]
val_data

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 872
})

In [7]:
# sentimentid2sentimentlabel = {"0": "negative", "1": "positive"}
sents, labels = [], []
for idx in range(len(val_data)):
    sents.append(val_data[idx]["sentence"])
    label = val_data[idx]["label"]
    if label == 0:
        labels.append('negative')
    else:
        labels.append('positive')

In [8]:
sent = sents[0]
label = labels[0]
sent, label

("it 's a charming and often affecting journey . ", 'positive')

In [9]:
from textwrap import wrap

In [10]:
repr(sent)

'"it \'s a charming and often affecting journey . "'

In [11]:
enc.encode(sent)

[275, 364, 82, 264, 35509, 323, 3629, 28987, 11879, 662, 220]

In [12]:
wrapped_sent = wrap(sent, replace_whitespace=True, drop_whitespace=True)[0]
repr(wrapped_sent)

'"it \'s a charming and often affecting journey ."'

In [13]:
enc.encode(wrapped_sent)

[275, 364, 82, 264, 35509, 323, 3629, 28987, 11879, 662]

In [14]:
wrapped_sents = [wrap(sent, replace_whitespace=True, drop_whitespace=True)[0] for sent in sents]

In [15]:
def my_classifier(text: str) -> str:
    return openai.ChatCompletion.create(
    model='gpt-3.5-turbo',
    messages=[
        {
            'role': 'system',
            'content': 'Given this text, what is the sentiment conveyed? Is it positive or negative?'
        },
        {
        'role':'user',
        'content':text
    }],
    logit_bias={
        '31587': 100,    # 31587 is the token for `positive`
        '43324': 100     # 43324 is the token for `negative`
    },
    max_tokens=1,
    temperature=0
).choices[0].message.content

In [16]:
my_classifier(wrapped_sent)

'positive'

In [17]:
labels[0]

'positive'

In [18]:
import time

In [19]:
preds = []
for i, wrapped_sent in enumerate(wrapped_sents):
    try:
        preds.append(my_classifier(wrapped_sent))
    except:
        preds.append(np.nan)
        time.sleep(5)

In [23]:
len(preds), len(labels)

(872, 872)

In [34]:
import pandas as pd

In [37]:
missing = [i for i, pred in enumerate(preds) if pd.isna(pred)]
missing

[10, 22, 242, 304, 471]

In [45]:
missing_preds = []
for idx in missing:
    try:
        missing_preds.append(my_classifier(wrapped_sents[idx]))
    except:
        missing_preds.append(np.nan)

In [46]:
missing_preds

['positive', 'negative', 'negative', 'positive', 'negative']

In [50]:
for i, idx in enumerate(missing):
    preds[idx] = missing_preds[i]

In [52]:
import numpy as np

np.mean([labels[i] == preds[i] for i in range(len(preds))])

0.8360091743119266

In [55]:
results = pd.DataFrame()
results["sentence"] = sents
results["label"] = labels
results["predictions"] = preds

In [60]:
results.head(10)

Unnamed: 0,sentence,label,predictions
0,it 's a charming and often affecting journey .,positive,positive
1,unflinchingly bleak and desperate,negative,negative
2,allows us to hope that nolan is poised to emba...,positive,positive
3,"the acting , costumes , music , cinematography...",positive,positive
4,"it 's slow -- very , very slow .",negative,negative
5,although laced with humor and a few fanciful t...,positive,negative
6,a sometimes tedious film .,negative,negative
7,or doing last year 's taxes with your ex-wife .,negative,negative
8,you do n't have to know about music to appreci...,positive,positive
9,"in exactly 89 minutes , most of which passed a...",negative,negative


In [74]:
results.to_csv("./outputs/SST2_dataset_results.csv", index=False)