# TODO:
### NOW:
- ~~enforce output format for gemini~~
- llama, gpt, ~~claude~~
   - ~~send concurrent calls to all models at once~~
- ~~add evaluation if there is a golden set for individual model~~
- aggregation strategy
   - multiclass classification: 
      - ~~majority vote~~, add tie breaking strategy
      - ~~baysian approach with GT~~
      - provide X labels per class
      - what if labels are not int?
- evaluate
   - bug in glad evaluation. make sure the labels are int
- repeat the same thing for multi-label/ner

### LATER:
- secret management
- ~~update readme~~
- add images



### nice things to do:
- ~~add tqdm to asyncio calls~~
- ~~proper logging~~

# Load the data

In [None]:
from datasets import load_dataset
seed =42

In [None]:
gemini_prompt_template = """
<data_description>
{description}
</data_description>
-----------

<context>
{datapoint}
</context>
------------

<labels>
{labels}
</labels>
------------

INSTRUCTION:
- familirize yourself with the data using data_description
- read the context carefully. this is the data point you need to label.
- take your time and label the dadatapoint with the most appropriate option using the provided labels.
- return the result as a single label from the <labels>. Don't provide explanations
"""

In [1]:
dataset = load_dataset("yelp_polarity", split="train") # https://huggingface.co/datasets/yelp_polarity

# take a small sample for dev purposes
dataset_sample = dataset.shuffle(seed=seed).select(range(20))

# user provided data description
DESCRIPTION = """
This is a dataset for binary sentiment classification.
It contains highly polar yelp rgenerate_funceviews.
Negative polarity is class 0, and positive class 1.
"""

LABEL_SET = [0, 1] 

NameError: name 'load_dataset' is not defined

In [None]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x,
                                        labels=LABEL_SET) for x in dataset_sample["text"][:]]
print(len(prompt))

# Annotate

In [None]:
from utils import Annotate

from datetime import datetime

# Get the current datetime
now = datetime.now().strftime("%Y%m%d")

seed =42

In [None]:
ann = Annotate()

models = [
    "palm"
    # "gemini",
    # "claude"
    ]

In [None]:
await ann.__palm(prompt=prompt[0])

In [None]:
r = await ann.classification(prompt, models=models)

In [None]:
r['gemini'], r['claude']

In [None]:
import json
with open(f"./data/output/annotation_output__{now}.json", "w") as json_file:
    json.dump(r, json_file, indent=4)

## Aggregate

In [None]:
import json
import random

with open('./data/output/annotation_output__20240515.json', 'r') as file:
        all_results = json.load(file)

all_results['fake'] = [random.randint(0, 1) for _ in range(len(all_results['gemini']))]
y_labels = ["gemini", "claude", "fake"]


all_results = [value for value in all_results.values()]

In [None]:
from utils import Aggregate

agg = Aggregate()

In [None]:
majority_vote = agg._get_majority_vote(all_results)

In [None]:
glad_result = agg._glad(all_results)

In [None]:
glad_result.keys()

In [None]:
[x for x in glad_result['labels'] if x[1] == 1]

## evaluate

In [None]:
from utils import Evaluate

eval = Evaluate()

In [None]:
eval.classification(all_results, strategy="majority", visualize=False, y_labels=y_labels)

# Dev

In [None]:
# import numpy as np
# from sklearn.datasets import fetch_20newsgroups

# from sklearn.model_selection import train_test_split


# dataset = fetch_20newsgroups(subset='train',
#                               remove=('headers', 'footers', 'quotes'),
#                               )
# # take a small sample for dev purposes - Stratified sampling to maintain class distribution
# # Convert the target names to a numpy array
# target_names = np.array(dataset.target_names)

# _, _, y_train, y_test = train_test_split(
#     dataset.filenames, 
#     dataset.target, 
#     train_size=1000,  # Get 1000 samples
#     stratify=dataset.target,  # Ensure class distribution is preserved
#     random_state=seed # For reproducibility
# )

# # Now load the actual data for the selected samples
# dataset_sample= fetch_20newsgroups(
#     subset='train',["""hello"""]
#     remove=('headers', 'footers', 'quotes'),
#     categories=target_names[y_train]  # Only load categories in the sample
# )


# # user provided data description
# DESCRIPTION = """
# The 20 newsgroups dataset comprises around 18000 newsgroups posts on 20 topics. Here are the topics:
# ['alt.atheism',
#  'comp.graphics',
#  'comp.os.ms-windows.misc',
#  'comp.sys.ibm.pc.hardware',
#  'comp.sys.mac.hardware',
#  'comp.windows.x',
#  'misc.forsale',
#  'rec.autos',
#  'rec.motorcycles',
#  'rec.sport.baseball',
#  'rec.sport.hockey',
#  'sci.crypt',
#  'sci.electronics',
#  'sci.med',
#  'sci.space',
#  'soc.religion.christian',
#  'talk.politics.guns',
#  'talk.politics.mideast',
#  'talk.politics.misc',
#  'talk.religion.misc']
# """

# LABEL_SET = ['alt.atheism',
#  'comp.graphics',
#  'comp.os.ms-windows.misc',
#  'comp.sys.ibm.pc.hardware',
#  'comp.sys.mac.hardware',
#  'comp.windows.x',
#  'misc.forsale',
#  'rec.autos',
#  'rec.motorcycles',
#  'rec.sport.baseball',
#  'rec.sport.hockey',
#  'sci.crypt',
#  'sci.electronics',
#  'sci.med',
#  'sci.space',
#  'soc.religion.christian',
#  'talk.politics.guns',
#  'talk.politics.mideast',
#  'talk.politics.misc',
#  'talk.religion.misc']

In [None]:
from transformers import AutoTokenizer, GemmaForCausalLM, pipeline


In [None]:
model = GemmaForCausalLM.from_pretrained("google/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

In [None]:
gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_new_tokens=1024)


In [None]:
inputs = tokenizer(prompt[0], return_tensors="pt")

In [None]:
outputs = model(**inputs)

In [None]:
generated_ids = model.generate(**inputs)


In [None]:
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)