## Contextual Calibration


Here we demonstate how users can take advatange of contextual calibration to enhance prompt labeler performance. We follow the strategies proposed by Zhao et al. (2021). To use calibration, simply call `Client.calbrate(Template, Voter)`.


References: 

Zhao, Z., Wallace, E., Feng, S., Klein, D., & Singh, S. (2021, July). Calibrate before use: Improving few-shot performance of language models. In International Conference on Machine Learning (pp. 12697-12706). PMLR.



#### Let's load a test dataset and use a alfred remote client


In [None]:
from alfred.data.wrench import WrenchBenchmarkDataset
from alfred.client import Client

youtube_dev = WrenchBenchmarkDataset(
                                dataset_name='youtube',
                                split='valid',
                                local_path="/data/Datasets/wrench/"
                            )



t03b = Client(end_point="", ssh_tunnel=True, ssh_node="")

#### Then we define a template that ask the LLM to decide whether the given instance reference another channel or video

In [None]:
from alfred.template import StringTemplate
from alfred.voter import Voter

channel_reference_template = StringTemplate(
    template = """Does the following comment reference the speaker’s channel or video?\n\n[[text]]""",
    answer_choices = "yes ||| no",
)

yes_voter = Voter(
    label_map = {'yes': 1, 'no': 0},
    matching_fn = lambda x, y: x == y,
)

#### We can quickly evaluate the performance by asking if the responses align with ground truth

In [None]:
import numpy as np

channel_reference_prompts_dev = channel_reference_template.apply_to_dataset(youtube_dev)

dev_resp = t03b(channel_reference_prompts_dev)

votes = yes_voter.vote(dev_resp)
acc = np.mean(votes==np.array(youtube_dev.labels))

print(f"Acc before calibration: {acc}")

#### Now let's try using the contextual calibration

In [None]:
t03b.calibrate(channel_reference_template, voter=yes_voter)

calibrated_votes = yes_voter.vote(dev_resp)
calibrated_acc = np.mean(calibrated_votes==np.array(youtube_dev.labels))

print(f"Acc before calibration: {acc} After calibration: {calibrated_acc}")