In [1]:
import json
import random
import time
from pathlib import Path
import numpy as np
from collections import defaultdict, Counter
from pprint import pprint
import tqdm
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer, util
import torch


def read_event_dataset(path):
    with open(path) as f:
        dataset = []
        for line in list(f)[1:]:
            id, text, label = line.strip().split("\t")
            item = {
                "id": id, "text": text, "label": label
            }
            dataset.append(item)
    return dataset

In [2]:
from sentence_transformers import SentenceTransformer, util


class ZeroShotClassifier:
    
    def __init__(self, model=None, threshold=0.0, null_label="OTHER"):
        self.model = model
        self.labels = []
        self.label_embeddings = None
        self.threshold = threshold
        self.null_label = null_label
    
    def train(self, labels, descriptions):
        self.labels = labels
        self.label_embeddings = model.encode(descriptions)
    
    def predict(self, input_texts=None, input_embeddings=None, output_scores=False):
        if input_embeddings is None:
            input_embeddings = self.model.encode(input_texts)
            
        S = util.pytorch_cos_sim(input_embeddings, self.label_embeddings)
        
        predicted_labels = []
        predicted_scores = []
        for i in range(input_embeddings.shape[0]):
            label_scores = S[i].tolist()
            scored = sorted(
                zip(self.labels, label_scores),
                key=lambda x: x[1],
                reverse=True
            )
            pred, score = scored[0]
            if score < self.threshold:
                pred = self.null_label
                
            predicted_scores.append(scored)
            predicted_labels.append(pred)        
        
        if output_scores:
            return predicted_labels, predicted_scores
        else:
            return predicted_labels

In [3]:
DIR = Path("data")

In [4]:
dataset = read_event_dataset(DIR / "test_set_final_release_with_labels.tsv")
texts = [x["text"] for x in dataset]
y_true = [x["label"] for x in dataset]

In [5]:
with open(DIR / "acled_label_to_name.json") as f:
    label_to_text = json.load(f)
    
label_names = sorted(label_to_text)
label_texts = [label_to_text[l] for l in label_names]

In [6]:
ZS_LABELS = ["ORG_CRIME", "NATURAL_DISASTER", "MAN_MADE_DISASTER", "DIPLO", "ATTRIB"]

In [7]:
model = SentenceTransformer("paraphrase-mpnet-base-v2", device="cpu")

In [8]:
zs_classifier = ZeroShotClassifier(model=model)

zs_classifier.train(labels=label_names, descriptions=label_texts)

In [9]:
input_embeddings = model.encode(texts)

KeyboardInterrupt: 

In [None]:
input_embeddings.shape

In [27]:
# predicted_labels = zs_classifier.predict(input_texts=texts)
predicted_labels = zs_classifier.predict(input_embeddings=input_embeddings)

In [32]:
evaluate(y_true, predicted_labels)

micro    precision: 0.520, recall: 0.520, f-score: 0.520
macro    precision: 0.528, recall: 0.495, f-score: 0.461
weighted precision: 0.569, recall: 0.520, f-score: 0.489


## Build Your Own Zero-Shot Classifier

In [10]:
my_classifier = ZeroShotClassifier(model=model, threshold=0.3, null_label="OTHER")
my_classifier.train(
    labels=["EARTHQUAKE", "WILDFIRE", "FLOODS"],
    descriptions=["earthquake", "wildfire", "floods"]
)

In [11]:
my_classifier.predict([
    "Death toll from Hurricane Ida floods rises to 65 in US",
    "As California burns, some ecologists say it’s time to rethink forest management",
    "Maharashtra: Tremor in Kolhapur, no casualty",
    "Leaked Guntrader firearms data file shared. Worst case scenario?",
    "Taliban take control of last holdout in Panjshir Valley"
])

['FLOODS', 'WILDFIRE', 'EARTHQUAKE', 'OTHER', 'OTHER']