In [1]:
from huggingface_hub import hf_hub_download
import torch
import sys
sys.path.append("src")
from model import CustomModel,MeanPooling
from model import CFG_base as CFG
from transformers import AutoTokenizer, AutoModel, AutoConfig
import numpy as np

In [2]:
name = "OrK7/parler_hate_speech"
downloaded_model_path = hf_hub_download(repo_id=name, filename="pytorch_model.bin")
model = torch.load(downloaded_model_path)
tokenizer = AutoTokenizer.from_pretrained(name)

In [3]:
def prepare_input(text):
    inputs = tokenizer.encode_plus(
        text, 
        return_tensors=None, 
        add_special_tokens=True, 
        max_length=512,
        pad_to_max_length=True,
        truncation=True
    )
    for k, v in inputs.items():
        inputs[k] = torch.tensor(np.array(v).reshape(1,-1), dtype=torch.long)
    return inputs

def collate(inputs):
    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
    for k, v in inputs.items():
        inputs[k] = inputs[k][:,:mask_len]
    return inputs

In [4]:
from transformers import Pipeline
class HatePipeline(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs):
        out = prepare_input(inputs)
        return collate(out)

    def _forward(self, model_inputs):
        outputs = self.model(model_inputs)
        return outputs

    def postprocess(self, model_outputs):
        return np.array(model_outputs[0,0].numpy()).clip(0,1)*4+1

In [5]:
pipe = HatePipeline(model=model)
pipe("I Hate #$%#$%Jewish%$#@%^^@#")



4.155198812484741

In [6]:
pipe("I Love you #")

1.0

In [7]:
pipe("It should be easy to do............itâs just a big mouth and two tits!!!! Wendy Williams is getting a Madame Tussauds wax figure")

3.473475933074951

In [9]:
pipe("Report: Google A.I. Is ‘Biased’ Against Gay People, Jews  http://www.breitbart.com/tech/2017/10/26/report-go...")

1.5166422724723816

In [10]:
pipe("fuck Gay People, Jews  http://www.breitbart.com/tech/2017/10/26/report-go...")

4.308754205703735