# Setup

In [1]:
dataset = "enISEAR"          # See dataset_config.py for dataset options
subset = None

augment_to = "original balance"    # Options: "original balance", "synthetic balance", integer value, None
llm = "ChatGPT 4o-Mini"     # See LLM_config.py for LLM options
keep_incorrect_but_consensus_labels = False
set_target_label = None

synthetic_dataset_relpath = "./synthetic_datasets/"

## Config Files

In [2]:
import ollama
from openai import OpenAI

In [3]:
from config_files import dataset_config
from config_files import LLM_config

## Imports

In [4]:
import pandas as pd
import numpy as np
# from datasets import load_dataset
import re
import os
import random

# Classes

In [5]:
class Record:
    def __init__(self, index, text, labels):
        self.index = index
        self.text = text
        self.labels = labels

# Functions
## Original Dataset Loading

In [6]:
def load_real_dataset(dataset_details):
    dataset = pd.DataFrame([])
    
    if dataset_details["location"] == "local":
        
        if dataset_details["is_split"]:
            
            if dataset_details["filetype"] == "csv":
                dataset = pd.read_csv(dataset_details["train_relpath"])
            
            elif dataset_details["filetype"] == "tsv":
                dataset = pd.read_csv(dataset_details["train_relpath"], sep="\t")
        
        else: # Dataset is not split
            if dataset_details["filetype"] == "csv":
                dataset = pd.read_csv(dataset_details["abspath"])
            
            elif dataset_details["filetype"] == "tsv":
                dataset = pd.read_csv(dataset_details["abspath"], sep="\t")
    
    return dataset

## Preprocessing

In [7]:
def preprocess_dataframe(dataset_details, dataframe):
    dataframe.drop(columns = dataset_details["unused_columns"], inplace=True)
    dataframe.rename(columns = dataset_details["remap_columns"], inplace=True)        

    dataframe.drop(dataframe[dataframe['labels'] == dataset_details["unlabeled_label"]].index, inplace=True) # Remove unlabeled records from original dataframe

## Synthetic Dataset Loading

In [8]:
def load_synthetic_dataset(dataset_details, llm_details):
    directory = synthetic_dataset_relpath + f"{dataset_details['id']}/"
    filename = llm_details['id'].replace(":", "_") + ".parquet"

    try:
        synthetic_dataset = pd.read_parquet(path=directory+filename)
        print("Synthetic dataset found.")
        
    except FileNotFoundError:
        print("No synthetic dataset found. Creating an empty synthetic dataframe.")
        
        synthetic_dataset = pd.DataFrame(columns = ['text', 'labels', 'all labels', 'source index', 'source label', 'intended label'])
                
    return synthetic_dataset

## Find Label Imbalance Counts

In [9]:
def find_label_imbalance_counts(df_original, dataset_details, llm_details):
    
    original_label_counts =  pd.Series(df_original.labels).value_counts()
    print(f"\nORIGINAL LABEL COUNTS:\n{original_label_counts}")
    
    df_synthetic = load_synthetic_dataset(dataset_details, llm_details)
    synthetic_label_counts = pd.Series(df_synthetic.labels).value_counts()
    print(f"\nSYNTHETIC LABEL COUNTS:\n{synthetic_label_counts}")
    del df_synthetic        # Only needed for the label counts

    if augment_to == "original balance":
        combined_label_counts = np.subtract(original_label_counts, original_label_counts.max())
        for label in synthetic_label_counts.index:
            combined_label_counts[label] += synthetic_label_counts.loc[label]
    
    elif augment_to == "synthetic balance":
        for label in dataset_details['label_list']:
            if label not in synthetic_label_counts.index:
                label_row = pd.Series({label: 0})
                synthetic_label_counts = pd.concat([synthetic_label_counts, label_row])
        combined_label_counts = np.subtract(synthetic_label_counts, synthetic_label_counts.max())
    
    elif isinstance(augment_to, int):
        combined_label_counts = pd.Series(name="labels")
        for label in dataset_details['label_list']:
            label_row = pd.Series({label: -300})
            combined_label_counts = pd.concat([combined_label_counts, label_row])
            
    elif augment_to is None:
        combined_label_counts = pd.Series(name="labels")
        for label in dataset_details['label_list']:
            label_row = pd.Series({label: -99999})
            combined_label_counts = pd.concat([combined_label_counts, label_row])
    
    print(f"\nCOMBINED LABEL DEFICITS:\n{combined_label_counts}")
    
    return combined_label_counts

## Generate Text
### Get A Random Record

In [10]:
def get_random_record(dataset, target_label):
    # Temporarily remove target labeled records and get a random record from remaining dataset 
    record = dataset[~dataset['labels'].apply(lambda x: target_label in x)].sample()
    record_obj = Record(record.index[0], record.text.values[0], record.labels.values[0])
    
    return  record_obj

### Prompt

def build_text_prompt(dataset_details, target_label, original_record):   
    prototype_prompt = f"Using the {dataset_details["text_source"]} \"{original_record.text}\" which portrays the emotion{original_record.labels}, generate a similar {dataset_details["text_source"]} that instead portrays {target_label}."
    
    raw_prompt = f"The following is a {dataset_details["text_source"]} portraying {original_record.labels}. \"{original_record.text}\". Using this {dataset_details["text_source"]}, generate a {dataset_details["text_source"]} about the same subject and similar in style that instead portrays {target_label}. Only give the generated {dataset_details["text_source"]}."
    
    tokenized_prompt = f"The following is a {dataset_details["text_source"]} with any usernames, names, hashtags, and URLs tokenized with an all-caps generalized term. \"{original_record.text}\". Using this {dataset_details["text_source"]}, which portrays the emotion {original_record.labels}, generate a {dataset_details["text_source"]} about the same subject and similar in style that instead portrays {target_label}. Only give the generated {dataset_details["text_source"]}."
    
    enISEAR_prompt = f"The following is a {dataset_details["text_source"]} portraying the emotion {original_record.labels}: \"{original_record.text}\". Using this {dataset_details["text_source"]}, create a {dataset_details["text_source"]} about the same subject and similar in style that instead portrays {target_label}. Only give the generated {dataset_details["text_source"]}."
    
    
    llm_derived_prompt = f"Create a {dataset_details["text_source"]} portraying {target_label} similar to this {dataset_details["text_source"]} portraying {original_record.labels}: \"{original_record.text}\". Replace usernames, names, hashtags, and URLs with tokenized all-caps terms (e.g., USER, NAME, HASHTAG, URL). Do NOT use all-caps unless tokenizing as indicated. Do not explain your response."
    
    sample_prompt = f"Change this {dataset_details["text_source"]} portraying {original_record.labels} to instead portray {target_label}: \"{original_record.text}\". Tokenize usernames, names, hashtags, and URLs with tokenized all-caps terms (e.g., USER, NAME, HASHTAG, URL). Do NOT use all-caps unless tokenizing as indicated."
    
    return 

In [11]:
def build_text_prompt(target_label, original_record):
    from config_files import prompt_config
    text_prompt = prompt_config.prompt[dataset]['text']
    
    text_prompt = text_prompt.replace("<original_record.labels>", original_record.labels)
    text_prompt = text_prompt.replace("<original_record.text>", original_record.text)
    text_prompt = text_prompt.replace("<target_label>", target_label)
    re.sub(' +', ' ', text_prompt)
    
    return text_prompt

### Prompt LLM for Synthetic Record

In [12]:
def generate_synthetic_text(llm_details, dataset_details, target_label, text_prompt):

    if llm_details["platform"] == "Ollama":        
        response = ollama.chat(
            model=llm_details["id"], 
            messages=[
                {"role": "user", "content": text_prompt},
                {"role": "assistant", "content": f"Here is a similar {dataset_details["text_source"]} portraying {target_label}:\n\n\""}  # Starts the LLM's response, preventing guardrails from censoring fear or anger based responses
            ])
        
        return parse_text_response(response["message"]["content"])
    
    elif llm_details["platform"] == "OpenAI":        
        client = OpenAI(api_key=os.environ.get('OPENAI_API_KEY'))
        response = client.chat.completions.create(
            model=llm_details["id"],
            messages = [{
                "role": "user",
                "content": text_prompt,
            }],
            max_tokens=50
        )
        
        return response.choices[0].message.content

### Parse Synthetic Text response 

In [13]:
def parse_text_response(response):
    
    CRED = '\33[91m'
    CBLU = '\33[34m'
    CEND = '\33[0m'
    
    # If the LLM is going to explain the rationale for its response, it will be after a couple line breaks.
    if response.find('\n\n')+1:   # -1 is returned if not found
        print(f"\n{CRED}Newline Response:\n\t{response}{CEND}")
        response = response[:response.find('\n')]
    
    # The synthetic tweet is given a quote to start, so it will often have a quote to end.
    quote_indexes = [i.start() for i in re.finditer('\"', response)]
    if len(quote_indexes):
        response = response[:quote_indexes[-1]]
    else:
        print(f"\n{CBLU}No Quote Response:\n\t{response}{CEND}")
        
    return response

## Generate Label
### Prompt

def build_label_prompt(dataset_details, synthetic_text):
    emoevent = \
        f"""Classify the {dataset_details['text_source']} \"{synthetic_text}\" by the single most represented {dataset_details['label_type']} ONLY from the following list:\n1. Anger (also includes annoyance, rage)\n2. Disgust (also includes disinterest, dislike, loathing)\n3. Fear (also includes apprehension, anxiety, terror)\n4. Joy (also includes serenity, ecstasy)\n5. Sadness (also includes pensiveness, grief)\n6. Surprise (also includes distraction, amazement)\nGive only the label."""
    
    enISEAR = f"""Classify the {dataset_details["text_source"]} \"{synthetic_text}\" by the single most represented {dataset_details["label_type"]} ONLY from the following list:"
    1. Anger
    2. Disgust
    3. Fear
    4. Guilt
    5. Joy
    6. Sadness
    7. Shame
    Give only the label."""
    
    stack_overflow = \
    f"""Classify the {dataset_details["text_source"]} \"{synthetic_text}\" by the single most represented {dataset_details["label_type"]} ONLY from the following list:"
    1. ANGER
    2. FEAR
    3. JOY
    4. LOVE
    5. SADNESS
    6. SURPRISE
    Give only the label."""

    return emoevent

In [14]:
def build_label_prompt(synthetic_text):
    from config_files import prompt_config
    
    labels_prompt = prompt_config.prompt[dataset]['labels']
    labels_prompt = labels_prompt.replace("<synthetic_text>", synthetic_text)
    re.sub(' +', ' ', labels_prompt)
    
    return labels_prompt

### Prompt LLM for Label

In [15]:
def generate_synthetic_label(dataset_details, llm_details, label_prompt):
    if llm_details["platform"] == "Ollama":
        response = ollama.chat(
            model=llm_details["id"], 
            messages=[
                {"role": "user", "content": label_prompt}
            ])
        
        return response["message"]["content"]
    
    elif llm_details["platform"] == "OpenAI":
        client = OpenAI(api_key=os.environ.get('OPENAI_API_KEY'))
        
        response = client.chat.completions.create(
            model=llm_details["id"],
            messages = [{
                "role": "user",
                "content": label_prompt,
            }],
            n=dataset_details["num_labelers"],
            max_tokens=30
        )
        
        response_text = []
        for choice in response.choices:
            response_text.append(choice.message.content)
            
        return response_text

### Parse Label Response

In [16]:
def parse_label_response(response, dataset_details):
    for label in dataset_details["label_list"]:
        if label.lower() in response.lower():
            return label
    
    # Label name not found, look for ID number
    for i in range(1, len(dataset_details["label_list"]) + 1):
        if str(i) in response:
            return dataset_details["label_list"][i-1]
    
    # Label not found
    CRED = '\33[91m'
    CEND = '\33[0m'
    print(f"{CRED}NO LABEL FOUND:{CEND} {response}")
    return None

### Label Record

In [17]:
def get_label(dataset_details, llm_details, text):
    label_prompt = build_label_prompt(text)
    
    labels = []
    print("Labels: ", end="")
    if llm_details["platform"] == "Ollama":
        
        for i in range(dataset_details["num_labelers"]):
            label_response = generate_synthetic_label(dataset_details, llm_details, label_prompt)
            labels.append(parse_label_response(label_response, dataset_details))
        
            if i > 0:
                print(", ", end="")
            print(f"{labels[i]}", end="")
    
    elif llm_details["platform"] == "OpenAI":
        responses = generate_synthetic_label(dataset_details, llm_details, label_prompt)
        for i, response in enumerate(responses):
            labels.append(parse_label_response(response, dataset_details))
            
            if i > 0:
                print(", ", end="")
            print(f"{labels[i]}", end="")
    
    consensus_label = None
    
    if dataset_details["label_format"] == "single":
        # Single label dataset
        for potential_label in dataset_details["label_list"]:
            if labels.count(potential_label) >= dataset_details["num_consensus"]:
                consensus_label = potential_label
                print(f"\n\tConsensus: {consensus_label}")
                
    elif dataset_details["label_format"] == "multi":
        # To be implemented if used with any multilabel datasets
        pass
    
    return labels, consensus_label

### Generating a synthetic record

In [18]:
def generate_synthetic_record(real_dataset, dataset_details, llm_details, target_label):
    random_record = get_random_record(real_dataset, target_label)   # Select a random record not in smallest class
    print(f"\nRandom Record:\n"
          f"\tIndex: {random_record.index}\n"
          f"\tText: {random_record.text}\n"
          f"\tLabel: {random_record.labels}")
    
    text_prompt = build_text_prompt(target_label, random_record)  # Build 
    synthetic_text = generate_synthetic_text(llm_details, dataset_details, target_label, text_prompt)     # Prompt LLM for synthetic record
    print(f"\nSynthetic Text:\n"
          f"\t{synthetic_text}\n")
    
    labels, consensus_label = get_label(dataset_details, llm_details, synthetic_text)     # Label the record (may differ from target label)
        
    return { "text" : synthetic_text, 
             "labels" : consensus_label, 
             "all labels" : labels, 
             "source index" : random_record.index, 
             "source label" : random_record.labels, 
             "intended label" : target_label }

### Saving the synthetic dataset

In [19]:
def save_dataset(dataset_details, llm_details, working_data):
    directory = synthetic_dataset_relpath + f"{dataset_details['id']}/"
    filename = llm_details['id'].replace(":", "_") + ".parquet"
    
    old_data = load_synthetic_dataset(dataset_details, llm_details)
    new_data = pd.concat([old_data, working_data], ignore_index=True)
    
    try:
        new_data.to_parquet(path=directory+filename)
    except OSError:
        os.makedirs(directory)
        new_data.to_parquet(path=directory+filename)
    
    print("+ Synthetic dataset saved!")
    
    del old_data, new_data

# Augment
## Setup

In [20]:
dataset_metadata = dataset_config.dataset[dataset]
llm_metadata = LLM_config.model[llm]

# Load real data
df_real_data = load_real_dataset(dataset_metadata)
display(df_real_data)

# Homogenize and remove any 'unlabeled' class records
preprocess_dataframe(dataset_metadata, df_real_data)      

# Find how many of each record is needed to balance the dataset
imbalance_counts = find_label_imbalance_counts(df_real_data, dataset_metadata, llm_metadata)

Unnamed: 0,Sentence_id,Prior_Emotion,Sentence,Temporal_Distance,Intensity,Duration,Gender,City,Country,Worker_id,Time,Anger,Disgust,Fear,Guilt,Joy,Sadness,Shame
0,271,Fear,"I felt ... when my 2 year old broke her leg, a...",Y,Vi,Dom,Ml,Bristol,GBR,87,11/28/2018 0:58:52,0,0,0,1,0,3,1
1,597,Shame,I felt ... one Christmas as one of our patient...,Y,I,Dom,Fl,Dulwich,GBR,86,11/26/2018 6:52:02,1,0,0,4,0,0,0
2,282,Guilt,I felt ... because I could not help a friend w...,M,Mi,Dom,Fl,Linlithgow,GBR,83,11/21/2018 18:45:00,0,0,0,4,0,1,0
3,171,Disgust,I felt ... when I read that hunters had killed...,Y,Mi,H,Ml,Bristol,GBR,87,11/28/2018 0:55:11,3,0,0,0,0,2,0
4,509,Sadness,I felt ... when my Gran passed away.,Y,Vi,Dom,Fl,Stoke-on-trent,GBR,92,11/26/2018 9:23:38,0,0,0,0,0,5,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
996,579,Shame,I felt ... that the neighbours in the small vi...,Y,Vi,Dom,Fl,Dulwich,GBR,86,11/25/2018 16:32:23,1,1,0,0,0,0,3
997,593,Shame,I feel ... because I behave in a way that I am...,W,Mi,H,Fl,Tunbridge Wells,GBR,122,11/24/2018 11:11:15,0,0,0,3,0,0,2
998,605,Shame,I felt ... because I fell over in public.,W,N,Fm,Fl,Sheffield,GBR,56,11/26/2018 17:28:12,0,0,0,0,0,0,5
999,606,Shame,I felt ... giving a cheque to the managing age...,W,I,H,Fl,Shepherds Bush,GBR,90,11/26/2018 21:26:45,0,0,0,2,0,0,3



ORIGINAL LABEL COUNTS:
labels
Fear       143
Shame      143
Guilt      143
Disgust    143
Sadness    143
Anger      143
Joy        143
Name: count, dtype: int64
No synthetic dataset found. Creating an empty synthetic dataframe.

SYNTHETIC LABEL COUNTS:
Series([], Name: count, dtype: int64)

COMBINED LABEL DEFICITS:
labels
Fear       0
Shame      0
Guilt      0
Disgust    0
Sadness    0
Anger      0
Joy        0
Name: count, dtype: int64


## Augmenting Loop

In [21]:
SAVE_BATCH_SIZE = 10
batch_count = 0

df_synthetic = pd.DataFrame(columns = ['text', 'labels', 'all labels', 'source index', 'source label', 'intended label'])
            
while imbalance_counts.min() < 0:   # Get lesser classes up to the size of the largest class
    
    target_label = set_target_label if set_target_label else imbalance_counts.idxmin()
    
    print("----------------------------------------------------------------------------")
    print(f"Targeting: {target_label} ({imbalance_counts[target_label]})")
    synthetic_record = generate_synthetic_record(df_real_data, 
                                                 dataset_metadata, 
                                                 llm_metadata, 
                                                 target_label)
    
    if synthetic_record["labels"]:
        # Maybe synthetic records where the labels don't match the intended label are muddying the labels?
        if synthetic_record["labels"] == target_label or keep_incorrect_but_consensus_labels:
            df_synthetic.loc[len(df_synthetic)] = synthetic_record
            
            imbalance_counts[synthetic_record["labels"]] += 1    # Update imbalance counts
            batch_count += 1
            
            if batch_count == SAVE_BATCH_SIZE:
                save_dataset(dataset_metadata, llm_metadata, df_synthetic)
                df_synthetic = df_synthetic.iloc[0:0]   # Clear dataframe contents
                batch_count = 0
                print(imbalance_counts)
        else:
            print("RECORD LABEL != TARGET LABEL. DISCARDING.")
            
save_dataset(dataset_metadata, llm_metadata, df_synthetic)  # Final save if target balance is reached

No synthetic dataset found. Creating an empty synthetic dataframe.
+ Synthetic dataset saved!


In [22]:
# Self-save if augment loop is canceled.
save_dataset(dataset_metadata, llm_metadata, df_synthetic)

Synthetic dataset found.
+ Synthetic dataset saved!


In [23]:
# Find how many of each record is needed to balance the dataset
imbalance_counts = find_label_imbalance_counts(df_real_data, dataset_metadata, llm_metadata)


ORIGINAL LABEL COUNTS:
labels
Fear       143
Shame      143
Guilt      143
Disgust    143
Sadness    143
Anger      143
Joy        143
Name: count, dtype: int64
Synthetic dataset found.

SYNTHETIC LABEL COUNTS:
Series([], Name: count, dtype: int64)

COMBINED LABEL DEFICITS:
labels
Fear       0
Shame      0
Guilt      0
Disgust    0
Sadness    0
Anger      0
Joy        0
Name: count, dtype: int64
