# TODO:
### NOW:
- ~~enforce output format for gemini~~
- llama, gpt, ~~claude~~
   - ~~send concurrent calls to all models at once~~
- ~~add evaluation if there is a golden set for individual model~~
- aggregation strategy
   - multiclass classification: 
      - ~~majority vote~~, add tie breaking strategy
      - ~~baysian approach with GT~~
      - provide X labels per class
      - what if labels are not int?
- evaluate
   - bug in glad evaluation. make sure the labels are int
- repeat the same thing for multi-label/ner

### LATER:
- secret management
- ~~update readme~~
- add images



### nice things to do:
- ~~add tqdm to asyncio calls~~
- ~~proper logging~~

# Load the data

In [1]:
from datasets import load_dataset
seed =42

In [2]:
gemini_prompt_template = """
<data_description>
{description}
</data_description>
-----------

<context>
{datapoint}
</context>
------------

<labels>
{labels}
</labels>
------------

INSTRUCTION:
- familirize yourself with the data using data_description
- read the context carefully. this is the data point you need to label.
- take your time and label the dadatapoint with the most appropriate option using the provided labels.
- return the result as a single label from the <labels>. Don't provide explanations
"""

In [3]:
dataset = load_dataset("yelp_polarity", split="train") # https://huggingface.co/datasets/yelp_polarity

# take a small sample for dev purposes
dataset_sample = dataset.shuffle(seed=seed).select(range(20))

# user provided data description
DESCRIPTION = """
This is a dataset for binary sentiment classification.
It contains highly polar yelp rgenerate_funceviews.
Negative polarity is class 0, and positive class 1.
"""

LABEL_SET = [0, 1] 

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x,
                                        labels=LABEL_SET) for x in dataset_sample["text"][:]]
print(len(prompt))

20


# Annotate

In [5]:
from utils import Annotate

from datetime import datetime

# Get the current datetime
now = datetime.now().strftime("%Y%m%d")

seed =42

In [6]:
ann = Annotate()

models = [
    "palm"
    # "gemini",
    # "claude"
    ]

In [8]:
await ann.__palm(prompt=prompt[0])

AttributeError: 'Annotate' object has no attribute '__palm'

In [9]:
r = await ann.classification(prompt, models=models)

Creating tasks: 40it [00:00, 63119.70it/s]            
Gathering palm results:   0%|          | 0/20 [00:00<?, ?it/s]2024-05-21 19:20:48,923/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:49,101/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:49,272/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:49,433/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:49,611/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:49,834/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:50,015/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' object has no attribute 'predcit'
2024-05-21 19:20:50,174/Annotate[ERROR]: Error in __palm: 'TextGenerationModel' objec

In [None]:
r['gemini'], r['claude']

In [None]:
import json
with open(f"./data/output/annotation_output__{now}.json", "w") as json_file:
    json.dump(r, json_file, indent=4)

## Aggregate

In [None]:
import json
import random

with open('./data/output/annotation_output__20240515.json', 'r') as file:
        all_results = json.load(file)

all_results['fake'] = [random.randint(0, 1) for _ in range(len(all_results['gemini']))]
y_labels = ["gemini", "claude", "fake"]


all_results = [value for value in all_results.values()]

In [None]:
from utils import Aggregate

agg = Aggregate()

In [None]:
majority_vote = agg._get_majority_vote(all_results)

In [None]:
glad_result = agg._glad(all_results)

In [None]:
glad_result.keys()

In [None]:
[x for x in glad_result['labels'] if x[1] == 1]

## evaluate

In [None]:
from utils import Evaluate

eval = Evaluate()

In [None]:
eval.classification(all_results, strategy="majority", visualize=False, y_labels=y_labels)

# Dev

In [None]:
# import numpy as np
# from sklearn.datasets import fetch_20newsgroups

# from sklearn.model_selection import train_test_split


# dataset = fetch_20newsgroups(subset='train',
#                               remove=('headers', 'footers', 'quotes'),
#                               )
# # take a small sample for dev purposes - Stratified sampling to maintain class distribution
# # Convert the target names to a numpy array
# target_names = np.array(dataset.target_names)

# _, _, y_train, y_test = train_test_split(
#     dataset.filenames, 
#     dataset.target, 
#     train_size=1000,  # Get 1000 samples
#     stratify=dataset.target,  # Ensure class distribution is preserved
#     random_state=seed # For reproducibility
# )

# # Now load the actual data for the selected samples
# dataset_sample= fetch_20newsgroups(
#     subset='train',["""hello"""]
#     remove=('headers', 'footers', 'quotes'),
#     categories=target_names[y_train]  # Only load categories in the sample
# )


# # user provided data description
# DESCRIPTION = """
# The 20 newsgroups dataset comprises around 18000 newsgroups posts on 20 topics. Here are the topics:
# ['alt.atheism',
#  'comp.graphics',
#  'comp.os.ms-windows.misc',
#  'comp.sys.ibm.pc.hardware',
#  'comp.sys.mac.hardware',
#  'comp.windows.x',
#  'misc.forsale',
#  'rec.autos',
#  'rec.motorcycles',
#  'rec.sport.baseball',
#  'rec.sport.hockey',
#  'sci.crypt',
#  'sci.electronics',
#  'sci.med',
#  'sci.space',
#  'soc.religion.christian',
#  'talk.politics.guns',
#  'talk.politics.mideast',
#  'talk.politics.misc',
#  'talk.religion.misc']
# """

# LABEL_SET = ['alt.atheism',
#  'comp.graphics',
#  'comp.os.ms-windows.misc',
#  'comp.sys.ibm.pc.hardware',
#  'comp.sys.mac.hardware',
#  'comp.windows.x',
#  'misc.forsale',
#  'rec.autos',
#  'rec.motorcycles',
#  'rec.sport.baseball',
#  'rec.sport.hockey',
#  'sci.crypt',
#  'sci.electronics',
#  'sci.med',
#  'sci.space',
#  'soc.religion.christian',
#  'talk.politics.guns',
#  'talk.politics.mideast',
#  'talk.politics.misc',
#  'talk.religion.misc']

In [None]:
from huggingface_hub import login, logout
login()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16
)


In [None]:

# input_text = "Write me a poem about Machine Learning."
# input_ids = tokenizer(input_text, return_tensors="pt")

# outputs = model.generate(**input_ids)
# print(tokenizer.decode(outputs[0]))


In [None]:
input_text = prompt
input_ids = tokenizer(input_text, padding=True, return_tensors="pt")


In [None]:
outputs = model.generate(**input_ids)
print(tokenizer.batch_decode(outputs))

In [None]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1", device_map="auto", max_new_tokens=2048
)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt")

In [None]:
tokenizer.pad_token = tokenizer.eos_token  # Most LLMs don't have a pad token by default
model_inputs = tokenizer(
    prompt, return_tensors="pt", padding=True
)
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [None]:
import vertexai
from vertexai.language_models import TextGenerationModel
from config import PALM_CONFIG

vertexai.init(project="amir-genai-bb", location="us-central1")


model = TextGenerationModel.from_pretrained("text-bison-32k@002")
response = model.predict(
    prompt[0],
    **PALM_CONFIG["generation_config"]
)
print(f"Response from Model: {response.text}")

In [None]:
for p in prompt:
    response = model.predict(
    p,
    **PALM_CONFIG["generation_config"]
)
    print(f"Response from Model: {response.text}")