# TODO:
### NOW:
- ~~enforce output format for gemini~~
- llama, gpt, ~~claude~~
   - send concurrent calls to all models at once
- ~~add evaluation if there is a golden set for individual model~~
- aggregation strategy and tie breaking
   - multiclass classification: ~~majority vote~~, baysian approach with GT
      - provide X labeles per class
- repeat the same thing for multi-label/ner

### LATER:
- secret management
- update readme
- add images



### nice things to do:
- add tqdm to asyncio calls
- proper logging

# Annotate

In [None]:
from utils import Annotate
from datasets import load_set

seed =42

In [None]:
gemini_prompt_template = """
<data_description>
{description}
</data_description>
-----------

<context>
{datapoint}
</context>
------------

<labels>
{labels}
</labels>
------------

INSTRUCTION:
- familirize yourself with the data using data_description
- read the context carefully. this is the data point you need to label.
- take your time and label the dadatapoint with the most appropriate option using the provided labels.
- return the result as a single label from the <labels>. Don't provide explanations
"""

In [None]:
dataset = load_dataset("yelp_polarity", split="train") # https://huggingface.co/datasets/yelp_polarity

# take a small sample for dev purposes
dataset_sample = dataset.shuffle(seed=seed).select(range(100))

# user provided data description
DESCRIPTION = """
This is a dataset for binary sentiment classification.
It contains highly polar yelp reviews.
Negative polarity is class 0, and positive class 1.
"""

LABEL_SET = [0, 1] 

In [None]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x,
                                        labels=LABEL_SET) for x in dataset_sample["text"][:20]]
print(len(prompt))

In [None]:
ann = Annotate()

VALID_MODELS = ["gemini", "claude"]

In [None]:
d = {}
for m in VALID_MODELS:
    d[m] = await ann.classification(prompt, model=m)

In [None]:
import json
with open("./data/output/20_sample.json", "w") as json_file:
    json.dump(d, json_file, indent=4)

In [None]:
# all_results = [d["gemini"], d["claude"]]
y_labels = ["gemini", "claude"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0]]

## evaluate

In [None]:
from utils import Evaluate

eval = Evaluate()

In [None]:
eval.classification(all_results, strategy="majority", visualize=True, y_labels=y_labels)

# Dev

In [None]:
import scipy as sp
import numpy as np

In [None]:
y_labels = ["gemini", "claude", "fake"]
all_results = [[1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0],
               [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]]


# task, labeler, label
result = [(j, i, sublist[i]) 
          for j, sublist in enumerate(all_results) 
          for i in range(len(sublist))]

In [17]:
import argparse
import logging
import numpy as np
import scipy as sp
import scipy.stats
import scipy.optimize

THRESHOLD = 1e-5

logger = None


class Dataset(object):
    def __init__(self, labels=None, numLabels=-1, numLabelers=-1, numTasks=-1, numClasses=-1,
                 priorAlpha=None, priorBeta=None, priorZ=None,
                 alpha=None, beta=None, probZ=None):
        self.labels = labels
        self.numLabels = numLabels
        self.numLabelers = numLabelers
        self.numTasks = numTasks
        self.numClasses = numClasses
        self.priorAlpha = priorAlpha
        self.priorBeta = priorBeta
        self.priorZ = priorZ
        self.alpha = alpha
        self.beta = beta
        self.probZ = probZ


def init_logger():
    global logger
    logger = logging.getLogger('GLAD')
    logger.setLevel(logging.DEBUG)
    log_fmt = '%(asctime)s/%(name)s[%(levelname)s]: %(message)s'
    logging.basicConfig(format=log_fmt)


def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))


def logsigmoid(x):
    return - np.log(1 + np.exp(-x))


def load_data(filename):
    data = Dataset()
    with open(filename) as f:
        header = f.readline().split()
        data.numLabels = int(header[0])
        data.numLabelers = int(header[1])
        data.numTasks = int(header[2])
        data.numClasses = int(header[3])
        data.priorZ = np.array([float(x) for x in header[4:]])
        assert len(data.priorZ) == data.numClasses, 'Incorrect input header'
        assert data.priorZ.sum() == 1, 'Incorrect priorZ given'

        data.labels = np.zeros((data.numTasks, data.numLabelers))
        for line in f:
            task, labeler, label = map(int, line.split())
            data.labels[task][labeler] = label + 1

    data.priorAlpha = np.ones(data.numLabelers)
    data.priorBeta = np.ones(data.numTasks)
    data.probZ = np.empty((data.numTasks, data.numClasses))
    data.beta = np.empty(data.numTasks)
    data.alpha = np.empty(data.numLabelers)

    return data


def EM(data):
    data.alpha = data.priorAlpha.copy()
    data.beta = data.priorBeta.copy()
    data.probZ[:] = data.priorZ[:]

    print(data.probZ[:])

    EStep(data)
    lastQ = computeQ(data)
    MStep(data)
    Q = computeQ(data)
    counter = 1
    while abs((Q - lastQ) / lastQ) > THRESHOLD:
        if verbose:
            logger.info('EM: iter={}'.format(counter))
        lastQ = Q
        EStep(data)
        MStep(data)
        Q = computeQ(data)
        counter += 1

def calcLogProbL(item, *args):
    data = args[-1]
    print(data.alpha, data.labels)

    j = int(item[0])
    delta = args[0][j]
    noResp = args[1][j]
    oneMinusDelta = (~delta) & (~noResp)

    exponents = item[1:]

    correct = logsigmoid(exponents[delta]).sum()
    wrong = (logsigmoid(-exponents[oneMinusDelta]) - np.log(float(data.numClasses - 1))).sum()

    return correct + wrong

def EStep(data):
    data.probZ = np.tile(np.log(data.priorZ), data.numTasks).reshape(data.numTasks, data.numClasses)

    ab = np.dot(np.array([np.exp(data.beta)]).T, np.array([data.alpha]))
    ab = np.c_[np.arange(data.numTasks), ab]
    for k in range(data.numClasses):
        data.probZ[:, k] = np.apply_along_axis(calcLogProbL, 1, ab,
                                               (data.labels == k + 1),
                                               (data.labels == 0),
                                               data)  # Pass data as an additional argument

    data.probZ = np.exp(data.probZ)
    s = data.probZ.sum(axis=1)
    data.probZ = (data.probZ.T / s).T
    assert not np.any(np.isnan(data.probZ)), 'Invalid Value [EStep]'
    assert not np.any(np.isinf(data.probZ)), 'Invalid Value [EStep]'



def df(x, *args):
    data = args[0]
    d = Dataset(labels=data.labels, numLabels=data.numLabels, numLabelers=data.numLabelers,
                numTasks=data.numTasks, numClasses=data.numClasses,
                priorAlpha=data.priorAlpha, priorBeta=data.priorBeta,
                priorZ=data.priorZ, probZ=data.probZ)
    unpackX(x, d)
    dQdAlpha, dQdBeta = gradientQ(d)
    return np.r_[-dQdAlpha, -dQdBeta]


def f(x, *args):
    u"""Return the value of the objective function
    """
    data = args[0]
    d = Dataset(labels=data.labels, numLabels=data.numLabels, numLabelers=data.numLabelers,
                numTasks=data.numTasks, numClasses=data.numClasses,
                priorAlpha=data.priorAlpha, priorBeta=data.priorBeta,
                priorZ=data.priorZ, probZ=data.probZ)
    unpackX(x, d)
    return - computeQ(d)

def MStep(data):
    initial_params = packX(data)
    params = sp.optimize.minimize(fun=f, x0=initial_params, args=(data,), method='CG',
                                  jac=df, tol=0.01,
                                  options={'maxiter': 25, 'disp': verbose})
    unpackX(params.x, data)


def computeQ(data):
    Q = 0
    Q += (data.probZ * np.log(data.priorZ)).sum()

    ab = np.dot(np.array([np.exp(data.beta)]).T, np.array([data.alpha]))

    logSigma = logsigmoid(ab)
    idxna = np.isnan(logSigma)
    if np.any(idxna):
        logger.warning('an invalid value was assigned to np.log [computeQ]')
        logSigma[idxna] = ab[idxna]

    logOneMinusSigma = logsigmoid(-ab) - np.log(float(data.numClasses - 1))
    idxna = np.isnan(logOneMinusSigma)
    if np.any(idxna):
        logger.warning('an invalid value was assigned to np.log [computeQ]')
        logOneMinusSigma[idxna] = -ab[idxna]

    for k in range(data.numClasses):
        delta = (data.labels == k + 1)
        Q += (data.probZ[:, k] * logSigma.T).T[delta].sum()
        oneMinusDelta = (data.labels != k + 1) & (data.labels != 0)
        Q += (data.probZ[:, k] * logOneMinusSigma.T).T[oneMinusDelta].sum()

    Q += np.log(sp.stats.norm.pdf(data.alpha - data.priorAlpha)).sum()
    Q += np.log(sp.stats.norm.pdf(data.beta - data.priorBeta)).sum()

    if np.isnan(Q):
        return -np.inf
    return Q


def gradientQ(data):
    dQdAlpha = - (data.alpha - data.priorAlpha)
    dQdBeta = - (data.beta - data.priorBeta)

    ab = np.dot(np.array([np.exp(data.beta)]).T, np.array([data.alpha]))
    sigma = sigmoid(ab)
    sigma[np.isnan(sigma)] = 0

    labelersIdx = np.arange(data.numLabelers).reshape((1, data.numLabelers))
    sigma = np.r_[labelersIdx, sigma]
    sigma = np.c_[np.arange(-1, data.numTasks), sigma]

    for k in range(data.numClasses):
        dQdAlpha += np.apply_along_axis(dAlpha, 0, sigma[:, 1:],
                                         (data.labels == k + 1),
                                         (data.labels == 0),
                                         data.probZ[:, k],
                                         data)

        dQdBeta += np.apply_along_axis(dBeta, 1, sigma[1:],
                                        (data.labels == k + 1),
                                        (data.labels == 0),
                                        data.probZ[:, k],
                                        data) * np.exp(data.beta)

    return dQdAlpha, dQdBeta


def dAlpha(item, *args):
    i = int(item[0])
    sigma_ab = item[1:]

    delta = args[0][:, i]
    noResp = args[1][:, i]
    oneMinusDelta = (~delta) & (~noResp)

    probZ = args[2]

    data = args[3] 

    correct = probZ[delta] * np.exp(data.beta[delta]) * (1 - sigma_ab[delta])
    wrong = probZ[oneMinusDelta] * np.exp(data.beta[oneMinusDelta]) * (-sigma_ab[oneMinusDelta])

    return correct.sum() + wrong.sum()


def dBeta(item, *args):
    j = int(item[0])
    sigma_ab = item[1:]

    delta = args[0][j]
    noResp = args[1][j]
    oneMinusDelta = (~delta) & (~noResp)

    probZ = args[2][j]
    data = args[3] 

    correct = probZ * data.alpha[delta] * (1 - sigma_ab[delta])
    wrong = probZ * data.alpha[oneMinusDelta] * (-sigma_ab[oneMinusDelta])

    return correct.sum() + wrong.sum()


def packX(data):
    return np.r_[data.alpha.copy(), data.beta.copy()]


def unpackX(x, data):
    data.alpha = x[:data.numLabelers].copy()
    data.beta = x[data.numLabelers:].copy()





def output(data):
    alpha = np.c_[np.arange(data.numLabelers), data.alpha]
    np.savetxt('data/alpha.csv', alpha, fmt=['%d', '%.5f'], delimiter=',', header='id,alpha')
    beta = np.c_[np.arange(data.numTasks), np.exp(data.beta)]
    np.savetxt('data/beta.csv', beta, fmt=['%d', '%.5f'], delimiter=',', header='id,beta')
    probZ = np.c_[np.arange(data.numTasks), data.probZ]
    np.savetxt(fname='data/probZ.csv',
               X=probZ,
               fmt=['%d'] + (['%.5f'] * data.numClasses),
               delimiter=',',
               header='id,' + ','.join(['z' + str(k) for k in range(data.numClasses)]))
    label = np.c_[np.arange(data.numTasks), np.argmax(data.probZ, axis=1)]
    np.savetxt('data/label_glad.csv', label, fmt=['%d', '%d'], delimiter=',', header='id,label')


def main():
    global debug, verbose
    init_logger()

    debug = True
    verbose = True

    data = load_data("./data/data.txt")

    EM(data)
    output(data)
    return



In [18]:
main()

[[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 ...
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1

2024-05-10 20:45:24,697/GLAD[INFO]: EM: iter=1


         Current function value: 11044.207398
         Iterations: 25
         Function evaluations: 45
         Gradient evaluations: 45
[0.33315453 0.31685869 0.38449305 0.3616166  0.37478013 0.34995843
 0.4033808  0.32692041 0.31437528 0.32074866 0.38337972 0.35536283
 0.38209194 0.3668237  0.34369716 0.3987627  1.4188851  1.53420844
 1.63548769 1.53787155] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.33315453 0.31685869 0.38449305 0.3616166  0.37478013 0.34995843
 0.4033808  0.32692041 0.31437528 0.32074866 0.38337972 0.35536283
 0.38209194 0.3668237  0.34369716 0.3987627  1.4188851  1.53420844
 1.63548769 1.53787155] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.33315453 0.31685869 0.38449305 0.3616166  0.37478013 0.34995843
 0.4033808  0.32692041 0.31437528

2024-05-10 20:45:27,797/GLAD[INFO]: EM: iter=2


         Current function value: 10902.987650
         Iterations: 25
         Function evaluations: 47
         Gradient evaluations: 47
[0.32073701 0.30276542 0.37209338 0.35090743 0.36254488 0.34245948
 0.39845283 0.3209832  0.30589133 0.3128972  0.37495734 0.35048986
 0.37240874 0.35168433 0.3275315  0.38372637 1.87790673 2.28650951
 2.24318221 2.10767697] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.32073701 0.30276542 0.37209338 0.35090743 0.36254488 0.34245948
 0.39845283 0.3209832  0.30589133 0.3128972  0.37495734 0.35048986
 0.37240874 0.35168433 0.3275315  0.38372637 1.87790673 2.28650951
 2.24318221 2.10767697] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.32073701 0.30276542 0.37209338 0.35090743 0.36254488 0.34245948
 0.39845283 0.3209832  0.30589133

2024-05-10 20:45:31,388/GLAD[INFO]: EM: iter=3


         Current function value: 10865.455154
         Iterations: 25
         Function evaluations: 58
         Gradient evaluations: 58
[0.32046683 0.30019498 0.36732822 0.34673829 0.35855277 0.33817406
 0.39583792 0.31942817 0.30486342 0.31220822 0.37084842 0.34645088
 0.36884822 0.34787566 0.32422973 0.38109941 1.97491727 2.56735432
 2.44041408 2.23366336] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.32046683 0.30019498 0.36732822 0.34673829 0.35855277 0.33817406
 0.39583792 0.31942817 0.30486342 0.31220822 0.37084842 0.34645088
 0.36884822 0.34787566 0.32422973 0.38109941 1.97491727 2.56735432
 2.44041408 2.23366336] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.32046683 0.30019498 0.36732822 0.34673829 0.35855277 0.33817406
 0.39583792 0.31942817 0.30486342

2024-05-10 20:45:34,940/GLAD[INFO]: EM: iter=4


         Current function value: 10862.565973
         Iterations: 25
         Function evaluations: 56
         Gradient evaluations: 56
[0.31907252 0.29917075 0.36556224 0.34522168 0.35696385 0.33705377
 0.39387344 0.3182799  0.3037414  0.31107868 0.36947447 0.34506435
 0.36729954 0.34646971 0.32271211 0.37958932 1.98331349 2.5893038
 2.44912574 2.24588229] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.31907252 0.29917075 0.36556224 0.34522168 0.35696385 0.33705377
 0.39387344 0.3182799  0.3037414  0.31107868 0.36947447 0.34506435
 0.36729954 0.34646971 0.32271211 0.37958932 1.98331349 2.5893038
 2.44912574 2.24588229] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.31907252 0.29917075 0.36556224 0.34522168 0.35696385 0.33705377
 0.39387344 0.3182799  0.3037414  0

2024-05-10 20:45:38,288/GLAD[INFO]: EM: iter=5


         Current function value: 10862.240912
         Iterations: 25
         Function evaluations: 49
         Gradient evaluations: 49
[0.31869739 0.29876349 0.36510734 0.34481735 0.35654707 0.33665307
 0.3933869  0.31791212 0.3033782  0.31069231 0.36901368 0.34465687
 0.36684171 0.34605365 0.32224112 0.37912008 1.97752824 2.58927525
 2.44535258 2.24605259] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.31869739 0.29876349 0.36510734 0.34481735 0.35654707 0.33665307
 0.3933869  0.31791212 0.3033782  0.31069231 0.36901368 0.34465687
 0.36684171 0.34605365 0.32224112 0.37912008 1.97752824 2.58927525
 2.44535258 2.24605259] [[2. 2. 1. ... 2. 2. 2.]
 [2. 2. 2. ... 2. 2. 2.]
 [1. 2. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 2. 2. 2.]
 [1. 1. 2. ... 2. 2. 2.]
 [2. 2. 1. ... 1. 1. 1.]]
[0.31869739 0.29876349 0.36510734 0.34481735 0.35654707 0.33665307
 0.3933869  0.31791212 0.3033782 