In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, PretrainedConfig

import pandas as pd
import torch
import torch.nn.functional as F

In [None]:
data = pd.read_csv("../data/raw/wndp_api.csv")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("wndp-exp/checkpoint-1500/")
model = AutoModelForSequenceClassification.from_pretrained("wndp-exp/checkpoint-1500/")
config = PretrainedConfig.from_pretrained("wndp-exp/checkpoint-1500/")

In [None]:
sample = "found on the ground by window - breathing hard, eyes not open, couldn't stand up, ants covering him, some spazmotic movements of leg, wing, seemed better today. emaciated fledgling with torticollis. Neurologic: torticollis Legs / Feet / Hocks: not using legs. poor prognosis given age, emaciation, and degree of debilitation"

In [None]:
device = torch.device("cuda")
model = model.to(device)

In [None]:
def _pre(text, device):
    tokens = tokenizer(text, return_tensors="pt")
    tokens = {k: v.to(device) for k,v in tokens.items()}
    return tokens

def _inf(model, tokens):
    out = model(**tokens)
    return out

def _post(out):
    probs = F.sigmoid(out.logits.squeeze().detach().cpu())
    preds = (probs > 0.5).int()
    labels = [config.id2label[idx] for idx, label in enumerate(preds) if label == 1.0]
    return labels

def infer(model, text):
    tokens = _pre(text, model.device)
    out = _inf(model, tokens)
    pred = _post(out)
    return pred

In [None]:
infer(model, sample)

In [None]:
for _,row in data.head(20).iterrows():
    print(row["text"])
    print("actuals: ", row["terms"])
    print("predics: ", infer(model, row["text"]))
    print("="*20)