# TODO:
### NOW:
- ~~enforce output format for gemini~~
- llama, gpt, ~~claude~~
   - ~~send concurrent calls to all models at once~~
- ~~add evaluation if there is a golden set for individual model~~
- aggregation strategy
   - multiclass classification: 
      - ~~majority vote~~, add tie breaking strategy
      - ~~baysian approach with GT~~
      - provide X labeles per class
      - what if labels are not int?
- evaluate
   - bug in glad evaluation. make sure the labels are int
- repeat the same thing for multi-label/ner

### LATER:
- secret management
- ~~update readme~~
- add images



### nice things to do:
- add tqdm to asyncio calls
- ~~proper logging~~

# Annotate

In [1]:
from utils import Annotate
from datasets import load_dataset

seed =42

In [2]:
gemini_prompt_template = """
<data_description>
{description}
</data_description>
-----------

<context>
{datapoint}
</context>
------------

<labels>
{labels}
</labels>
------------

INSTRUCTION:
- familirize yourself with the data using data_description
- read the context carefully. this is the data point you need to label.
- take your time and label the dadatapoint with the most appropriate option using the provided labels.
- return the result as a single label from the <labels>. Don't provide explanations
"""

In [3]:
dataset = load_dataset("yelp_polarity", split="train") # https://huggingface.co/datasets/yelp_polarity

# take a small sample for dev purposes
dataset_sample = dataset.shuffle(seed=seed).select(range(100))

# user provided data description
DESCRIPTION = """
This is a dataset for binary sentiment classification.
It contains highly polar yelp rgenerate_funceviews.
Negative polarity is class 0, and positive class 1.
"""

LABEL_SET = [0, 1] 

In [4]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x,
                                        labels=LABEL_SET) for x in dataset_sample["text"][:20]]
print(len(prompt))

20


In [5]:
ann = Annotate()

VALID_MODELS = [
    "gemini",
    "claude"
    ]

In [6]:
r = await ann.classification(prompt, models=VALID_MODELS)

In [7]:
r['gemini'], r['claude']

([1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0],
 [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0])

In [None]:
import json
with open("./data/output/20_sample.json", "w") as json_file:
    json.dump(d, json_file, indent=4)

In [None]:
# all_results = [d["gemini"], d["claude"]]


## Aggregate

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]

In [None]:
from utils import Aggregate

In [None]:
agg = Aggregate()

In [None]:
agg._get_majority_vote(all_results)

In [None]:
agg._glad(all_results)

## evaluate

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]

In [None]:
from utils import Evaluate

eval = Evaluate()

In [None]:
eval.classification(all_results, strategy="glad", visualize=True, y_labels=y_labels)

# Dev

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]


In [None]:
[x for x in y_labels if x not in ['claude', 'gemini']]

In [None]:
# import numpy as np
# from sklearn.datasets import fetch_20newsgroups
# from sklearn.model_selection import train_test_split


# dataset = fetch_20newsgroups(subset='train',
#                               remove=('headers', 'footers', 'quotes'),
#                               )
# # take a small sample for dev purposes - Stratified sampling to maintain class distribution
# # Convert the target names to a numpy array
# target_names = np.array(dataset.target_names)

# _, _, y_train, y_test = train_test_split(
#     dataset.filenames, 
#     dataset.target, 
#     train_size=1000,  # Get 1000 samples
#     stratify=dataset.target,  # Ensure class distribution is preserved
#     random_state=seed # For reproducibility
# )

# # Now load the actual data for the selected samples
# dataset_sample= fetch_20newsgroups(
#     subset='train',
#     remove=('headers', 'footers', 'quotes'),
#     categories=target_names[y_train]  # Only load categories in the sample
# )


# # user provided data description
# DESCRIPTION = """
# The 20 newsgroups dataset comprises around 18000 newsgroups posts on 20 topics. Here are the topics:
# ['alt.atheism',
#  'comp.graphics',
#  'comp.os.ms-windows.misc',
#  'comp.sys.ibm.pc.hardware',
#  'comp.sys.mac.hardware',
#  'comp.windows.x',
#  'misc.forsale',
#  'rec.autos',
#  'rec.motorcycles',
#  'rec.sport.baseball',
#  'rec.sport.hockey',
#  'sci.crypt',
#  'sci.electronics',
#  'sci.med',
#  'sci.space',
#  'soc.religion.christian',
#  'talk.politics.guns',
#  'talk.politics.mideast',
#  'talk.politics.misc',
#  'talk.religion.misc']
# """

# LABEL_SET = ['alt.atheism',
#  'comp.graphics',
#  'comp.os.ms-windows.misc',
#  'comp.sys.ibm.pc.hardware',
#  'comp.sys.mac.hardware',
#  'comp.windows.x',
#  'misc.forsale',
#  'rec.autos',
#  'rec.motorcycles',
#  'rec.sport.baseball',
#  'rec.sport.hockey',
#  'sci.crypt',
#  'sci.electronics',
#  'sci.med',
#  'sci.space',
#  'soc.religion.christian',
#  'talk.politics.guns',
#  'talk.politics.mideast',
#  'talk.politics.misc',
#  'talk.religion.misc']