# GEF: GPT-turbo-3.5 zero-shot classification

This notebook uses zero-shot classification.

## OpenAI Privacy Policy
This notebook uses OpenAI's API, meaning that your data will be sent to the OpenAI servers.

For concerns about how your data will be handled, please read through the Privacy Policy [here](https://openai.com/policies/api-data-usage-policies).

In [None]:
from causation.utils import openai_apikey_input

openai_apikey_input()

# Upload Prompt

In [None]:
from causation.utils import fileuploader 

finput, uploaded = fileuploader('.toml')
finput

In [None]:
from langchain.prompts import load_prompt
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain.prompts import PromptTemplate
import toml
from pathlib import Path

path = uploaded.get('data', None)
assert path is not None, "Did you upload your .toml file?"

path = Path(path)
if not path.suffix == '.toml': raise ValueError("path is not a toml file.")
import toml
data = toml.load(path)

if "PREFIX" in data.keys():
    prefix = data.pop("PREFIX")
    prefix_instructions = prefix.get('instruction', '')
else:
    prefix_instructions = ''

classes = list(data.keys())
instructions = []
for clz in classes:
    instruction = data.get(clz).get('instruction')
    instruction = f"<class>\n{clz}: {instruction}</class>"
    instructions.append(instruction)

instruction = prefix_instructions + "\n\n" + f"""
The following are {len(classes)} classes with a description of each. 
These are XML delimited with <class> tags in the format: <class> Class: Description </class>.
Please classify each 'query' as one of the {len(classes)} classes.\n\n""" + '\n'.join(instructions) + "\n\n"

instruction += "\n\n{format_instructions}\nQuery: {query}"
template = PromptTemplate.from_template(instruction)
template

class ClassificationOutput(BaseModel):
    answer: str = Field(description="the classification")

parser = PydanticOutputParser(pydantic_object=ClassificationOutput)

# template = load_prompt(exemplars.get('data'))
template = template.partial(format_instructions=parser.get_format_instructions())
human = HumanMessagePromptTemplate(prompt=template)
chat = ChatPromptTemplate.from_messages(messages=[human])

print(f"Classes found: {', '.join(classes)}")
"Prompt set up complete. Please continue."

In [None]:
from langchain.chat_models import ChatOpenAI

llm = ChatOpenAI(
    model_name='gpt-3.5-turbo', 
    n=1,
    temperature=0.0,
    model_kwargs={'top_p': 0.8},
)

"LLM set up complete. Please continue."

# Upload Dataset

In [None]:
from causation.utils import fileuploader

finput, dataset = fileuploader('.xlsx')
finput

In [None]:
import pandas as pd
assert dataset.get('data'), "Did you upload your dataset?"
df = pd.read_excel(dataset.get('data'))
df.head(1)

In [None]:
from tqdm.auto import tqdm
import pandas as pd

checkpointing = 200
corrupted = []
classifications = []
for i, sent in tqdm(enumerate(df.sentence), total=len(df)):
    messages = chat.format_prompt(query=sent).to_messages()
    results = llm.generate([messages])
    try:
        answer = parser.parse(results.generations[0][0].text).answer
        classifications.append((sent, answer))
    except:
        print("Got corrupted llm output. These are added to an excel sheet so you can rerun these later.")
        corrupted.append(sent)
        
    if checkpointing and (i + 1) % checkpointing == 0:
        path = f'./cotsc-outputs-checkpoint-{i + 1}.xlsx'
        pd.DataFrame(classifications, columns=['sentence', 'classification']).to_excel(path)
        print(f"Checkpointed at {i + 1} queries processed. Checkpoint file: {path}.")
        
f"Passed {len(classifications)}/{len(df)}. Please continue."

In [None]:
corrupted_df = pd.DataFrame(corrupted, columns=['sentence'])
results_df = pd.DataFrame(classifications, columns=['sentence', 'classification'])
results_df

In [None]:
from pathlib import Path
from datetime import datetime
import srsly

now = datetime.now().strftime(format="%Y-%m-%d_%H-%M-%S")
output_dir = Path(f"./.zeroshot-corpus-output-{now}")
output_dir.mkdir(exist_ok=False)

results_df.to_excel(output_dir.joinpath('zeroshot-output.xlsx'))
path = output_dir.joinpath('zeroshot-config.json')
config = {
    'model': llm.model_name,
    'temperature': llm.temperature,
    'top_p': llm.model_kwargs.get('top_p', 'N/A'),
    'n_completions': llm.n,
}
srsly.write_json(path, config)

if len(corrupted_df) > 0:
    corrupted_df.to_excel(output_dir.joinpath('zeroshot-corrupted.xlsx'))

In [None]:
file_names = [uploaded['data'], dataset['data']]
file_names += list(output_dir.glob("*"))
file_names

In [None]:
import zipfile
import os
from datetime import datetime
from pathlib import Path
import panel as pn

zfname = Path(f'{now}-zeroshot-corpus.zip')
with zipfile.ZipFile(zfname, 'w') as zipf:
    for file_name in file_names:
        zipf.write(file_name, arcname=os.path.basename(file_name))
print(f"Saved as {zfname}.\nClick below to download.")

# download link for the zip.
pn.widgets.FileDownload(file=str(zfname), filename=zfname.name)