# Setup
## Config Files

In [1]:
from config_files import dataset_config
from config_files import LLM_config

## Imports

In [2]:
import pandas as pd
import numpy as np
import ollama
# from datasets import load_dataset
import re
import os
import random

# Classes

In [3]:
class Record:
    def __init__(self, index, text, label):
        self.index = index
        self.text = text
        self.label = label
        
class SyntheticRecord:
    def __init__(self, text, label, labels, source_index, source_label, target_label):
        self.text = text
        self.label = label
        self.labels = labels
        self.source_index = source_index
        self.source_label = source_label
        self.target_label = target_label
    
    def to_dataframe(self):
        return { "text": self.text, 
                 "label": self.label, 
                 "all labels": self.labels, 
                 "source index": self.source_index,
                 "source label": self.source_label, 
                 "intended label": self.target_label }

# Functions
## Original Dataset Loading

In [4]:
def load_real_dataset(dataset_details):
    dataset = pd.DataFrame([])
    
    if dataset_details["location"] == "local":
        
        if dataset_details["is_split"]:
            
            if dataset_details["filetype"] == "csv":
                dataset = pd.read_csv(dataset_details["train_relpath"])
            
            elif dataset_details["filetype"] == "tsv":
                dataset = pd.read_csv(dataset_details["train_relpath"], sep="\t")
        
        else: # Dataset is not split
            if dataset_details["filetype"] == "csv":
                dataset = pd.read_csv(dataset_details["relpath"])
            
            elif dataset_details["filetype"] == "tsv":
                dataset = pd.read_csv(dataset_details["relpath"], sep="\t")
    
    return dataset

## Preprocessing

In [5]:
def preprocess_dataframe(dataset_details, dataframe):
    dataframe.drop(columns = dataset_details["unused_columns"], inplace=True)
    dataframe.rename(columns = dataset_details["remap_columns"], inplace=True)        

    dataframe.drop(dataframe[dataframe['label'] == dataset_details["unlabeled_label"]].index, inplace=True) # Remove unlabeled records from original dataframe

## Synthetic Dataset Loading

In [6]:
def load_synthetic_dataset(dataset_details, llm_details):
    directory = f"./synthetic_datasets/{dataset_details['id']}/"
    filename = llm_details['id'].replace(":", "_") + ".parquet"
    
    try:
        synthetic_dataset = pd.read_parquet(path=directory+filename)
        print("Synthetic dataset found.")
        
    except FileNotFoundError:
        print("No synthetic dataset found. Creating an empty synthetic dataframe.")
        
        if dataset_details["label_format"] == "single":
            synthetic_dataset = pd.DataFrame(columns = ['text', 'label', 'all labels', 'source index', 'source label', 'intended label'])
            
        elif dataset_details["label_format"] == "multi":
            synthetic_dataset = pd.DataFrame(columns = ['text', 'label', 'all labels', 'source index', 'source label', 'intended label'])
    
    return synthetic_dataset

## Find Label Imbalance Counts

In [7]:
def find_label_imbalance_counts(df_original, dataset_details, llm_details):
    df_synthetic = load_synthetic_dataset(dataset_details, llm_details)
    
    original_label_counts =  pd.Series(df_original.label).value_counts()
    synthetic_label_counts = pd.Series(df_synthetic.label).value_counts()
    print(f"\nORIGINAL LABEL COUNTS: {original_label_counts}")
    print(f"\nSYNTHETIC LABEL COUNTS: {synthetic_label_counts}")
    
    del df_synthetic        # Only needed for the label counts

    combined_label_counts = np.subtract(original_label_counts, original_label_counts.max())
    for label in synthetic_label_counts.index:
        combined_label_counts[label] += synthetic_label_counts.loc[label]
    print(f"\nCOMBINED LABEL COUNTS: {combined_label_counts}")
    
    return combined_label_counts

## Generate Text
### Get A Random Record

In [8]:
def get_random_record(dataset, target_label):
    # Temporarily remove target labeled records and get a random record from remaining dataset 
    record = dataset[~dataset['label'].apply(lambda x: target_label in x)].sample()
    record_obj = Record(record.index[0], record.text.values[0], record.label.values[0])
    
    return  record_obj

### Prompt

In [9]:
def build_text_prompt(dataset_details, target_label, original_record):   
    prototype_prompt = f"Using the {dataset_details["text_source"]} \"{original_record.text}\" which portrays the emotion{original_record.label}, generate a similar {dataset_details["text_source"]} that instead portrays {target_label}."
    
    prompt = f"The following is a {dataset_details["text_source"]} with any usernames, names, hashtags, and URLs tokenized with an all-caps generalized term. \"{original_record.text}\". Using this {dataset_details["text_source"]}, which portrays the emotion {original_record.label}, generate a {dataset_details["text_source"]} about the same subject and similar in style that instead portrays {target_label}. Only give the generated {dataset_details["text_source"]}."
    
    llm_derived_prompt = f"Create a {dataset_details["text_source"]} portraying {target_label} similar to this {dataset_details["text_source"]} portraying {original_record.label}: \"{original_record.text}\". Replace usernames, names, hashtags, and URLs with tokenized all-caps terms (e.g., USER, NAME, HASHTAG, URL). Do NOT use all-caps unless tokenizing as indicated. Do not explain your response."
    
    sample_prompt = f"Change this {dataset_details["text_source"]} portraying {original_record.label} to instead portray {target_label}: \"{original_record.text}\". Tokenize usernames, names, hashtags, and URLs with tokenized all-caps terms (e.g., USER, NAME, HASHTAG, URL). Do NOT use all-caps unless tokenizing as indicated."
    
    return prompt

### Prompt LLM for Synthetic Record

In [10]:
def generate_synthetic_text(llm_details, dataset_details, target_label, text_prompt):
    
    if llm_details["platform"] == "Ollama":
        response = ollama.chat(
            model=llm_details["id"], 
            messages=[
                {"role": "user", "content": text_prompt},
                {"role": "assistant", "content": f"Here is a similar {dataset_details["text_source"]} portraying {target_label}:\n\n\""}  # Starts the LLM's response, preventing guardrails from censoring fear or anger based responses
            ])
        
    return parse_text_response(response["message"]["content"])

### Parse Synthetic Text response 

In [11]:
def parse_text_response(response):
    
    CRED = '\33[91m'
    CBLU = '\33[34m'
    CEND = '\33[0m'
    
    # If the LLM is going to explain the rationale for its response, it will be after a couple line breaks.
    if response.find('\n\n')+1:   # -1 is returned if not found
        print(f"\n{CRED}Newline Response:\n\t{response}{CEND}")
        response = response[:response.find('\n')]
    
    # The synthetic tweet is given a quote to start, so it will often have a quote to end.
    quote_indexes = [i.start() for i in re.finditer('\"', response)]
    if len(quote_indexes):
        response = response[:quote_indexes[-1]]
    else:
        print(f"\n{CBLU}No Quote Response:\n\t{response}{CEND}")
        
    return response

## Generate Label
### Prompt

In [12]:
def generate_label_prompt(dataset_details, synthetic_record):
    return (f"Classify the {dataset_details["text_source"]} \"{synthetic_record}\" by the single most represented {dataset_details["label_type"]} ONLY from the following list:\n"
            f"1. Anger (also includes annoyance, rage)\n"
            f"2. Disgust (also includes disinterest, dislike, loathing)\n"
            f"3. Fear (also includes apprehension, anxiety, terror)\n"
            f"4. Joy (also includes serenity, ecstasy)\n"
            f"5. Sadness (also includes pensiveness, grief)\n"
            f"6. Surprise (also includes distraction, amazement)\n"
            f"Give only the label.")

### Prompt LLM for Label

In [13]:
def generate_synthetic_label(llm_details, label_prompt):
    if llm_details["platform"] == "Ollama":
        response = ollama.chat(
            model=llm_details["id"], 
            messages=[
                {"role": "user", "content": label_prompt}
            ])
        
    return response["message"]["content"]

### Parse Label Response

In [14]:
def parse_label_response(response, dataset_details):
    for label in dataset_details["label_list"]:
        if label.lower() in response.lower():
            return label
    
    # Label name not found, look for ID number
    for i in range(1, len(dataset_details["label_list"]) + 1):
        if str(i) in response:
            return dataset_details["label_list"][i-1]
    
    # Label not found
    CRED = '\33[91m'
    CEND = '\33[0m'
    print(f"{CRED}NO LABEL FOUND:{CEND} {response}")
    return None

### Label Record

In [15]:
def get_label(dataset_details, llm_details, text):
    label_prompt = generate_label_prompt(dataset_details, text)
    
    labels = []
    print("Labels: ", end="")
    for i in range(dataset_details["num_labelers"]):
        label_response = generate_synthetic_label(llm_details, label_prompt)
        labels.append(parse_label_response(label_response, dataset_details))
        
        if i > 0:
            print(", ", end="")
        print(f"{labels[i]}", end="")

    consensus_label = None
    
    if dataset_details["label_format"] == "single":
        # Single label dataset
        for potential_label in dataset_details["label_list"]:
            if labels.count(potential_label) >= dataset_details["num_consensus"]:
                consensus_label = potential_label
                print(f"\n\tConsensus: {consensus_label}")
                
    elif dataset_details["label_format"] == "multi":
        # To be implemented if used with any multilabel datasets
        pass
    
    return labels, consensus_label

### Generating a synthetic record

In [16]:
def generate_synthetic_record(real_dataset, dataset_details, llm_details, target_label):
    random_record = get_random_record(real_dataset, target_label)   # Select a random record not in smallest class
    print(f"\nRandom Record:\n"
          f"\tIndex: {random_record.index}\n"
          f"\tText: {random_record.text}\n"
          f"\tLabel: {random_record.label}")
    
    text_prompt = build_text_prompt(dataset_details, target_label, random_record)  # Build 
    synthetic_text = generate_synthetic_text(llm_details, dataset_details, target_label, text_prompt)     # Prompt LLM for synthetic record
    print(f"\nSynthetic Text:\n"
          f"\t{synthetic_text}\n")
    
    labels, consensus_label = get_label(dataset_details, llm_details, synthetic_text)     # Label the record (may differ from target label)
        
    return { "text" : synthetic_text, 
             "label" : consensus_label, 
             "all labels" : labels, 
             "source index" : random_record.index, 
             "source label" : random_record.label, 
             "intended label" : target_label }

### Saving the synthetic dataset

In [17]:
def save_dataset(dataset_details, llm_details, working_data):
    directory = f"./synthetic_datasets/{dataset_details['id']}/"
    filename = llm_details['id'].replace(":", "_") + ".parquet"
    
    old_data = load_synthetic_dataset(dataset_details, llm_details)
    new_data = pd.concat([old_data, working_data], ignore_index=True)
    
    try:
        new_data.to_parquet(path=directory+filename)
    except OSError:
        os.makedirs(directory)
        new_data.to_parquet(path=directory+filename)
    
    print("+ Synthetic dataset saved!")
    
    del old_data, new_data

# Augment
## Setup

In [21]:
dataset_metadata = dataset_config.dataset["EmoEvent (English)"]
llm_metadata = LLM_config.model["Llama3.1 8B instruct-q8"]

# Load real data
df_real_data = load_real_dataset(dataset_metadata)

# Homogenize and remove any 'unlabeled' class records
preprocess_dataframe(dataset_metadata, df_real_data)      

# Find how many of each record is needed to balance the dataset
imbalance_counts = find_label_imbalance_counts(df_real_data, dataset_metadata, llm_metadata)

Synthetic dataset found.

ORIGINAL LABEL COUNTS: label
joy         2039
disgust      765
sadness      416
anger        392
surprise     235
fear         151
Name: count, dtype: int64

SYNTHETIC LABEL COUNTS: label
anger       1957
fear        1891
surprise    1806
sadness     1624
disgust     1276
joy          164
Name: count, dtype: int64

COMBINED LABEL COUNTS: label
joy         164
disgust       2
sadness       1
anger       310
surprise      2
fear          3
Name: count, dtype: int64


## Augmenting Loop

In [19]:
SAVE_BATCH_SIZE = 10
batch_count = 0

df_synthetic = pd.DataFrame(columns = ['text', 'label', 'all labels', 'source index', 'source label', 'intended label'])
            
while imbalance_counts.min() < 0:   # Get lesser classes up to the size of the largest class
    target_label = imbalance_counts.idxmin()
    
    print("----------------------------------------------------------------------------")
    print(f"Targeting: {target_label}")
    synthetic_record = generate_synthetic_record(df_real_data, 
                                                 dataset_metadata, 
                                                 llm_metadata, 
                                                 target_label)
    
    if synthetic_record["label"]:
        df_synthetic.loc[len(df_synthetic)] = synthetic_record
        
        imbalance_counts[synthetic_record["label"]] += 1    # Update imbalance counts
        batch_count += 1
        
        if batch_count == SAVE_BATCH_SIZE:
            save_dataset(dataset_metadata, llm_metadata, df_synthetic)
            df_synthetic = df_synthetic.iloc[0:0]   # Clear dataframe contents
            batch_count = 0
            print(imbalance_counts)
            
save_dataset(dataset_metadata, llm_metadata, df_synthetic)  # Final save if target balance is reached

----------------------------------------------------------------------------
Targeting: sadness

Random Record:
	Index: 6199
	Text: HASHTAG minors please note potty language is used for puncuation!  URL
	Label: disgust

Synthetic Text:
	SO SAD why do some shows use strong language? especially when kids are watching! URL

Labels: sadness, sadness, sadness
	Consensus: sadness
----------------------------------------------------------------------------
Targeting: anger

Random Record:
	Index: 2816
	Text: 6 Books Celebrating Differences and Kindness  URL HASHTAG  URL
	Label: joy

Synthetic Text:
	5 Books Exposing Injustice and Outrage URL HASHTAG URL

Labels: anger, anger, anger
	Consensus: anger
----------------------------------------------------------------------------
Targeting: surprise

Random Record:
	Index: 1433
	Text: USER Send 5 US soldiers for every half a Cuban. End this. Free HASHTAG.
	Label: anger

Synthetic Text:
	USER Just heard there's a proposal to deploy 5 US soldiers fo

In [20]:
save_dataset(dataset_metadata, llm_metadata, df_synthetic)

Synthetic dataset found.
+ Synthetic dataset saved!
