In [26]:
import os
import re
import json
import torch
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import f1_score, classification_report, confusion_matrix
# from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
import outlines # see https://dottxt-ai.github.io/outlines/latest/features/models/transformers/
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

from data_loader import load_dataset, format_input_with_context


In [12]:
MODEL_NAME = 'meta-llama/Llama-3.2-3B-Instruct'
BATCH_SIZE = 12
USE_CTX = False # whether to use [CTX] parts of threads (too slow for True)
MAX_COT_TOKENS = 100 # max n tokens for cot response

RESULTS_DIR = "./results/prompting/"
os.makedirs(RESULTS_DIR, exist_ok=True)

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# load data
train_df, dev_df, test_df = load_dataset()
use_df = dev_df # df to test on

VALID_STANCES = ['SUPPORT', 'DENY', 'QUERY', 'COMMENT']

# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

Loading cached data from saved_data/datasets.pkl...


In [35]:
# load model

RELOAD_MODEL = False

# output types for outlines
from typing import Literal
StanceLabel = Literal["SUPPORT", "DENY", "QUERY", "COMMENT"]
COT_REGEX = outlines.types.Regex(r".*Label: (SUPPORT|DENY|QUERY|COMMENT)")

if 'model' not in globals() or RELOAD_MODEL:
    hf_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.float16,
        device_map='auto',
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # use outlines
    model = outlines.from_transformers(hf_model, tokenizer)
    
    # get rid of annoying warnings re pad_token
    model.tokenizer.pad_token = tokenizer.eos_token
    model.tokenizer.padding_side = "left"
    model.model.generation_config.pad_token_id = model.tokenizer.pad_token_id


# if 'pipe' not in globals() or RELOAD_MODEL: # rerunning takes up more mem
#     pipe = pipeline(
#         "text-generation", 
#         MODEL_NAME, 
#         device_map='auto',
#         dtype=torch.float16,
#     )
#     pipe.tokenizer.pad_token = pipe.tokenizer.eos_token
#     pipe.tokenizer.padding_side = "left"

# print(pipe.model.config)

print(MODEL_NAME)
print(BATCH_SIZE)

meta-llama/Llama-3.2-3B-Instruct
12


In [36]:
# prompts
# see https://huggingface.co/docs/transformers/en/tasks/prompting

PERSONA = """You are an expert in rumour stance analysis on twitter."""
INSTRUCTION = """Your task is to classify the stance of the [TARGET] tweet towards the veracity of the rumour in the [SRC] tweet."""
INPUT_FORMAT = """\
    The input will be provided as a single string containing labeled segments:
    "[SRC] ... [PARENT] ... [TARGET] ...". (Note: [PARENT] is omitted if [TARGET] replies directly to [SRC])
    """
# LABEL_DEFNS = """\
#     Classification Labels:
#     - SUPPORT: The reply supports the veracity of the source claim
#     - DENY: The reply denies the veracity of the source claim
#     - QUERY: The reply asks for additional evidence in relation to the veracity of the source claim
#     - COMMENT: The reply makes their own comment without a clear contribution to assessing the veracity of the source claim
#     """
LABEL_DEFNS = """\
    Classification Labels:
    - SUPPORT: The reply supports veracity of the source claim
    - DENY: The reply denies the veracity of the source claim
    - QUERY: The reply asks for additional evidence in relation to the veracity of the source claim
    - COMMENT: The reply makes their own comment without a clear contribution to assessing the veracity of the source claim
    """

OUTPUT_FORMAT = """\
    Respond with ONLY one word: SUPPORT, DENY, QUERY, or COMMENT.
    """

SYS_PROMPT = f"""\
{PERSONA}
{INSTRUCTION}
{INPUT_FORMAT}
{LABEL_DEFNS}
{OUTPUT_FORMAT}
"""

USER_PROMPT_TEMPLATE = """\
Text: {thread_context}

Task: Classify the stance of [TARGET] towards the veracity of the rumour in [SRC].
"""

def build_user_prompt(thread_context):
    return USER_PROMPT_TEMPLATE.format(thread_context=thread_context)


def build_zero_shot_messages(thread_context):
    return [
        {"role": "system", "content": SYS_PROMPT},
        {"role": "user", "content": build_user_prompt(thread_context)},
    ]


def build_few_shot_messages(thread_context, examples=None):
    messages = [{"role": "system", "content": SYS_PROMPT}]
    
    if examples:
        for ex in examples:
            messages.append({"role": "user", "content": build_user_prompt(ex['source'])})
            messages.append({"role": "assistant", "content": ex['label']})
    
    messages.append({"role": "user", "content": build_user_prompt(thread_context)})
    return messages


In [37]:
# COT prompts 

# handselected ids for reasoning
COT_FEW_SHOT_IDS = {
    'support': '524967134339022848', # @TheKirkness Radio Canada tweeting same. must be true :-(
    'deny': '544292581950357504', # @JSchoenberger7 @TheAnonMessage not an Isis flag. Just an Islamic one. Stop spreading false rumors.
    'query': '544281192632053761', # @mscott Are your reporters 100% sure it is an ISIS flag? Cause that is what is being reported. #facts
    'comment': '552804023389392896', # @thei100 @Independent these fuckers thinking its a GTA heist mission
}

COT_EXAMPLES = {
    'support': "1. Stance: takes position on veracity\n2. \"must be true\" → affirms claim\nLabel: SUPPORT",
    'deny': "1. Stance: challenges veracity\n2. \"not an Isis flag\", \"false rumors\" → rejects claim\nLabel: DENY",
    'query': "1. Stance: engages with veracity\n2. \"Are your reporters 100% sure?\" → requests evidence\nLabel: QUERY",
    'comment': "1. Stance: no veracity assessment\n2. Offers opinion/reaction only\nLabel: COMMENT",
}


COT_OUTPUT_FORMAT = """\
Think BRIEFLY step-by-step (max 2-3 short lines).
End with "Label: " followed by ONLY one word: SUPPORT, DENY, QUERY, or COMMENT.
"""

COT_SYS_PROMPT = f"""\
{PERSONA}
{INSTRUCTION}
{INPUT_FORMAT}
{LABEL_DEFNS}
{COT_OUTPUT_FORMAT}
"""

def build_cot_messages(thread_context, examples=None):
    messages = [{"role": "system", "content": COT_SYS_PROMPT}]
    
    if examples:
        for ex in examples:
            messages.append({"role": "user", "content": build_user_prompt(ex['source'])})
            messages.append({"role": "assistant", "content": ex['reasoning']})
    
    messages.append({"role": "user", "content": build_user_prompt(thread_context)})
    return messages

USER_COT_PROMPT_TEMPLATE = """\
Text: {thread_context}

Let's go through this step-by-step:
"""

def build_cot_user_prompt(thread_context):
    return USER_COT_PROMPT_TEMPLATE.format(thread_context=thread_context)


In [38]:
# def parse_stance_response(text):
#     text = text.upper().strip()
#     for label in VALID_STANCES:
#         if re.search(rf'\b{label}\b', text):
#             return label.lower()
#     return None

def parse_cot_label(text):
    """Extract the stance label from CoT output (after 'Label: ')."""
    match = re.search(r'Label:\s*(SUPPORT|DENY|QUERY|COMMENT)', text, re.IGNORECASE)
    if match:
        return match.group(1).lower()
    return None

In [39]:
# def generate_response(pipe, messages): 
#     output = pipe(
#         messages, 
#         max_new_tokens=10,
#         do_sample=False, # deterministic output
#         pad_token_id=pipe.tokenizer.eos_token_id, # prevent warnings
#         )
#     return output[0]["generated_text"][-1]["content"].strip()



# LLM --
def plot_confusion_matrix(cm, mode, save_path=None):
    """Plot confusion matrix."""
    labels = ['support', 'deny', 'query', 'comment']
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix ({mode})')
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()


def save_results(results, model_key, mode):
    """Save evaluation results to JSON."""
    output = {
        'model': model_key,
        'mode': mode,
        'macro_f1': results['macro_f1'],
        'per_class_f1': results['per_class_f1'],
        'predictions': results['predictions'],
        'true_labels': results['true_labels'],
        'raw_responses': results['raw_responses']
    }
    
    filename = f"{RESULTS_DIR}{model_key}_{mode}_results.json"
    with open(filename, 'w') as f:
        json.dump(output, f, indent=2)
    print(f"Saved: {filename}")

In [40]:
# few shot examples
support_example = train_df[train_df['label_text']=='support'].iloc[0]
deny_example = train_df[train_df['label_text']=='deny'].iloc[0]
query_example = train_df[train_df['label_text']=='query'].iloc[0]
comment_example = train_df[train_df['label_text']=='comment'].iloc[0]


def get_few_shot_examples(df, n_per_class=1):
    """Select stratified random examples for few-shot prompting."""
    examples = []
    
    for label in ['support', 'deny', 'query', 'comment']:
        class_df = df[df['label_text'] == label]
        samples = class_df.sample(n=min(n_per_class, len(class_df)))
        
        for _, row in samples.iterrows():
            examples.append({
                'source': format_input_with_context(row, df, use_features=False, use_context=USE_CTX),
                'label': label.upper()
            })
    
    return examples

In [41]:
def evaluate_prompting(model, df, mode="zero-shot", examples=None, batch_size=BATCH_SIZE, verbose=True):
#     from datasets import Dataset
    
#     if mode == "cot":
#         max_new_tokens = MAX_COT_TOKENS
#     else:
#         max_new_tokens = 10
    
#     # Prepare all messages upfront
#     all_messages = []
#     for _, row in df.iterrows():
#         input_text = format_input_with_context(row, df, use_features=False, use_context=USE_CTX)
#         if mode == "zero-shot":
#             messages = build_zero_shot_messages(input_text)
#         elif mode == "few-shot":
#             messages = build_few_shot_messages(input_text, examples)
#         elif mode == "cot":
#             messages = build_cot_messages(input_text, examples)
#         else:
#             raise ValueError(f"Unrecognised mode: {mode}. Choose from 'zero-shot', 'few-shot', or 'cot'.")
#         all_messages.append(messages)
    
#     # Batched inference using Dataset
#     dataset = Dataset.from_dict({"messages": all_messages})
    
#     raw_responses = []
#     for i in tqdm(range(0, len(all_messages), batch_size), desc=f"Evaluating ({mode})"):
#         batch = all_messages[i:i+batch_size]
#         outputs = pipe(batch, max_new_tokens=max_new_tokens, pad_token_id=pipe.tokenizer.eos_token_id)
#         for out in outputs:
#             raw_responses.append(out[0]["generated_text"][-1]["content"].strip())
    
#     # Parse responses
#     predictions = [parse_stance_response(r) for r in raw_responses]

#     # count number of None (errors), print and replace with comment
#     error_idxs = [i for i, p in enumerate(predictions) if p is None]
#     num_errors = len(error_idxs)  
    
#     if error_idxs:
#         print(f"Warning: {num_errors} errors in predictions. Replacing with 'comment'.")
#         for idx in error_idxs:
#             problem_row = df.iloc[idx]['text']
#             raw_text = raw_responses[idx]
            
#             print(f"Index: {idx}")
#             print(f"Raw Model Output: '{raw_text}'")
#             # Adjust 'text' to whatever column contains your input content
#             print(f"Input Data: {problem_row}") 
#             print("-" * 30)
    
#     if num_errors > 0:
#         predictions = [p if p is not None else 'comment' for p in predictions]


    if mode == "cot":
        output_type = COT_REGEX
        max_new_tokens = MAX_COT_TOKENS
    else:
        output_type = StanceLabel
        max_new_tokens = 10
    
    # Prepare all prompts
    all_prompts = []
    for _, row in df.iterrows():
        input_text = format_input_with_context(row, df, use_features=False, use_context=USE_CTX)
        if mode == "zero-shot":
            messages = build_zero_shot_messages(input_text)
        elif mode == "few-shot":
            messages = build_few_shot_messages(input_text, examples)
        elif mode == "cot":
            messages = build_cot_messages(input_text, examples)
        else:
            raise ValueError(f"Unrecognised mode: {mode}. Choose from 'zero-shot', 'few-shot', or 'cot'.")
        
        # Convert messages to prompt string
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        all_prompts.append(prompt)
    
    # Generate with constraints (one at a time - outlines doesn't batch well with constraints)
    raw_responses = []
    predictions = []
    
    for prompt in tqdm(all_prompts, desc=f"Evaluating ({mode})"):
        response = model(prompt, output_type, max_new_tokens=max_new_tokens)
        raw_responses.append(response)
        
        # Parse based on mode
        if mode == "cot":
            # Extract label from CoT response
            pred = parse_cot_label(response)
        else:
            # Direct label output (already constrained)
            pred = response.lower() if response in VALID_STANCES else None
        
        predictions.append(pred)
    
    # Handle any parsing errors (should be rare with constraints)
    error_idxs = [i for i, p in enumerate(predictions) if p is None]
    num_errors = len(error_idxs)
    
    if error_idxs and verbose:
        print(f"Warning: {num_errors} errors in predictions. Replacing with 'comment'.")
        for idx in error_idxs[:5]:  # Show first 5 errors only
            print(f"Index: {idx}")
            print(f"Raw Model Output: '{raw_responses[idx]}'")
            print("-" * 30)
    
    predictions = [p if p is not None else 'comment' for p in predictions]
    
    
    
    

    true_labels = df['label_text'].tolist()
    
    # Metrics
    labels = ['support', 'deny', 'query', 'comment']
    macro_f1 = f1_score(true_labels, predictions, average='macro')
    per_class_f1 = f1_score(true_labels, predictions, average=None, labels=labels)
    
    if verbose:
        print(f"\n{'='*60}\nResults ({mode})\n{'='*60}")
        print(f"Macro F1: {macro_f1:.4f}")
        print(f"\nPer-class F1:")
        for lbl, f1 in zip(labels, per_class_f1):
            print(f"  {lbl}: {f1:.4f}")
        print(f"\n{classification_report(true_labels, predictions, labels=labels, zero_division=0.0)}")
    
    # Save predictions to CSV
    results_df = pd.DataFrame({
        'true_label': true_labels,
        'predicted': predictions,
    })
    csv_path = f"{RESULTS_DIR}predictions_{mode}.csv"
    results_df.to_csv(csv_path, index=False)
    if verbose:
        print(f"Saved predictions: {csv_path}")
    
    return {
        'predictions': predictions,
        'true_labels': true_labels,
        'raw_responses': raw_responses,
        'macro_f1': macro_f1,
        'per_class_f1': dict(zip(labels, per_class_f1)),
        'confusion_matrix': confusion_matrix(true_labels, predictions, labels=labels)
    }

## Zero-shot

In [25]:
# zero-shot
zero_results = evaluate_prompting(model, use_df, mode="zero-shot")

Evaluating (zero-shot):   0%|          | 0/281 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707] WON'T CONVERT _apply_token_bitmask_inplace_kernel /home2/nchw73/venv312/lib/python3.12/site-packages/outlines_core/kernels/torch.py line 43 
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707] due to: 
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707] Traceback (most recent call last):
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707]   File "/home2/nchw73/venv312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1433, in _compile
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707]     guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0102 18:13:08.965000 373702 torch/_dynamo/convert_frame.py:1707]   File "/home2/nchw73

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for


Results (zero-shot)
Macro F1: 0.2300

Per-class F1:
  support: 0.0286
  deny: 0.2069
  query: 0.2110
  comment: 0.4735

              precision    recall  f1-score   support

     support       1.00      0.01      0.03        69
        deny       0.17      0.27      0.21        11
       query       0.12      0.82      0.21        28
     comment       0.81      0.34      0.47       173

    accuracy                           0.30       281
   macro avg       0.52      0.36      0.23       281
weighted avg       0.76      0.30      0.33       281

Saved predictions: ./results/prompting/predictions_zero-shot.csv


In [None]:
sns.set_theme(style="whitegrid", font_scale=1.1)

# Component dictionary for easy reference
PROMPT_COMPONENTS = {
    'persona': PERSONA,
    'instruction': INSTRUCTION,
    'input_format': INPUT_FORMAT,
    'label_defns': LABEL_DEFNS,
    'output_format': OUTPUT_FORMAT,
}

# Build colour palette: light grey for minimal, seaborn default for components, black for full
_default_palette = sns.color_palette()  # Default seaborn qualitative palette
COMPONENT_COLOURS = {
    'minimal': '#D3D3D3',              # Light grey
    'persona': _default_palette[0],     # Blue
    'instruction': _default_palette[1], # Orange
    'input_format': _default_palette[2],# Green
    'label_defns': _default_palette[3], # Red
    'output_format': _default_palette[4],# Purple
    'full': '#000000',                  # Black
}

# --- Ablation Configurations ---

# ISOLATED: minimal + ONE component (shows individual contribution)
ISOLATED_CONFIGS = {
    'minimal': [],
    '+persona': ['persona'],
    '+instruction': ['instruction'],
    '+input_format': ['input_format'],
    '+label_defns': ['label_defns'],
    '+output_format': ['output_format'],
    'full': ['persona', 'instruction', 'input_format', 'label_defns', 'output_format'],
}

# CUMULATIVE: stacking components one by one
CUMULATIVE_CONFIGS = {
    'minimal': [],
    '+persona': ['persona'],
    '+instruction': ['persona', 'instruction'],
    '+input_format': ['persona', 'instruction', 'input_format'],
    '+label_defns': ['persona', 'instruction', 'input_format', 'label_defns'],
    '+output_format (full)': ['persona', 'instruction', 'input_format', 'label_defns', 'output_format'],
}


def build_custom_sys_prompt(component_keys):
    """Build system prompt from list of component keys."""
    if not component_keys:
        return None  # No system prompt
    components = [PROMPT_COMPONENTS[k] for k in component_keys]
    return "\n".join(components)


def build_zero_shot_messages_custom(thread_context, sys_prompt=None):
    """Build zero-shot messages with optional custom system prompt."""
    messages = []
    if sys_prompt is not None:
        messages.append({"role": "system", "content": sys_prompt})
    messages.append({"role": "user", "content": build_user_prompt(thread_context)})
    return messages


def evaluate_with_custom_prompt(pipe, df, sys_prompt=None, batch_size=BATCH_SIZE, verbose=False):
    """Evaluate zero-shot with a custom system prompt. Returns macro F1."""
    # Prepare all messages
    all_messages = []
    for _, row in df.iterrows():
        input_text = format_input_with_context(row, df, use_features=False, use_context=USE_CTX)
        messages = build_zero_shot_messages_custom(input_text, sys_prompt)
        all_messages.append(messages)
    
    # Batched inference
    raw_responses = []
    for i in range(0, len(all_messages), batch_size):
        batch = all_messages[i:i+batch_size]
        outputs = pipe(batch, max_new_tokens=10, pad_token_id=pipe.tokenizer.eos_token_id)
        for out in outputs:
            raw_responses.append(out[0]["generated_text"][-1]["content"].strip())
    
    # Parse responses
    predictions = [parse_stance_response(r) for r in raw_responses]
    predictions = [p if p is not None else 'comment' for p in predictions]
    
    true_labels = df['label_text'].tolist()
    macro_f1 = f1_score(true_labels, predictions, average='macro')
    
    if verbose:
        print(f"Macro F1: {macro_f1:.4f}")
    
    return macro_f1


def run_ablation_study(pipe, df, configs, desc="Ablation"):
    """Run ablation study with given configurations. Returns dict of {config_name: macro_f1}."""
    results = {}
    
    for config_name, component_keys in tqdm(configs.items(), desc=desc):
        sys_prompt = build_custom_sys_prompt(component_keys)
        macro_f1 = evaluate_with_custom_prompt(pipe, df, sys_prompt)
        results[config_name] = macro_f1
        print(f"  {config_name}: {macro_f1:.4f}")
    
    return results


def _get_ablation_color(config_name):
    """Get bar color based on config name using seaborn palette."""
    if config_name == 'minimal':
        return COMPONENT_COLOURS['minimal']
    elif 'full' in config_name:
        return COMPONENT_COLOURS['full']
    # Extract component name (e.g., '+persona' -> 'persona')
    for comp in COMPONENT_COLOURS:
        if comp in config_name:
            return COMPONENT_COLOURS[comp]
    return COMPONENT_COLOURS['minimal']


def plot_ablation(results, title, save_path=None, show_delta=False):
    """Plot horizontal bar chart for ablation results using seaborn."""
    # Create DataFrame for seaborn
    df_plot = pd.DataFrame({
        'config': list(results.keys()),
        'score': list(results.values())
    })
    df_plot['color'] = [_get_ablation_color(c) for c in df_plot['config']]
    palette = dict(zip(df_plot['config'], df_plot['color']))
    
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(data=df_plot, y='config', x='score', hue='config', palette=palette, 
                ax=ax, edgecolor='white', legend=False)
    
    ax.set_xlabel('Macro F1 Score')
    ax.set_ylabel('')
    ax.set_title(f'{title}\n(Zero-Shot on Dev Set)')
    ax.set_xlim(0, df_plot['score'].max() * 1.15)
    
    # Add value labels on bars
    for i, (score, config) in enumerate(zip(df_plot['score'], df_plot['config'])):
        ax.text(score + 0.005, i, f'{score:.3f}', va='center', fontsize=10)
    
    # Add baseline reference line (minimal)
    if 'minimal' in results:
        ax.axvline(x=results['minimal'], color=COMPONENT_COLOURS['minimal'], 
                   linestyle='--', alpha=0.7, linewidth=1.5, label='Minimal baseline')
        ax.legend(loc='lower right')
    
    # Add delta annotations for cumulative
    if show_delta:
        scores = list(results.values())
        for i in range(1, len(scores)):
            delta = scores[i] - scores[i-1]
            sign = '+' if delta >= 0 else ''
            ax.text(scores[i] + 0.04, i - 0.25, f'{sign}{delta:.3f}', 
                    fontsize=9, color='dimgray', style='italic')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()
    return fig


def plot_isolated_ablation(results, save_path=None):
    """Plot isolated ablation (minimal + ONE component)."""
    return plot_ablation(results, 'System Prompt Ablation: Isolated Component Contribution', 
                         save_path=save_path, show_delta=False)


def plot_cumulative_ablation(results, save_path=None):
    """Plot cumulative ablation (stacking components)."""
    return plot_ablation(results, 'System Prompt Ablation: Cumulative Component Stacking', 
                         save_path=save_path, show_delta=True)

## Few-shot

In [30]:
# few-shot
few_shot_examples = get_few_shot_examples(train_df, n_per_class=1)
few_results = evaluate_prompting(model, use_df, mode="few-shot", examples=few_shot_examples)

Evaluating (few-shot):   0%|          | 0/281 [00:00<?, ?it/s]


Results (few-shot)
Macro F1: 0.4222

Per-class F1:
  support: 0.4118
  deny: 0.3478
  query: 0.3680
  comment: 0.5612

              precision    recall  f1-score   support

     support       0.42      0.41      0.41        69
        deny       0.33      0.36      0.35        11
       query       0.24      0.82      0.37        28
     comment       0.74      0.45      0.56       173

    accuracy                           0.47       281
   macro avg       0.43      0.51      0.42       281
weighted avg       0.60      0.47      0.50       281

Saved predictions: ./results/prompting/predictions_few-shot.csv


## CoT Prompting

In [None]:
# CoT prompting

cot_results = evaluate_prompting(model, use_df, mode="cot", examples=None)


Evaluating (cot):   0%|          | 0/281 [00:00<?, ?it/s]

In [43]:
# test a single eg
row = dev_df.iloc[0]
input_text = format_input_with_context(row, dev_df, use_features=False, use_context=USE_CTX)
messages = build_cot_messages(input_text, None)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response = model(prompt, COT_REGEX, max_new_tokens=100)
response

In [44]:
response

'Analyzing the tweet, [TARGET] is providing factual information with a link to a news source, which is a common way to verify the accuracy of a news story. The tone is informative and neutral, without expressing any opinion or bias. Therefore, the reply does not directly support, deny, or query the veracity of the source claim, but rather presents the information as is. Label: COMMENT'

In [45]:
row['label_text']

'support'

In [46]:
row['text']

'Heart goes out to 148 passengers and crew of Germanwings Airbus A320 that has crashed in French Alps, Southern France http://t.co/K7fmJLRt4G'