In [2]:
from openai import OpenAI, OpenAIError
from datasets import load_dataset
import itertools
import json
import tiktoken
import pandas as pd
import random
from tqdm import tqdm
import time
import os
import re
import warnings
from collections import Counter

In [None]:
def num_tokens_from_string(speech):
    """Return the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model('gpt-4o-mini')
    num_tokens = len(encoding.encode(speech))
    return num_tokens

    
def completion():
    client = OpenAI()
    response = client.chat.completions.create(
        model='gpt-4o-mini',
        messages=[
            {'role': 'system', 'content': system},
            {'role': 'user', 'content': prompt}],
            #response_format=response_format,
            temperature = 0.1,
            max_tokens = 1000, # max length of response
    )
        
    return response.choices[0].message.content


def get_key():
    """Get OpenAI authorization key."""
    with open('/projappl/project_2011109/openai-apikey.txt', 'r') as f:
        os.environ['OPENAI_API_KEY'] = f.read().strip('\n')


def calculate_cost(input_len, output_len):
    """Calculate the cost of running the model based on the input and output tokens."""
    input_price = 0.15/1_000_000
    output_price = 0.6 / 1_000_000
    print(f"Input tokens: {sum(input_len)} at a cost of ${input_price * sum(input_len)}")
    print(f"Output tokens: {sum(output_len)} at a cost of ${output_price * sum(output_len)}")
    print(f"Total cost: ${sum([output_price * sum(output_len), input_price * sum(input_len)])}")
    print()


def format_input(round_num, doc):
    if round_num == 0:
        pass
    elif round_num == 1 or round_num == 3:
        pass
    else:
        

def main(start_from_index=0,
         stop_at_index=500,
         save_file='junk_classification_output.jsonl'):
    
    get_key() # load OpenAI API key
    input_len = []
    output_len = []
    time_taken = []
    
    docs = load_dataset('HuggingFaceFW/fineweb', name='sample-10BT', split='train', streaming=True)
    
    for doc_index, doc in tqdm(enumerate(docs)):
        start_time = time.time()
        if doc_index < start_from_index:
            continue

        for round in range(5):
            # Format input.
            input = format_input(round, doc)
            
            # Generate response.
            output, full_prompt = generate(input, junk_labels)
            
            # Calculate input and output tokens to keep track of costs.
            input_len.append(num_tokens_from_string(full_prompt))
            output_len.append(num_tokens_from_string(output))
            
        # Add generated junk labels to junk_labels list
        junk_labels = extract_junk_labels(junk_labels, output)

        for input_line, output_line in zip(chunk, output.splitlines()):
            dict = {'line': input_line,
                    'label': output_line.split(':')[1].strip().lower(), # Remove the "Line X:" preamble
                    'split': was_split #whether the doc was split "manually".
                   }
            doc_output.append(dict)

        # Save output.
        with open(f'../results/{save_file}', 'a') as f:
            dict = {'doc': doc, 'content': doc_output}
            f.write(json.dumps(dict, ensure_ascii=False))
            f.write('\n')

        with open(f'../results/junk_labels.txt', 'w') as f:
            for line in junk_labels:
                f.write(line)
                f.write('\n')

        # Keep track of time to get average time per document.
        end_time = time.time()
        time_taken.append(end_time - start_time)

        # Print cost every now and then while running to make sure we're not bleeding money.
        # Also print the junk labels and how many labels there are to keep an eye on them, too.
        if doc_index > 0 and doc_index % 100 == 0:
            with open('../results/number_of_labels.csv', 'a') as f:
                f.write(f'{doc_index}, {len(junk_labels)}\n')
            calculate_cost(input_len, output_len)
            print(f'Junk labels: {junk_labels}')
            print(f'Number of labels: {len(junk_labels)}')
            print()

        if doc_index >= stop_at_index:
            break

    calculate_cost(input_len, output_len)