In [1]:
import pandas as pd
from datasets import load_dataset
import os
from openai import OpenAI

# Import dataset from HuggingFace

In [2]:
orig_dataset = load_dataset('go_emotions', 'raw')

orig_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'id', 'author', 'subreddit', 'link_id', 'parent_id', 'created_utc', 'rater_id', 'example_very_unclear', 'admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'],
        num_rows: 211225
    })
})

In [3]:
orig_dataset = orig_dataset['train'].to_pandas()

# Remove unnecessary columns.
# All records have example_very_unclear = False
orig_dataset = orig_dataset.drop(['id', 'author', 'subreddit', 'link_id', 'parent_id', 'created_utc', 'rater_id', 'example_very_unclear'], axis=1)

orig_dataset

Unnamed: 0,text,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,desire,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
0,That game hurt.,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
1,>sexuality shouldn’t be a grouping category I...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,"You do right, if you don't care then fuck 'em!",0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,Man I love reddit.,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
4,"[NAME] was nowhere near them, he was by the Fa...",0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211220,Everyone likes [NAME].,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
211221,Well when you’ve imported about a gazillion of...,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
211222,That looks amazing,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
211223,The FDA has plenty to criticize. But like here...,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
label_columns = orig_dataset.columns.tolist()
# From go_emotions README
label_encoding = {  '0': 'admiration',      '1': 'amusement',   '2': 'anger',           '3': 'annoyance',   '4': 'approval',    
                    '5': 'caring',          '6': 'confusion',   '7': 'curiosity',       '8': 'desire',      '9': 'disappointment',  
                    '10': 'disapproval',    '11': 'disgust',    '12': 'embarrassment',  '13': 'excitement', '14': 'fear',   
                    '15': 'gratitude',      '16': 'grief',      '17': 'joy',            '18': 'love',       '19': 'nervousness',    
                    '20': 'optimism',       '21': 'pride',      '22': 'realization',    '23': 'relief',     '24': 'remorse',    
                    '25': 'sadness',        '26': 'surprise',   '27': 'neutral' }

inverse_encoding = {}
for key, value in label_encoding.items():
    inverse_encoding[value] = key

In [5]:
# Create a column 'labels' with a list of label values
orig_dataset.insert(1,'labels','')
orig_dataset['labels'] = orig_dataset[label_columns[1:]].values.tolist()
orig_dataset['labels'] = orig_dataset['labels'].apply(lambda t: [i for i, x in enumerate(t) if x])

# Remove unlabeled records
orig_dataset = orig_dataset[orig_dataset['labels'].map(lambda d: len(d)) > 0]

orig_dataset

Unnamed: 0,text,labels,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
0,That game hurt.,[25],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,"You do right, if you don't care then fuck 'em!",[27],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,Man I love reddit.,[18],0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
4,"[NAME] was nowhere near them, he was by the Fa...",[27],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
5,Right? Considering it’s such an important docu...,[15],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211219,"Well, I'm glad you're out of all that now. How...",[17],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
211220,Everyone likes [NAME].,[18],0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
211221,Well when you’ve imported about a gazillion of...,[5],0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
211222,That looks amazing,[0],1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


# Augment

In [6]:
# Get the number of records needed to balance each label
label_values = pd.Series([x for item in orig_dataset.labels for x in item]).value_counts()
label_values.drop(27, inplace=True)   # 'Neutral' is more of a lack of emotion than an emotion
balance_label_values = label_values.max() - label_values

balance_label_values.sort_index(ascending=True)

0       489
1      8375
2      9536
3      4002
4         0
5     11621
6     10261
7      7928
8     13803
9      9151
10     6196
11    12319
12    15144
13    11991
14    14423
15     5995
16    16947
17     9637
18     9429
19    15810
20     8905
21    16318
22     8835
23    16331
24    15095
25    10862
26    12106
Name: count, dtype: int64

In [8]:
target_label = balance_label_values.idxmax()
print(target_label)

# Temp remove target labeled records and get a random record from remaining dataset
sample = orig_dataset[~orig_dataset['labels'].apply(lambda x: target_label in x)].sample()

sample

16


Unnamed: 0,text,labels,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
109046,Yeah I'm not arguing there isnt recreational p...,[10],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [9]:
# Generate prompt
sample_text = sample['text'].values[0]
print(sample_text)
print(str(sample['labels'].values[0]))
sample_label = label_encoding[str(sample['labels'].values[0][0])]
if len(sample['labels'].values[0]) == 1:
    sample_label = ' ' + sample_label
    
else:
    sample_label = 's ' + sample_label
    
    if len(sample['labels'].values[0]) == 2:
        sample_label += " and " + label_encoding[str(sample['labels'].values[0][1])]
        
    elif len(sample['labels'].values[0]) > 2:
        for label in sample['labels'].values[0][1:]:
            if label != sample['labels'].values[0][-1]:
                sample_label += ', ' + label_encoding[str(label)]
                
            else:
                sample_label += ', and ' + label_encoding[str(label)]       
                
print(sample_label)
query = f"This comment, \"{sample_text}\" portrays the emotion{sample_label}. Based on the topic of this comment, generate a new comment that would portray {label_encoding[str(target_label)]}."
query

Yeah I'm not arguing there isnt recreational potential though it can certainly mess you up. Doesn't help most products with dxm have other active ingredients.
[10]
 disapproval


'This comment, "Yeah I\'m not arguing there isnt recreational potential though it can certainly mess you up. Doesn\'t help most products with dxm have other active ingredients." portrays the emotion disapproval. Based on the topic of this comment, generate a new comment that would portray grief.'

In [10]:
client = OpenAI(api_key=os.environ.get('OPENAI_API_KEY'))
print(os.environ.get('OPENAI_API_KEY'))

sk-proj-TxnpcCj0mnUYuK1uTL4qT3BlbkFJOPZkjyBGea6YlqkKf7Pr


In [11]:
# OpenAI InstructGPT
response = client.completions.create(
    model="gpt-3.5-turbo-instruct",
    prompt=query
)

print(response)

RateLimitError: Error code: 429 - {'error': {'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.', 'type': 'insufficient_quota', 'param': None, 'code': 'insufficient_quota'}}

In [19]:
# Build new records until balance is achieved 
while balance_label_values.min():
    goal_label = balance_label_values.idxmax()
    
    
    # Start with the most minor class
    for index, value in balance_label_values.sort_values(ascending=False).items():
        print(f'{index} {value}')

16 16947
23 16331
21 16318
19 15810
12 15144
24 15095
14 14423
8 13803
11 12319
26 12106
13 11991
5 11621
25 10862
6 10261
17 9637
2 9536
18 9429
9 9151
20 8905
22 8835
1 8375
7 7928
10 6196
15 5995
3 4002
0 489
4 0
