In [86]:
import torch
import json
import random
import re
import ast
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report

In [150]:
# load the sentences and correct labels
all_sentences = []

with open("../data/annotations.json", "r") as f:
    data = json.load(f)

for sentence in data:
    text = sentence["data"]["sentence"]
    labels = []
    results = sentence["annotations"][0]["result"]
    labels = [r["value"]["text"] for r in results]
    
    all_sentences.append({
        "text": text,
        "labels": labels
    })

In [57]:
# load the model
checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

In [165]:
# define chat template for normal inference
def compile_ner_prompt(few_shot_examples, test_sentence):
    chat = [
          {"role": "system",
           "content": (
                "You are a helpful assistant that extracts social group mentions from text.\n"
                "Definition of a social group: A social group is a collective of people with common socio-demographic characteristics "
                "(e.g., students, migrants, teachers, women, workers). It can also be formed by shared values or life experiences. "
                "Institutions or institutional groupings do not count as social groups. However, if the sentence refers to the people "
                "of an institution (e.g., 'the patients in the hospital'), this does count as a social group.\n\n"
                "Your task is to extract all social group mentions from a given sentence.\n"
                "Collect all social group mentions into a single list. If there are several group mentions, this list will have several entries.\n"
                "If there are no social group mentions in the sentence, respond with 'None'."
                )
        }
    ]
    # add few-shot examples
    for example in few_shot_examples:
        context = example["text"]
        if example["labels"]:
            answer = str(example["labels"])
        else:
            answer = "None"
        chat.append(
            {"role": "user", "content": f"Sentence: {context}"})
        chat.append({"role": "assistant", "content": answer})
    
    # add the test sentence
    chat.append(
        {"role": "user", "content": f"Sentence: {test_sentence}"})

    # compile the prompt
    prompt = tokenizer.apply_chat_template(
    chat, return_tensors="pt", tokenize=False, add_generation_prompt=True)
    return prompt

In [163]:
# create some few-shot examples
non_empty_examples = [ex for ex in all_sentences if ex["labels"]]
empty_examples = [ex for ex in all_sentences if not ex["labels"]]
few_shot_examples = random.sample(non_empty_examples, 4) + random.sample(empty_examples, 1)

# create test dataset
split_idx = int(len(non_empty_examples)*0.75)
test_dataset = non_empty_examples[split_idx:] + random.sample(empty_examples, int(len(empty_examples)*0.2))
random.shuffle(test_dataset)

In [186]:
def to_list_or_empty(entry):
    try:
        val = ast.literal_eval(entry)
        if isinstance(val, list):
            return val
        else:
            return []
    except (ValueError, SyntaxError):
        return []

In [193]:
# generate the answers for the normal format and store in a list
gen_answers = []
for i in range(len(test_dataset)):
    sentence = test_dataset[i]["text"]
    prompt = compile_ner_prompt(few_shot_examples, sentence)
    prompt_ids = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    outputs = model.generate(**prompt_ids)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens = True)
    answer = generated_text.split("assistant\n")[-1]
    answer_list = to_list_or_empty(answer)
    gen_answers.append(answer_list)

In [206]:
# evaluate the generated answers
results = []

for idx in range(len(test_dataset)):
    ground_truth = test_dataset[idx]["labels"]
    prediction = gen_answers[idx]
    results.append({"labels": ground_truth,
                    "prediction": prediction})

def evaluate_predictions(results):
    y_true = []
    y_pred = []

    for example in results:
        gold_mentions = set([m.lower().strip() for m in example["labels"]])
        pred_mentions = set([m.lower().strip() for m in example["prediction"]])

        for mention in gold_mentions:
            y_true.append(1)
            y_pred.append(1 if mention in pred_mentions else 0)

        for mention in pred_mentions:
            if mention not in gold_mentions:
                y_true.append(0)
                y_pred.append(1)

    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    return precision, recall, f1

p, r, f1 = evaluate_predictions(results)

print(f"Precision: {p:.4f} \nRecall: {r:.4f} \nF1: {f1:.4f}")

Precision: 0.1509 
Recall: 0.4211 
F1: 0.2222


In [204]:
# print some examples
for idx in range(5):
    print(test_dataset[idx]["text"])
    print(results[idx]["labels"])
    print(results[idx]["prediction"])
    print("-"*80)

Looking forward to discussing top issues later today on Politics Live
[]
['Politics Live']
--------------------------------------------------------------------------------
Constituency offices wholly paid for by the taxpayer can't be used for party activity (some split rent/space).
['taxpayer']
['party activity']
--------------------------------------------------------------------------------
11th December

DAY 11 of my Digital Advent Calendar - I’m featuring Depaul UK

They provide accommodation & support for homeless young people aged 16-25, as well as employment workshops to help with CVs and applying for jobs, colleges & university.
['homeless young people']
[]
--------------------------------------------------------------------------------
To mark , I joined the Maternal Mental Health Alliance to highlight the need for better perinatal mental health support.
[]
['Maternal Mental Health Alliance']
--------------------------------------------------------------------------------
(2/2