In [13]:
from transformers import pipeline
import torch


In [14]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("notaphoenix/shakespeare_classifier_model")

model = AutoModelForSequenceClassification.from_pretrained("notaphoenix/shakespeare_classifier_model")

In [21]:
def predict(text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_class_id = logits.argmax().item()
    label = model.config.id2label[predicted_class_id]
    return (predicted_class_id, label) 

In [22]:
import evaluate
from datasets import load_dataset


In [56]:

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
f1_metric = evaluate.load("f1")
accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")

In [33]:
def get_pred_gold(split="test"):
    ds = load_dataset("notaphoenix/shakespeare_dataset", split=split)
    predicted = [predict(x['text'])[0] for x in ds]
    gold = [x['label'] for x in ds]
    return predicted, gold

In [61]:
predicted, gold= get_pred_gold(split="validation")

print(
    f"{clf_metrics.compute(predicted, gold, average='macro')}\n"
    f"macro-f1: {round(f1_metric.compute(predictions=predicted, references=gold, average='macro')['f1'], 2)}\n"
    f"{precision.compute(predictions=predicted, references=gold, average='macro')}\n"
    f"{accuracy.compute(predictions=predicted, references=gold)}\n"
)

Using custom data configuration notaphoenix--shakespeare_dataset-7d26b19ec4f377f7
Found cached dataset parquet (/home/elba_ro/.cache/huggingface/datasets/notaphoenix___parquet/notaphoenix--shakespeare_dataset-7d26b19ec4f377f7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


{'accuracy': 0.8391608391608392, 'f1': 0.8270676691729323, 'precision': 0.8823529411764706, 'recall': 0.7783018867924528}
macro-f1: 0.84
{'precision': 0.8440690325717064}
{'accuracy': 0.8391608391608392}



In [62]:
predicted, gold= get_pred_gold(split="test")

print(
    f"{clf_metrics.compute(predicted, gold, average='macro')}\n"
    f"macro-f1: {round(f1_metric.compute(predictions=predicted, references=gold, average='macro')['f1'], 2)}\n"
    f"{precision.compute(predictions=predicted, references=gold, average='macro')}\n"
    f"{accuracy.compute(predictions=predicted, references=gold)}\n"
)

Using custom data configuration notaphoenix--shakespeare_dataset-7d26b19ec4f377f7
Found cached dataset parquet (/home/elba_ro/.cache/huggingface/datasets/notaphoenix___parquet/notaphoenix--shakespeare_dataset-7d26b19ec4f377f7/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


{'accuracy': 0.8666044776119403, 'f1': 0.8596663395485771, 'precision': 0.906832298136646, 'recall': 0.8171641791044776}
macro-f1: 0.87
{'precision': 0.8702242984740955}
{'accuracy': 0.8666044776119403}

