In [1]:
from stance_gator.sent_module import SentModule
from stance_gator.f1_calc import F1Calc
from stance_gator.constants import TriStance
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [2]:
torch.set_default_device('cuda')

In [3]:
ckpt_path = "/home/ethanlmines/blue_dir/experiments/lightning_logs/29MaySentMod/checkpoints/epoch=02-val_macro_f1=0.757.ckpt"
ckpt = torch.load(ckpt_path)


In [4]:
sent_mod = SentModule()
sent_mod.load_state_dict(ckpt['state_dict'])
sent_mod.eval()
tokenizer = sent_mod.encoder.tokenizer

In [5]:
hf_dataset = load_dataset("SetFit/sst5")

Repo card metadata block was not found. Setting CardData to empty.


In [6]:
class MyDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __getitem__(self, index):
        return self.samples[index]
    def __len__(self):
        return len(self.samples)

In [7]:
label_map = {"very positive": TriStance.favor,
             "positive": TriStance.favor,
             "neutral": TriStance.neutral,
             "negative": TriStance.against,
             "very negative": TriStance.against}
encoded = []
for sample in hf_dataset['train']:
    encoding = tokenizer(text=sample['text'], return_special_tokens_mask=True, return_tensors='pt')
    encoding['token_type_ids'] = torch.ones_like(encoding['input_ids'])

    special_tokens_mask = encoding.pop('special_tokens_mask')
    context_mask = torch.logical_not(special_tokens_mask)

    encoded.append({
        "context": encoding,
        "context_mask": context_mask,
        "labels": torch.tensor(label_map[sample['label_text']])
    })
dataset = MyDataset(encoded)

In [8]:
calc = F1Calc(TriStance.label2id())

In [9]:
loader = DataLoader(dataset,
                    batch_size=8,
                    shuffle=False,
                    collate_fn=sent_mod.encoder.collate)

In [10]:
sent_mod.to("cuda")
for batch in loader:
    labels = batch.pop('labels')
    preds = sent_mod.predict_sent(**batch)
    calc.record(preds, labels)
calc.summarize()

In [12]:
calc.results

{'neutral_precision': tensor(0.2013, device='cuda:0'),
 'neutral_recall': tensor(0.5776, device='cuda:0'),
 'neutral_f1': tensor(0.2985, device='cuda:0'),
 'against_precision': tensor(0.5433, device='cuda:0'),
 'against_recall': tensor(0.4529, device='cuda:0'),
 'against_f1': tensor(0.4940, device='cuda:0'),
 'favor_precision': tensor(0.8187, device='cuda:0'),
 'favor_recall': tensor(0.2551, device='cuda:0'),
 'favor_f1': tensor(0.3890, device='cuda:0'),
 'macro_f1': tensor(0.3938, device='cuda:0')}