# Setup
## Config Files

In [1]:
from config_files import dataset_config
from config_files import LLM_config

## Imports

In [2]:
import pandas as pd
import ollama
# from datasets import load_dataset

# Functions
## Original Dataset Loading

In [3]:
def load_training_dataset(dataset_details):
    df_train = pd.DataFrame([])
    
    if dataset_details["location"] == "local":
        
        if dataset_details["is_split"]:
            
            if dataset_details["filetype"] == "csv":
                df_train = pd.read_csv(dataset_details["train_relpath"])
            
            elif dataset_details["filetype"] == "tsv":
                df_train = pd.read_csv(dataset_details["train_relpath"], sep="\t")
        
        else: # Dataset is not split
            if dataset_details["filetype"] == "csv":
                df_train = pd.read_csv(dataset_details["relpath"])
            
            elif dataset_details["filetype"] == "tsv":
                df_train = pd.read_csv(dataset_details["relpath"], sep="\t")
    
    return df_train

## Preprocessing

In [4]:
def preprocess_dataframe(dataset_details, dataframe):
    dataframe.drop(columns = dataset_details["unused_columns"], inplace=True)
    dataframe.rename(columns = dataset_details["remap_columns"], inplace=True)        
    
    unlabeled_dataframe = pd.concat([dataframe[dataframe['label'] == dataset_details["unlabeled_label"]]]) # Create dataframe of only unlabeled records
    dataframe.drop(dataframe[dataframe['label'] == dataset_details["unlabeled_label"]].index, inplace=True) # Remove unlabeled records from original dataframe
    
    return unlabeled_dataframe

## Synthetic Dataset Loading

In [5]:
def load_synthetic_dataset(dataset_details):
    # DO THIS LATER
    return pd.DataFrame(columns=["text", "label"])
    

## Find Label Imbalance Counts

In [6]:
def find_label_imbalance_counts(df_original, df_synthetic):
    original_label_counts =  pd.Series(df_original.label).value_counts()
    synthetic_label_counts = pd.Series(df_synthetic.label).value_counts()
    
    return original_label_counts - (original_label_counts.max())

## Get A Random Record

In [7]:
def get_random_record(dataset, target_label):
    # Temp remove target labeled records and get a random record from remaining dataset 
    record = dataset[~dataset['label'].apply(lambda x: target_label in x)].sample()
    
    return  {"text": record.text.values[0], "label": record.label.values[0]}

## Generate Text Prompt

In [12]:
def generate_text_prompt(dataset_details, target_label, original_record):
    if dataset_details["label_type"] == "single":
        original_record['label'] =  f" {original_record["label"]}"
    
    query = f"Using the {dataset_details["text_source"]} \"{original_record["text"]}\" which portrays the emotion{original_record["label"]}, generate a {dataset_details["text_source"]} similar in style and content that instead portrays {target_label}. Only give the generated {dataset_details["text_source"]}."
    
    new_query = (f"The following is a {dataset_details["text_source"]} with any names, hashtags, and URLs replaced with an all-caps generalized term.\n"
                 f"\"{original_record["text"]}\"\n"
                 f"Using this {dataset_details["text_source"]}, which portrays the emotion{original_record["label"]}, generate a {dataset_details["text_source"]} about the same subject and similar in style that instead portrays {target_label}. Only give the generated {dataset_details["text_source"]}.")
    
    return new_query

## Prompt LLM

In [9]:
def generate_synthetic_text(llm_details, text_prompt):
    if llm_details["platform"] == "Ollama":
        response = ollama.chat(
            model=llm_details["model"], 
            messages=[
                {
                    "role": "user",
                    "content": text_prompt,
                }
            ])
        
    return response["message"]["content"]
        

## Parse LLM response  

In [10]:
def parse_response(response):
    
    line_break = [pos for pos, char in enumerate(response) if char == '\n']
    
    
    if not line_break:  
        if response[0] == '"' and response[-1] == '"': # Synthetic text has been within quotes
            return response[1:-1]
        else:   # LLM guardrails may prevent a response from being generated:
            return "null"
    
    # A response may have a lead-in prior to the generated text
    if line_break[0] + 1 == line_break[1]:
        start_tweet = line_break[1]+1
    else:
        start_tweet = line_break[0]+1
    
    generated_tweet = response[start_tweet:]
    
    # Following the generated text may be a rationale
    line_break = [pos for pos, char in enumerate(generated_tweet) if char == '\n'] 
    if line_break:
        generated_tweet = generated_tweet[:line_break[0]]
        
    # Tweet may be flanked in quotes
    if generated_tweet[0] == '"':   
        generated_tweet = generated_tweet[1:]
    if generated_tweet[-1] == '"':  
        generated_tweet = generated_tweet[:-1]
    
    return generated_tweet

# Test

In [21]:
dataset_details = dataset_config.dataset[0]

df_original_training = load_training_dataset(dataset_details)                           # Load real data
df_unlabeled_records = preprocess_dataframe(dataset_details, df_original_training)      # Homogenize and pull out any 'unlabeled' class records
df_synthetic_training = load_synthetic_dataset(dataset_details)                         # Load synthetic data
    
df_sample = df_original_training.sample(10)

df_sample    

Unnamed: 0,text,label
29,"USER That’s not deluded, that’s the comment of...",joy
4674,It's so beautiful seeing birds/bats being slau...,disgust
1525,I wish that bird would piss off flying about a...,anger
6333,possible in my life. And I am horribly limited...,sadness
4009,"Holy shit, you aren't going to believe this. T...",surprise
3873,"I think that such is the hysterical, passionat...",disgust
4767,Happy HASHTAG! Let us know what you are readin...,joy
6439,Any day Messi plays is a good day 😄😄 HASHTAG,joy
832,I liked that they filmed this episode of HASHT...,joy
3241,USER If HASHTAG had Asperger's syndrome she wo...,disgust


In [41]:
llm_details = LLM_config.model["Llama3.1 8B instruct-q8"]

imbalance_counts = find_label_imbalance_counts(df_original_training, df_synthetic_training) # Find how many of each record is needed to balance the dataset
target_label = imbalance_counts.idxmin()

for row in df_sample.iterrows():
    record = {"text": row[1].text, "label": row[1].label}
    
    print("Original Text:")
    print(record["text"])
    
    text_prompt = generate_text_prompt(dataset_details, target_label, record)
    # print("User Prompt:")
    # print(f"{text_prompt}\n")
    
    full_response = generate_synthetic_text(llm_details, text_prompt)
    # print("Full Response:")
    # print(f"{full_response}\n")
    
    synthetic_text = parse_response(full_response)
    print("Synthetic Text:")
    print(f"{synthetic_text}\n")

Original Text:
USER That’s not deluded, that’s the comment of someone that believes in their club, I also think we’ll get through, I was impressed with Ajax, but with Son and a fully fit Sissoko, we can do it! 👍🏼 HASHTAG HASHTAG HASHTAG
Synthetic Text:
USER That’s deluded, that’s the comment of someone who’s not facing reality, I have my doubts we’ll get through, Ajax was impressive and with Son and Sissoko already injured, what if it all falls apart? 😨🤕 HASHTAG HASHTAG HASHTAG

Original Text:
It's so beautiful seeing birds/bats being slaughtered in midair by Eco-friendly wind blades. Tourists can also pick up some bird carcasses to make Ecologically friendly barbecues. URL HASHTAG HASHTAG HASHTAG HASHTAG HASHTAG HASHTAG
Synthetic Text:
I'm terrified thinking about the devastating impact of those wind turbines on local wildlife - the mere thought of them being sliced apart in mid-air is giving me nightmares. Tourists might get a thrill from taking home mangled remains, but I'll be slee

# Augment

In [None]:
for dataset_details in dataset_config.dataset:
    for llm_details in LLM_config.model:
        df_original_training = load_training_dataset(dataset_details)                           # Load real data
        df_unlabeled_records = preprocess_dataframe(dataset_details, df_original_training)      # Homogenize and pull out any 'unlabeled' class records
        df_synthetic_training = load_synthetic_dataset(dataset_details)                         # Load synthetic data
        
        imbalance_counts = find_label_imbalance_counts(df_original_training, df_synthetic_training) # Find how many of each record is needed to balance the dataset
        
        while imbalance_counts.min() < 0:
            target_label = imbalance_counts.idxmin()
            random_record = get_random_record(df_original_training, target_label)
            
            text_prompt = generate_text_prompt(dataset_details, target_label, random_record)
            print("User Prompt:")
            print(f"{text_prompt}\n")
            
            full_response = generate_synthetic_text(llm_details, text_prompt)
            synthetic_text = parse_response(full_response)
            print("Synthetic Text:")
            print(f"{synthetic_text}\n")
            
            break