# TODO:
### NOW:
- ~~enforce output format for gemini~~
- llama, gpt, ~~claude~~
   - send concurrent calls to all models at once
- ~~add evaluation if there is a golden set for individual model~~
- aggregation strategy
   - multiclass classification: 
      - ~~majority vote~~, add tie breaking strategy
      - ~~baysian approach with GT~~
      - provide X labeles per class
- repeat the same thing for multi-label/ner

### LATER:
- secret management
- ~~update readme~~
- add images



### nice things to do:
- add tqdm to asyncio calls
- ~~proper logging~~

# Annotate

In [None]:
from utils import Annotate
from datasets import load_set

seed =42

In [None]:
gemini_prompt_template = """
<data_description>
{description}
</data_description>
-----------

<context>
{datapoint}
</context>
------------

<labels>
{labels}
</labels>
------------

INSTRUCTION:
- familirize yourself with the data using data_description
- read the context carefully. this is the data point you need to label.
- take your time and label the dadatapoint with the most appropriate option using the provided labels.
- return the result as a single label from the <labels>. Don't provide explanations
"""

In [None]:
dataset = load_dataset("yelp_polarity", split="train") # https://huggingface.co/datasets/yelp_polarity

# take a small sample for dev purposes
dataset_sample = dataset.shuffle(seed=seed).select(range(100))

# user provided data description
DESCRIPTION = """
This is a dataset for binary sentiment classification.
It contains highly polar yelp reviews.
Negative polarity is class 0, and positive class 1.
"""

LABEL_SET = [0, 1] 

In [None]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x,
                                        labels=LABEL_SET) for x in dataset_sample["text"][:20]]
print(len(prompt))

In [None]:
ann = Annotate()

VALID_MODELS = ["gemini", "claude"]

In [None]:
d = {}
for m in VALID_MODELS:
    d[m] = await ann.classification(prompt, model=m)

In [None]:
import json
with open("./data/output/20_sample.json", "w") as json_file:
    json.dump(d, json_file, indent=4)

In [None]:
# all_results = [d["gemini"], d["claude"]]


## Aggregate

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]

In [None]:
from utils import Aggregate

In [None]:
agg = Aggregate()

In [None]:
agg._get_majority_vote(all_results)

In [None]:
agg._glad(all_results)

## evaluate

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]

In [None]:
from utils import Evaluate

eval = Evaluate()

In [None]:
eval.classification(all_results, strategy="majority", visualize=True, y_labels=y_labels)

# Dev

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]


# list(map(list, zip(*all_results))) 
# # task, labeler, label
# result = [(j, i, sublist[i]) 
#           for j, sublist in enumerate(all_results) 
#           for i in range(len(sublist))]

In [None]:
import logging
import numpy as np
import scipy as sp
THRESHOLD = 1e-5

logger = None

def Dataset(**kwargs):
    """Function to create a dataset-like dictionary."""
    return kwargs

def load_data(arr):
    data = Dataset(
        numLabels=sum(len(sublist) for sublist in arr),
        numLabelers=len(arr),
        numTasks=len(arr[0]),
        numClasses=len(set(arr[0])),
        priorZ=np.repeat(1 / len(set(arr[0])), len(set(arr[0]))),
        labels=np.array(list(map(list, zip(*arr))))
    )

    assert np.isclose(data['priorZ'].sum(), 1), 'Incorrect priorZ given'

    data['priorAlpha'] = np.ones(data['numLabelers'])
    data['priorBeta'] = np.ones(data['numTasks'])
    data['probZ'] = np.empty((data['numTasks'], data['numClasses']))
    data['beta'] = np.empty(data['numTasks'])
    data['alpha'] = np.empty(data['numLabelers'])

    return data


def init_logger():
    global logger
    logger = logging.getLogger('GLAD')
    logger.setLevel(logging.DEBUG)
    log_fmt = '%(asctime)s/%(name)s[%(levelname)s]: %(message)s'
    logging.basicConfig(format=log_fmt)


def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def logsigmoid(x):
    return - np.logaddexp(0, -x) 


def EM(data):
    data["alpha"] = data["priorAlpha"].copy()
    data["beta"] = data["priorBeta"].copy()
    data["probZ"][:] = data["priorZ"][:]

    print(data["probZ"][:])

    EStep(data)
    lastQ = computeQ(data)
    MStep(data)
    Q = computeQ(data)
    counter = 1
    while abs((Q - lastQ) / lastQ) > THRESHOLD:
        if verbose:
            logger.info('EM: iter={}'.format(counter))
        lastQ = Q
        EStep(data)
        MStep(data)
        Q = computeQ(data)
        counter += 1

def calcLogProbL(item, *args):
    data = args[-1]
    print(data["alpha"], data["labels"])

    j = int(item[0])
    delta = args[0][j]
    noResp = args[1][j]
    oneMinusDelta = (~delta) & (~noResp)

    exponents = item[1:]

    correct = logsigmoid(exponents[delta]).sum()
    wrong = (logsigmoid(-exponents[oneMinusDelta]) - np.log(float(data["numClasses"] - 1))).sum()

    return correct + wrong

def EStep(data):
    data["probZ"] = np.tile(np.log(data["priorZ"]), data["numTasks"]).reshape(data["numTasks"], data["numClasses"])

    ab = np.dot(np.array([np.exp(data["beta"])]).T, np.array([data["alpha"]]))
    ab = np.c_[np.arange(data["numTasks"]), ab]
    for k in range(data["numClasses"]):
        data["probZ"][:, k] = np.apply_along_axis(calcLogProbL, 1, ab,
                                               (data["labels"] == k + 1),
                                               (data["labels"] == 0),
                                               data)  # Pass data as an additional argument

    data["probZ"] = np.exp(data["probZ"])
    s = data["probZ"].sum(axis=1)
    data["probZ"] = (data["probZ"].T / s).T
    assert not np.any(np.isnan(data["probZ"])), 'Invalid Value [EStep]'



def df(x, *args):
    data = args[0]
    d = Dataset(labels=data["labels"], numLabels=data["numLabels"], numLabelers=data["numLabelers"],
                numTasks=data["numTasks"], numClasses=data["numClasses"],
                priorAlpha=data["priorAlpha"], priorBeta=data["priorBeta"],
                priorZ=data["priorZ"], probZ=data["probZ"])
    unpackX(x, d)
    dQdAlpha, dQdBeta = gradientQ(d)
    return np.r_[-dQdAlpha, -dQdBeta]


def f(x, *args):
    u"""Return the value of the objective function
    """
    data = args[0]
    d = Dataset(labels=data["labels"], numLabels=data["numLabels"], numLabelers=data["numLabelers"],
                numTasks=data["numTasks"], numClasses=data["numClasses"],
                priorAlpha=data["priorAlpha"], priorBeta=data["priorBeta"],
                priorZ=data["priorZ"], probZ=data["probZ"])
    unpackX(x, d)
    return - computeQ(d)

def MStep(data):
    initial_params = packX(data)
    params = sp.optimize.minimize(fun=f, x0=initial_params, args=(data,), method='CG',
                                  jac=df, tol=0.01,
                                  options={'maxiter': 25, 'disp': verbose})
    unpackX(params.x, data)


def computeQ(data):
    Q = 0
    Q += (data["probZ"] * np.log(data["priorZ"])).sum()

    ab = np.dot(np.array([np.exp(data["beta"])]).T, np.array([data["alpha"]]))

    logSigma = logsigmoid(ab)
    idxna = np.isnan(logSigma)
    if np.any(idxna):
        logger.warning('an invalid value was assigned to np.log [computeQ]')
        logSigma[idxna] = ab[idxna]

    logOneMinusSigma = logsigmoid(-ab) - np.log(float(data["numClasses"] - 1))
    idxna = np.isnan(logOneMinusSigma)
    if np.any(idxna):
        logger.warning('an invalid value was assigned to np.log [computeQ]')
        logOneMinusSigma[idxna] = -ab[idxna]

    for k in range(data["numClasses"]):
        delta = (data["labels"] == k + 1)
        Q += (data["probZ"][:, k] * logSigma.T).T[delta].sum()
        oneMinusDelta = (data["labels"] != k + 1) & (data["labels"] != 0)
        Q += (data["probZ"][:, k] * logOneMinusSigma.T).T[oneMinusDelta].sum()

    Q += np.log(sp.stats.norm.pdf(data["alpha"] - data["priorAlpha"])).sum()
    Q += np.log(sp.stats.norm.pdf(data["beta"] - data["priorBeta"])).sum()

    if np.isnan(Q):
        return -np.inf
    return Q


def gradientQ(data):
    dQdAlpha = - (data["alpha"] - data["priorAlpha"])
    dQdBeta = - (data["beta"] - data["priorBeta"])

    ab = np.exp(data["beta"])[:, np.newaxis] * data["alpha"]
    sigma = sigmoid(ab)
    sigma[np.isnan(sigma)] = 0

    for k in range(data["numClasses"]):
        delta = (data["labels"] == k + 1)
        oneMinusDelta = (data["labels"] != k + 1) & (data["labels"] != 0)

        dQdAlpha += (data["probZ"][:, k][:, np.newaxis] * np.exp(data["beta"])[:, np.newaxis] * (delta - sigma)).sum(axis=0)
        dQdBeta += (data["probZ"][:, k][:, np.newaxis] * data["alpha"] * (delta - sigma)).sum(axis=1)

    return dQdAlpha, dQdBeta


def dAlpha(item, *args):
    i = int(item[0])
    sigma_ab = item[1:]

    delta = args[0][:, i]
    noResp = args[1][:, i]
    oneMinusDelta = (~delta) & (~noResp)

    probZ = args[2]

    data = args[3] 

    correct = probZ[delta] * np.exp(data["beta"][delta]) * (1 - sigma_ab[delta])
    wrong = probZ[oneMinusDelta] * np.exp(data["beta"][oneMinusDelta]) * (-sigma_ab[oneMinusDelta])

    return correct.sum() + wrong.sum()


def dBeta(item, *args):
    j = int(item[0])
    sigma_ab = item[1:]

    delta = args[0][j]
    noResp = args[1][j]
    oneMinusDelta = (~delta) & (~noResp)

    probZ = args[2][j]
    data = args[3] 

    correct = probZ * data["alpha"][delta] * (1 - sigma_ab[delta])
    wrong = probZ * data["alpha"][oneMinusDelta] * (-sigma_ab[oneMinusDelta])

    return correct.sum() + wrong.sum()


def packX(data):
    return np.r_[data["alpha"].copy(), data["beta"].copy()]


def unpackX(x, data):
    data["alpha"] = x[:data["numLabelers"]].copy()
    data["beta"] = x[data["numLabelers"]:].copy()


def output(data):
    alpha = np.c_[np.arange(data["numLabelers"]), data["alpha"]]
    beta = np.c_[np.arange(data["numTasks"]), np.exp(data["beta"])]
    probZ = np.c_[np.arange(data["numTasks"]), data["probZ"]]
    label = np.c_[np.arange(data["numTasks"]), np.argmax(data["probZ"], axis=1)]

    return {"alpha": alpha,
            "beta": beta,
            "probZ": probZ,
            "labels": label}


def main():
    global debug, verbose
    init_logger()

    debug = False
    verbose = False

    data = load_data(all_results)

    EM(data)
    r = output(data)
    return r



In [None]:
output = main()

In [None]:
output['labels']