In [38]:
from openai import OpenAI, api_key
import huggingface_hub
import os
import torch
import torch.distributed as dist
import time
import prompts
import json
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import sys
from pydantic import BaseModel
from random import shuffle
from collections import defaultdict

def get_key():
    """
    Fetches the OpenAI API key from environment variables.
    """
    api_key = os.environ["OPENAI_API_KEY"]


def calculate_doc_similarity(original, rewrite):
    """
    Calculates cosine similarities between the original document and the rewrites.

    Args:
        original (str): The original document.
        rewrites (list): List of the rewritten documents.

    Returns:
        dict: Dictionary of rewritten documents and their similarity scores.
    """
    model = SentenceTransformer("all-MiniLM-L6-v2")
    original_embedding = model.encode([original])
    rewrite_embedding = model.encode([rewrite])

    # Compute cosine similarities
    similarity = model.similarity(original_embedding, rewrite_embedding)

    # Return similarity between documents
    return round(float(similarity), 4)


def get_response_format(stage):
    """
    Returns the appropriate response format based on the stage of execution.
    
    Args:
        stage (str): The stage of execution the program is in.

    Returns:
        ResponseFormat: A Pydantic model for formatting responses.
    """
    if stage == 'initial':
        class ResponseFormat(BaseModel):
            general: list[str]
            specific: list[str]
    elif stage == 'rewrite':
        class ResponseFormat(BaseModel):
            document: str
    else:
        class ResponseFormat(BaseModel):
            differences: str
            general: list[str]
            specific: list[str]

    return ResponseFormat


def print_execution_time(start_time, end_time):
    """
    Prints the total time taken to process data.

    Args:
        start_time (float): Start time of the process.
        end_time (float): End time of the process.
    """
    print('='*20)
    print(f'Total time taken to generate answers: {end_time - start_time}')
    print('='*20)


def format_prompt(stage, original=None, rewritten=None, general=None, specific=None, vocab=None):
    """
    Formats the prompt based on the current processing stage.

    Args:
        stage (str): The processing stage ('initial', 'rewrite', or 'revise').
        original (str): The original document text.
        rewritten (str, optional): The rewritten text if available. Defaults to None.
        general (list, optional): List of general descriptors if available. Defaults to None.
        specific (list, optional): List of specific descriptors if available. Defaults to None.
        vocab (list, optional): Vocabulary to include in the prompt. Defaults to None.

    Returns:
        str: The formatted prompt based on the stage and inputs.
    """
    if stage == 'initial':
        message = prompts.initial_prompt(original, vocab)
    elif stage == 'rewrite':
        message = prompts.rewrite_prompt(general, specific)
    else:
        message = prompts.revise_keyphrases_prompt(original, rewritten, general, specific, vocab)
    return message


def generate(client, message, stage):  
    """
    Sends a prompt to the OpenAI client and retrieves the response.

    Args:
        client (OpenAI): OpenAI API client instance.
        message (str): The prompt message.
        stage (str): Current processing stage.

    Returns:
        dict: Parsed JSON response from the model.
    """
    response_format = get_response_format(stage)
    
    response = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=message,
        response_format=response_format,
        )
    
    return response.choices[0].message.parsed.json()


def load_documents():
    """
    Loads and streams documents from a specified dataset.

    Returns:
        Dataset: A streaming dataset split to be processed in training mode.
    """
    return load_dataset('HuggingFaceFW/fineweb',
                        name='sample-10BT',
                        split='train',
                        streaming=True)


def save_descs(general, specific, file_id):
    """
    Saves general and specific descriptors to a text file.

    Args:
        general (list): General descriptors.
        specific (list): Specific descriptors.
        file_id (str): Identifier for the output file.
    """
    with open(f'../results/desciptors_{file_id}.txt', 'a') as f:
        f.write('======================\n')
        f.write('General:\n')
        for item in general:
            f.write(f'{item}\n')
        f.write('----------------------\n')
        f.write('Specific:\n')
        for item in specific:
            f.write(f'{item}\n')
        f.write('======================\n')


def initial_stage(document, vocab, stage, client):
    """
    Generates initial descriptors for a given document.

    Args:
        document (str): Document text.
        vocab (list): Vocabulary for descriptor generation.
        stage (str): Current processing stage.
        client (OpenAI): OpenAI client instance.

    Returns:
        tuple: General and specific descriptors.
    """
    if len(vocab) == 0:
        vocab = "The list of general descriptors is currently empty."
    else:
        vocab = '\n'.join(vocab)
    
    prompt = format_prompt(stage=stage, original=document, vocab=vocab)
    output = json.loads(generate(client, prompt, stage))
    general = output['general']
    specific = output['specific']

    print('Initial prompt:')
    print(prompt)
    print()
    
    return general, specific


def rewrite_stage(stage, general, specific, client):
    """
    Rewrites a document based on provided descriptors.

    Args:
        stage (str): Current processing stage.
        general (list): General descriptors.
        specific (list): Specific descriptors.
        client (OpenAI): OpenAI client instance.

    Returns:
        str: Rewritten document.
    """
    prompt = format_prompt(stage=stage,
                           general=general,
                           specific=specific)
    output = json.loads(generate(client, prompt, stage))
    return output['document']
    

def save_rewrite(rewritten, file_id):
    """
    Saves rewritten document to a text file.

    Args:
        rewritten (str): Rewritten document.
        file_id (str): Identifier for the output file.
    """
    with open(f'../results/rewritten_docs_{file_id}.txt', 'a') as f:
        f.write(f'{rewritten}\n')
        f.write('===========================\n')


def revise_stage(stage, document, rewritten, general, specific, vocab, client):
    """
    Revises descriptors based on a rewritten document.

    Args:
        stage (str): Current processing stage.
        document (str): Original document text.
        rewritten (str): Rewritten document.
        general (list): General descriptors.
        specific (list): Specific descriptors.
        vocab (list): Vocabulary for descriptor generation.
        client (OpenAI): OpenAI client instance.

    Returns:
        tuple: Revised general and specific descriptors.
    """
    vocab = '\n'.join(vocab)
    prompt = format_prompt(stage=stage,
                           original=document,
                           rewritten=rewritten,
                           general=general,
                           specific=specific,
                           vocab=vocab)
    output = json.loads(generate(client, prompt, stage))
    general = output['general']
    specific = output['specific']

    return general, specific


def save_best_results(document, rewrites, general, specific, similarity_scores, run_id, print_results=False):
    """
    Saves the best results (highest similarity) among multiple rewrites.

    Args:
        document (str): Original document.
        rewrites (list): List of rewritten documents.
        general (list): General descriptors for each rewrite.
        specific (list): Specific descriptors for each rewrite.
        similarity_scores (list): Similarity scores for each rewrite.
        run_id (str): Run identifier.

    Returns:
        list: Best general descriptors.
    """
    best_index = similarity_scores.index(max(similarity_scores))
    results = {
        'document': document,
        'rewrite': rewrites[best_index],
        'similarity': similarity_scores[best_index],
        'general_descriptors': general[best_index],
        'specific_descriptors': specific[best_index],
    }
    if print_results:
        print('======================')
        print('BEST RESULTS:')
        for key, value in results.items():
            print(key)
            print(value)
            print()
        print('======================')
    with open(f'../results/descriptors_{run_id}.jsonl', 'a') as f:
        f.write(json.dumps(results, ensure_ascii=False))
        f.write('\n')

    return general[best_index]


def initialise_descriptor_vocab(use_previous_descriptors, path):
    """
    Initializes the descriptor vocabulary.

    Args:
        use_previous_descriptors (bool): Whether to load previous descriptors.
        path (str): Path to the previous descriptors file.

    Returns:
        defaultdict: Initialized descriptor vocabulary.
    """

    descriptors = defaultdict(int)
    
    if use_previous_descriptors:
        print('use_previous_descriptors=True')
        print('Set this to False if you want to start with an empty dictionary.')
        try:
            with open(path, 'r') as f:
                file = f.readlines()
                for line in file:
                    line = line.strip().split('\t')
                    desc, freq = line
                    descriptors[desc] = int(freq)
            return descriptors
        except FileNotFoundError:
            print('No previous descriptors found. Defaulting to empty dictionary.')
            return descriptors
    else:
        return descriptors


def save_descriptors(vocab, path):
    """
    Saves the current descriptor vocabulary to a file.

    Args:
        vocab (defauldict): Dict of descriptors and their frequency.
        path (str): Path to save the vocabulary.
    """
    with open(path, 'w') as f:
        for desc, freq in vocab:
            f.write(f"{desc}\t{freq}\n")


def return_top_descriptors(descriptor_counts_sorted):
    return [desc[0] for desc in descriptor_counts_sorted][:100]


def main(start_at_index=0, stop_at_index=100, use_previous_descriptors=False, descriptor_path=None, run_id='run1'):
    """
    Main function to set up the model, generate responses, and save the results.

    - Initializes cache directory and sets up the LLM.
    - Iterates through each document in the dataset, generating responses for each stage.
    - Collects and saves results.
    """
    
    cache_dir = "/scratch/project_2011109/otto/LLM_data_labelling/hf_cache"
    get_key()
    
    client = OpenAI()
    data = load_documents()

    descriptor_counts = initialise_descriptor_vocab(use_previous_descriptors, descriptor_path)
    # Keep the top 100 general descriptors. These will be given to the model as possible options.
    descriptor_counts_sorted = sorted(descriptor_counts.items(), key=lambda item: item[1], reverse=True)
    descriptor_vocab = return_top_descriptors(descriptor_counts_sorted)
    
    for i, line in enumerate(data):
        if i < start_at_index:
            continue
        print('General descriptor vocab:')
        print(descriptor_vocab)
        print('Num:', len(descriptor_vocab))
        print()
        
        file_id = f'{run_id}_doc{i}'
        general_descriptor_lists = []
        specific_descriptor_lists = []
        rewrites = []
        doc_similarities = []
        document = line['text']
        
        # Generate initial descriptors for document
        stage = 'initial'
        general_descriptors, specific_descriptors = initial_stage(document, descriptor_vocab, stage, client)
        general_descriptor_lists.append(general_descriptors)
        specific_descriptor_lists.append(specific_descriptors)
        #save_descs(general_descriptors, specific_descriptors, file_id)

        for _ in range(5):
            # Rewrite doc based on the descriptors
            stage = 'rewrite'
            rewritten = rewrite_stage(stage,
                                      general_descriptors,
                                      specific_descriptors,
                                      client)
            rewrites.append(rewritten)
            #save_rewrite(rewritten, file_id)

            # Evaluate rewrite and revise descriptors
            stage = 'revise'
            general_descriptors, specific_descriptors = revise_stage(stage,
                                                                     document,
                                                                     rewritten,
                                                                     general_descriptors,
                                                                     specific_descriptors,
                                                                     descriptor_vocab,
                                                                     client)
            general_descriptor_lists.append(general_descriptors)
            specific_descriptor_lists.append(specific_descriptors)
            #save_descs(general_descriptors, specific_descriptors, file_id)

            doc_similarities.append(calculate_doc_similarity(document, rewritten))

        # Save best result based on similarity score between original and rewrite
        # Return the best general descriptors
        best_descriptors = save_best_results(document,
                                             rewrites,
                                             general_descriptor_lists,
                                             specific_descriptor_lists,
                                             doc_similarities,
                                             run_id)

        # Update descriptor counts and save
        for desc in best_descriptors:
            descriptor_counts[desc] += 1
            
        # Sort descriptors by their frequency and save
        descriptor_counts_sorted = sorted(descriptor_counts.items(), key=lambda item: item[1], reverse=True)
        save_descriptors(descriptor_counts_sorted, descriptor_path)
        
        # Keep the 100 most common general descriptors. These will be given to the model as possible options.
        descriptor_vocab = return_top_descriptors(descriptor_counts_sorted)

        # Stop at given index
        if i >= stop_at_index:
            break

In [39]:
if __name__ == '__main__':
    main(start_at_index=0,
         stop_at_index=19,
         use_previous_descriptors=False,
         descriptor_path='../results/descriptor_vocab5.tsv',
         run_id='test5')

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

General descriptor vocab:
[]
Num: 0

Initial prompt:
[{'role': 'system', 'content': '\n##Instruction:\n\nYou will be given a document. Your task is to create a comprehensive list of descriptors—words or phrases that distill the meaning, tone, style, genre, topics, and other characteristics of the document. Do not focus solely on the topics; also include phrases that describe the tone and style. Order the descriptors as they appear in the document.\n\n##Requirements:\n\n1. Generate Two Types of Descriptors:\n    - General Descriptors: General Descriptors: Describe the document in aspects including, but not limited to, style, tone, genre, topic, domain, length, language, quality, etc. They should be general enough so they could likely be applied to other, hypothetical documents.\n    - Specific Descriptors: Describe minute details specific to this document such as individual words and phrases, emphasis, structure, etc.\n\n2. Descriptor Details:\n    - Descriptors can be single words or m

In [4]:
# Now let's investigate the results

results = []
with open('../results/descriptors_test1.jsonl', 'r') as f:
    file = f.readlines()
    for line in file:
        results.append(json.loads(line))

In [5]:
# Calculate character-level compression rate

for doc in results:
    document = doc['document']
    descriptors = "".join(doc['general_descriptors'] + ['specific_descriptors'])
    print('Compression rate:')
    print(f'{round((1-(len(descriptors)/len(document)))*100, 2)}%')

Compression rate:
66.91%
Compression rate:
92.68%
Compression rate:
88.14%
Compression rate:
85.85%
Compression rate:
77.44%
Compression rate:
75.47%
Compression rate:
64.31%
Compression rate:
-20.38%
Compression rate:
83.75%
Compression rate:
81.75%


In [6]:
import pandas as pd
df = pd.read_json('../results/descriptors_test1.jsonl', lines=True)
df

Unnamed: 0,document,rewrite,similarity,general_descriptors,specific_descriptors
0,|Viewing Single Post From: Spoilers for the We...,**Viewing Single Post** \n**Spoilers for the ...,0.849,"[Informal tone, Casual style, Fan commentary, ...","[Viewing Single Post, Spoilers for the Week of..."
1,"*sigh* Fundamentalist community, let me pass o...",*sigh* It seems we must once again delve into ...,0.604,"[Critical and confrontational tone, Engaging a...","[*sigh*, Advice to the fundamentalist communit..."
2,A novel two-step immunotherapy approach has sh...,**Innovations in Advanced Oncology: A Focus on...,0.7356,[Clinical and research-related tone with a foc...,[Two-Step Immunotherapy Approach: focuses on a...
3,Free the Cans! Working Together to Reduce Wast...,**Embracing Our Community: Waste Reduction and...,0.821,"[Informal and engaging tone, Persuasive writin...","[Reference to bizarre sharing habits, Society'..."
4,"ORLANDO, Fla. — While the Rapid Recall Exchang...",### Industry Insights on the Rapid Recall Exch...,0.7954,[Technical and informative tone that emphasize...,[References key events and settings like the U...
5,"September 28, 2010\n2010 Season - Bowman pulls...",# 2010 Season - Bowman pulls down CCIW honor\n...,0.9258,[Sports journalism tone focusing on collegiate...,"[Date: September 28, 2010, Headline: 2010 Seas..."
6,Kraft Foods has taken the Cadbury chocolate br...,**FOR IMMEDIATE RELEASE** \n**Kraft Foods Tak...,0.9189,[Corporate press release style focusing on pro...,[Kraft Foods taking Cadbury brand in a new dir...
7,You must be a registered member to view this p...,Welcome to our online community! \n\nTo fully ...,0.764,[Friendly and approachable writing style with ...,[Direct call to action: 'You must be a registe...
8,|Facility Type:||Full Service Restaurant|\n|In...,**Inspection Report: Full Service Restaurant**...,0.8885,[Technical and formal writing style with an em...,[Facility type specified as Full Service Resta...
9,News of the Week\nBarrie Spring Studio Tour\nA...,### 🎨 Upcoming Events at Jill Price Studios! \...,0.7113,[Informative and promotional tone with a perso...,"[Event announcement with dates and times., Inv..."


In [27]:
d = defaultdict(int)

d = {"Informal and engaging tone":2, "Fan commentary style":1, "Moderate length":4, "Entertainment domain":6}

sorted(d.items(), key=lambda item: item[1], reverse=True)[:10]

[('Entertainment domain', 6),
 ('Moderate length', 4),
 ('Informal and engaging tone', 2),
 ('Fan commentary style', 1)]

In [28]:
l = [1]

l[:10]

[1]