# Setup

In [1]:
dataset_name = 'go_emotions'    # Name of HF dataset to load
dataset_subset = 'raw'          # Name of HF subset

synthetic_dataset_name = 'synthetic_dataset.parquet'
synthetic_dataset_dir = './synthetic_dataset/'

batch_size = 10     # Number of synthetic records to generate between saves

## Common Imports

In [2]:
import pandas as pd
from datasets import load_dataset
from IPython.display import display
import os
from openai import OpenAI
import string

# Import Dataset from HuggingFace

In [3]:
orig_dataset = load_dataset(dataset_name, dataset_subset)

orig_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'id', 'author', 'subreddit', 'link_id', 'parent_id', 'created_utc', 'rater_id', 'example_very_unclear', 'admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'],
        num_rows: 211225
    })
})

In [4]:
orig_dataset = orig_dataset['train'].to_pandas()

# Remove unnecessary columns.
# All records have example_very_unclear = False
orig_dataset = orig_dataset.drop(['id', 'author', 'subreddit', 'link_id', 'parent_id', 'created_utc', 'rater_id', 'example_very_unclear'], axis=1)

orig_dataset

Unnamed: 0,text,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,desire,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
0,That game hurt.,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
1,>sexuality shouldn’t be a grouping category I...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,"You do right, if you don't care then fuck 'em!",0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,Man I love reddit.,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
4,"[NAME] was nowhere near them, he was by the Fa...",0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211220,Everyone likes [NAME].,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
211221,Well when you’ve imported about a gazillion of...,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
211222,That looks amazing,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
211223,The FDA has plenty to criticize. But like here...,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
label_columns = orig_dataset.columns.tolist()

label_encoding = {}
for i, label in enumerate(label_columns[1:]):
    label_encoding[label] = i
    print(f"'{label}': {i}")

inverse_encoding = {}
for key, value in label_encoding.items():
    inverse_encoding[value] = key

'admiration': 0
'amusement': 1
'anger': 2
'annoyance': 3
'approval': 4
'caring': 5
'confusion': 6
'curiosity': 7
'desire': 8
'disappointment': 9
'disapproval': 10
'disgust': 11
'embarrassment': 12
'excitement': 13
'fear': 14
'gratitude': 15
'grief': 16
'joy': 17
'love': 18
'nervousness': 19
'optimism': 20
'pride': 21
'realization': 22
'relief': 23
'remorse': 24
'sadness': 25
'surprise': 26
'neutral': 27


In [6]:
# Aggregate labels in a list column
orig_dataset.insert(1,'labels','')
orig_dataset['labels'] = orig_dataset[label_columns[1:]].values.tolist()
orig_dataset['labels'] = orig_dataset['labels'].apply(lambda t: [i for i, x in enumerate(t) if x])

# Remove unlabeled records
orig_dataset = orig_dataset[orig_dataset['labels'].map(lambda x: len(x)) > 0]

orig_dataset

Unnamed: 0,text,labels,admiration,amusement,anger,annoyance,approval,caring,confusion,curiosity,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
0,That game hurt.,[25],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,"You do right, if you don't care then fuck 'em!",[27],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,Man I love reddit.,[18],0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
4,"[NAME] was nowhere near them, he was by the Fa...",[27],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
5,Right? Considering it’s such an important docu...,[15],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211219,"Well, I'm glad you're out of all that now. How...",[17],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
211220,Everyone likes [NAME].,[18],0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
211221,Well when you’ve imported about a gazillion of...,[5],0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
211222,That looks amazing,[0],1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


# Generating Synthetic Data
## Functions

In [7]:
def get_random_record(dataset, enc_target_label):
    # Temp remove target labeled records and get a random record from remaining dataset
    sample = dataset[~dataset['labels'].apply(lambda x: enc_target_label in x)].sample()
    # If a neutral sample is randomly selected, select again until it's something not-neutral
    while 27 in sample['labels'].tolist()[0]:
       sample = dataset[~dataset['labels'].apply(lambda x: enc_target_label in x)].sample()
    
    return sample

In [8]:
def generate_text_prompt(sample, enc_target_label):
    sample_text = sample['text'].values[0]
    print(sample_text)
    # Translate list of encoded labels to prompt
    match len(sample['labels'].values[0]):
        case 1:
            sample_label = inverse_encoding[sample['labels'].values[0][0]]

        case 2:
            sample_label = (inverse_encoding[sample['labels'].values[0][0]] + " and " 
                            + inverse_encoding[sample['labels'].values[0][1]])
            
        case _:
            sample_label = inverse_encoding[sample['labels'].values[0][0]]
            for label in sample['labels'].values[0][1:]:
                if label != sample['labels'].values[0][-1]:
                    sample_label += ', ' + inverse_encoding[label]
                else:
                    sample_label += ', and ' + inverse_encoding[label]       
    
    print(sample_label)
    
    if len(sample['labels'].values[0]) == 1:
        sample_label = ' ' + sample_label
    else:
        sample_label = 's ' + sample_label
        
    old_query = f"The comment \"{sample_text}\" portrays the emotion{sample_label}. Based on the topic of this comment, generate a new comment that would portrays a clear example of {inverse_encoding[enc_target_label]}"
    
    new_query = f"Using the comment \"{sample_text}\" which portrays the emotion{sample_label}, generate a similar comment that instead portrays {inverse_encoding[enc_target_label]}"
    
    return new_query

In [9]:
# Removes all line breaks, spaces from start/end, and punctuation from start
# If labels=True, removes all punctuation and lower cases the string
# If labels=False, removes trailing quotes if inappropriate
def clean_response(text, labels=False):
    
    text = text.replace('\n', ' ')
    try:
        text[0]
    except IndexError:
        pass
    else:
        while text[0] == ' ' or text[0] in "!#$%&()*+,-./:;<=>?@[\\]^_`{|}~":
            text = text[1:]
    
        while text[-1] == ' ':
            text = text[:-1]
    
    if labels:
        text = text.translate(str.maketrans('', '', string.punctuation)).lower()
    
    if not labels:
        if text.count('"') == 0:
            print("NO QUOTE RESPONSE")
        elif text.count('"') % 2 == 0:
            text = text[text.find('"')+1:text.rfind('"')]
        else:
            text = text[text.find('"')+1:]

    return text

In [10]:
def generate_label_prompts(enc_target_label, response_text):
    target_query = f"Does the comment \"{response_text}\" portray the emotion {inverse_encoding[enc_target_label]}? Limit your answer to yes or no."
    
    valid_emotions = ', '.join(label_columns[2:])
    valid_emotions = "[" + valid_emotions + "]"
    
    other_query = f"Classify the comment \"{response_text}\" by one or more emotions ONLY from the following list: {valid_emotions}"
    
    return target_query, other_query

In [11]:
def target_consensus(query_text):
    response = client.completions.create(
        model="gpt-3.5-turbo-instruct", 
        prompt=query_text,
        temperature=1,
        n=5,
        logit_bias={"1734": -100}
    )
    
    responses = [clean_response(choice.text, labels=True) for choice in response.choices]
    
    # Cutoff for a consensus is 4/5 yes votes
    if sum('yes' in text for text in responses) >= 4:
        return True
    elif sum('no' in text for text in responses) >= 2:
        return False
    else:
        print("Unexpected responses: " + responses)
    
    return False

In [12]:
# Prompts LLM for emotion labels from a text. Returns a list of emotions 4 of 5 prompts included
def get_labels(query_text):
    
    response = client.completions.create(
        model="gpt-3.5-turbo-instruct", 
        prompt=query_text,
        temperature=1,
        n=5,
        logit_bias={"1734": -100}
    )
    
    responses = [clean_response(choice.text, labels=True) for choice in response.choices]
    
    response_labels = [text.split(" ") for text in responses]
    response_labels = [label.lower() for labels in response_labels for label in labels]   # flatten the list
    
    label_list = []
    while response_labels:
        label = response_labels[0]
        
        if label in label_columns[2:]:  # Make sure the word is a valid label                
            if response_labels.count(label) >= 4:   # 4/5 responses included the label
                label_list.append(label)
        
        # Remove duplicates of current label
        while label in response_labels:
            response_labels.remove(label)
    
    return label_list

## Initialization

In [13]:
client = OpenAI(api_key=os.environ.get('OPENAI_API_KEY'))

try:
    synth_dataset = pd.read_parquet(path=synthetic_dataset_dir+synthetic_dataset_name)
except FileNotFoundError:
    synth_dataset = pd.DataFrame(columns = ['text', 'labels', 'source index', 'source labels', 'intended label'])

try:
    synth_dataset.to_parquet(path=synthetic_dataset_dir+synthetic_dataset_name)
except OSError:
    os.makedirs(synthetic_dataset_dir)
    synth_dataset.to_parquet(path=synthetic_dataset_dir+synthetic_dataset_name)
        
# Build a list of how imbalanced each class is
label_values = pd.Series([x for item in orig_dataset.labels for x in item]).value_counts()
label_values.drop(27, inplace=True)   # 'Neutral' is more of a lack of emotion than an emotion
label_imbalance_values = label_values.max() - label_values
label_imbalance_values.at[27] = 0
label_imbalance_values.sort_index(ascending=True)

0       489
1      8375
2      9536
3      4002
4         0
5     11621
6     10261
7      7928
8     13803
9      9151
10     6196
11    12319
12    15144
13    11991
14    14423
15     5995
16    16947
17     9637
18     9429
19    15810
20     8905
21    16318
22     8835
23    16331
24    15095
25    10862
26    12106
27        0
Name: count, dtype: int64

## Main Loop

In [14]:
batch_num = 1

enc_target_label = label_imbalance_values.idxmax()
target_label = inverse_encoding[enc_target_label]

while label_imbalance_values.any():
    print(f"Target: {target_label} : {enc_target_label} : {label_imbalance_values[enc_target_label]}")
    
    sample = get_random_record(orig_dataset, enc_target_label)
    text_query = generate_text_prompt(sample, enc_target_label)
    
    # OpenAI InstructGPT
    response = client.completions.create(
        model="gpt-3.5-turbo-instruct", 
        prompt=text_query, 
        max_tokens=50, 
        logit_bias={"1734": -100} # Remove line breaks
    )
    
    response_text = clean_response(response.choices[0].text)    # InstructGPT response starts with \n\n
    
    assert response_text, f"Entry wiped by clean_response(): {response.choices[0].text}" 
    print("Synthetic Record Text: " + response_text)
    
    target_query, other_labels_query = generate_label_prompts(enc_target_label, response_text)
    
    # Does the generated text match the intended label?
    if target_consensus(target_query):
        consensus_labels = [inverse_encoding[enc_target_label]]
    else:
        consensus_labels = []
    
    synth_record_labels = get_labels(other_labels_query)
    
    if synth_record_labels:
        if consensus_labels and target_label not in synth_record_labels:
            consensus_labels += synth_record_labels
        else:
            consensus_labels = synth_record_labels
    
    print(f"Synthetic labels: {consensus_labels}")
    
    # Only save if a label is found
    if consensus_labels:
        source_labels = orig_dataset.loc[sample.index.values[0], 'labels']
        source_labels = list(map(inverse_encoding.get, source_labels))
        next_row = len(synth_dataset)
        synth_dataset.loc[next_row] = { 'text': response_text, 
                                        'labels': consensus_labels, 
                                        'source index': sample.index.values[0], 
                                        'source labels': source_labels, 
                                        'intended label': inverse_encoding[enc_target_label]  }
        
        if batch_num % batch_size == 0:
            synth_dataset.to_parquet(path=synthetic_dataset_dir+synthetic_dataset_name)
            print(f"\nSaving data. {len(synth_dataset)} records.")
        
    # Update label values
    for label in consensus_labels:
        if label_imbalance_values[label_encoding[label]] != 0:
            label_imbalance_values[label_encoding[label]] -= 1
        
        # If the synthetic data label doesn't include the target label, the target is still the highest imbalance
        if label == target_label:
            enc_target_label = label_imbalance_values.idxmax()
            target_label = inverse_encoding[enc_target_label]
    
    batch_num += 1
    print()

synth_dataset.to_parquet(path=synthetic_dataset_dir+synthetic_dataset_name)
print(f"\nSaving data. {len(synth_dataset)} records.")

Target: grief : 16 : 16947
I’d tell her nothing, that freaking weirdo would just never see my kid. The End. Disgusting.
disgust
Synthetic Record Text: I wouldn't even know how to begin to explain the loss of my child to someone who could never understand it. They would never see the light, the love, and the beauty that my child brought into this world. It's a tragic shame
Synthetic labels: ['grief', 'sadness', 'love']

Target: grief : 16 : 16946
Be true to yourself. Your the only one who will be
optimism
Synthetic Record Text: It's okay to let yourself feel the pain. It's a reminder of how much you loved and cared. Take your time to grieve, and know that it's a necessary step towards healing.
Synthetic labels: ['grief', 'caring']

Target: grief : 16 : 16945
same i feel u
disappointment
Synthetic Record Text: Same, it's like a heavy weight on my heart.
Synthetic labels: ['grief', 'sadness']

Target: grief : 16 : 16944
It would definitely suck!!
annoyance
Synthetic Record Text: It would 

KeyboardInterrupt: 

In [None]:
synth_dataset