In [16]:
from stance_gator.data_modules import StanceCorpus, StanceDataModule
from stance_gator.sent_module import SentModule
from stance_gator.torch_utils import load_module
from stance_gator.constants import TriStance
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from lightning.fabric.utilities.apply_func import move_data_to_device
import numpy as np

In [15]:
import html
from IPython.display import HTML, display

In [3]:
sent_mod: SentModule  = load_module('/home/ethanlmines/blue_dir/experiments/lightning_logs/30MayNesyStance/checkpoints/epoch=03-val_macro_f1=0.761.ckpt')

In [4]:
data_mod = StanceDataModule(
    [StanceCorpus(
        path="/home/ethanlmines/blue_dir/datasets/VAST/vast_zero_dev.csv",
        corpus_type='vast',
        data_ratio=(0, 0, 1)
    )]
)
data_mod.encoder = sent_mod.encoder
data_mod.setup('predict')

Parsing /home/ethanlmines/blue_dir/datasets/VAST/vast_zero_dev.csv: 0it [00:00, ?it/s]

Parsing /home/ethanlmines/blue_dir/datasets/VAST/vast_zero_dev.csv: 1019it [00:00, 2043.76it/s]


In [23]:
sent_mod.eval().to('cuda')
tokenizer = sent_mod.encoder.tokenizer
i = 0

def select_from_mask(tensor, mask):
    return [el for el,unmasked in zip(tensor, mask) if unmasked]

def ids_to_html_str(ids):
    return html.escape( tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)) )

for batch in data_mod.test_dataloader():
    batch = move_data_to_device(batch, sent_mod.device)
    labels = batch.pop('labels')
    output = sent_mod(**batch)

    # attention_vecs = output.attention.detach().tolist()
    stance_prob_dists = output.stance_prob.detach().cpu().numpy()
    summand_lists = output.summands.detach().cpu().numpy()
    target_id_lists = batch['target']['input_ids'].detach().tolist()
    id_lists = batch['context']['input_ids'].detach().tolist()
    # token_sent_lists = output.token_sents.detach().tolist()

    masks = batch['context_mask'].detach().tolist()
    # token_lists = [tokenizer.convert_ids_to_tokens(id_list) for id_list in ]

    for mask, id_list, target_id_list, stance_prob_dist, summand_list in zip(masks, id_lists, target_id_lists, stance_prob_dists, summand_lists):

        context_str = ids_to_html_str(id_list)
        target_str = ids_to_html_str(target_id_list)

        token_list = select_from_mask(tokenizer.convert_ids_to_tokens(id_list), mask)
        summand_list = summand_list[mask]

        prediction = np.argmax(stance_prob_dist)
        prediction_prob = stance_prob_dist[prediction]
        summand_list = summand_list[:, prediction]
        top_tok_indices = np.flip(np.argsort(summand_list))

        N = min(len(token_list), 10)

        top_toks = [token_list[ind] for ind in top_tok_indices[:N]]
        prob_mass = [summand_list[ind] for ind in top_tok_indices[:N]]
        
        html_toks = []


        html_toks.append(f'<p> <strong>Document</strong>: {context_str} </p>')
        html_toks.append(f'<p> <strong>Target</strong>: {target_str} </p>')
        html_toks.append(f'<p> <strong>Prediction</strong>: P({TriStance(prediction).name}) = {prediction_prob} </p>')

        html_toks.append("<table>")
        html_toks.append(f'<thead> <tr> <th>Token</th> <th>Contribution</th> </tr> </thead>')
        html_toks.append("<tbody>")
        for tok, mass_val in zip(top_toks, prob_mass):
            html_toks.append(f'<tr> <td>{tok}</td> <td>{mass_val}</td> </tr>')
        html_toks.append("</tbody>")
        html_toks.append("</table>")
        html_str = "".join(html_toks)

        break
    break

In [24]:
# html_str = '<table><tbody> <tr> <td colspan="2">Doof</td> </tr> <tr> <td>Hi</td> <td>Yo</td> <td>Wassup</td> </tbody></table>'
display(HTML(html_str))

Token,Contribution
guy,0.0263575501739978
house,0.0167811382561922
70,0.0149161564186215
hardwood,0.013802234083414
vintage,0.0129745919257402
type,0.0129716517403721
brilliant,0.0129654742777347
a,0.0121612772345542
he,0.0119147608056664
another,0.0117258103564381
