In [1]:
from stance_gator.data_modules import StanceCorpus, StanceDataModule
from stance_gator.sent_module import SentModule
from stance_gator.torch_utils import load_module
from stance_gator.constants import TriStance
from lightning.fabric.utilities.apply_func import move_data_to_device
import numpy as np

In [2]:
import html
from IPython.display import HTML, display

In [3]:
# ckpt_path = '/home/ethanlmines/blue_dir/experiments/lightning_logs/30MayNesyStance/checkpoints/epoch=03-val_macro_f1=0.761.ckpt'
ckpt_path = '/home/ethanlmines/blue_dir/experiments/lightning_logs/31MayQuarterTemp/checkpoints/epoch=03-val_macro_f1=0.757.ckpt'
sent_mod: SentModule  = load_module(ckpt_path)

In [4]:
data_mod = StanceDataModule(
    [StanceCorpus(
        path="/home/ethanlmines/blue_dir/datasets/VAST/vast_zero_dev.csv",
        corpus_type='vast',
        data_ratio=(0, 0, 1)
    )]
)
data_mod.encoder = sent_mod.encoder
data_mod.setup('predict')

Parsing /home/ethanlmines/blue_dir/datasets/VAST/vast_zero_dev.csv: 1019it [00:00, 2022.42it/s]


In [8]:
sent_mod.eval().to('cuda')
tokenizer = sent_mod.encoder.tokenizer
i = 0

def select_from_mask(tensor, mask):
    return [el for el,unmasked in zip(tensor, mask) if unmasked]

def ids_to_html_str(ids):
    return html.escape( tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)) )

for batch in data_mod.test_dataloader():
    batch = move_data_to_device(batch, sent_mod.device)
    labels = batch.pop('labels')
    output = sent_mod(**batch)


    iterator = zip(
                    batch['context_mask'].detach().tolist(),
                    batch['context']['input_ids'].detach().tolist(),
                    batch['target']['input_ids'].detach().tolist(),
                    output.stance_prob.detach().cpu().numpy(),
                    output.summands.detach().cpu().numpy(),
                    output.attention.detach().tolist(),
                    output.token_sents.detach().tolist()
    )
    for mask, id_list, target_id_list, stance_prob_dist, summand_list, attention_vec, token_sent in iterator:

        context_str = ids_to_html_str(id_list)
        target_str = ids_to_html_str(target_id_list)

        prediction = np.argmax(stance_prob_dist)
        prediction_prob = stance_prob_dist[prediction]
        summand_list = summand_list[:, prediction]
        top_tok_indices = np.flip(np.argsort(summand_list))

        token_list = tokenizer.convert_ids_to_tokens(id_list)
        N = min(len(token_list), 10)

        class_str = TriStance(prediction).name
        html_toks = []
        html_toks.append(f'<p> <strong>Target</strong>: {target_str} </p>')
        html_toks.append(f'<p> <strong>Document</strong>: {context_str} </p>')
        html_toks.append(f'<p> <strong>Prediction</strong>: P({class_str}) = {prediction_prob:.6f} </p>')

        html_toks.append("<table>")
        html_toks.append(f'<thead> <tr> <th>Token</th> <th>Relevance to Target (Attention)</th> <th>P({class_str}|Token)</th> <th>Total Contribution</th> </tr> </thead>')
        html_toks.append("<tbody>")

        for ind in top_tok_indices[:N]:
            html_toks.append('<tr>')
            html_toks.append(f'<td>{token_list[ind]}</td>')
            html_toks.append(f'<td>{attention_vec[ind]:.6f}</td>')
            html_toks.append(f'<td>{token_sent[ind][prediction]:.6f}</td>')
            html_toks.append(f'<td>{summand_list[ind]:.6f}</td> </tr>')
            html_toks.append('</tr>')
        html_toks.append("</tbody>")
        html_toks.append("</table>")
        html_str = "".join(html_toks)

        display(HTML(html_str))
        exit = input("Press enter to continue (q to quit)").lower() == 'q'
        if exit:
            break
    if exit:
        break

Token,Relevance to Target (Attention),P(favor|Token),Total Contribution
special,0.021876,0.838442,0.018342
effects,0.022261,0.702112,0.01563
c,0.018786,0.80357,0.015096
c,0.014697,0.93669,0.013766
c,0.014501,0.863921,0.012528
themselves,0.022736,0.511724,0.011634
##gi,0.013411,0.834229,0.011188
hyper,0.020384,0.542922,0.011067
-,0.016272,0.603252,0.009816
enhanced,0.025046,0.38491,0.00964


Token,Relevance to Target (Attention),P(favor|Token),Total Contribution
fit,0.046822,0.911162,0.042663
see,0.036889,0.915169,0.03376
really,0.034388,0.905393,0.031134
i,0.033429,0.909294,0.030397
me,0.034077,0.843914,0.028758
what,0.034463,0.802786,0.027666
##cta,0.03658,0.745649,0.027276
i,0.02795,0.906735,0.025343
as,0.026899,0.919415,0.024731
where,0.026483,0.921976,0.024417


Token,Relevance to Target (Attention),P(favor|Token),Total Contribution
usual,0.110427,0.906712,0.100125
reservation,0.095905,0.826215,0.079239
spots,0.070051,0.766134,0.053669
own,0.051695,0.813491,0.042053
i,0.034968,0.971382,0.033967
go,0.028963,0.940204,0.027232
cook,0.027518,0.909484,0.025028
my,0.026555,0.935288,0.024837
bring,0.025789,0.863347,0.022265
my,0.02272,0.917172,0.020838


Token,Relevance to Target (Attention),P(favor|Token),Total Contribution
professors,0.051495,0.793829,0.040878
professors,0.041608,0.870402,0.036216
has,0.043564,0.745276,0.032467
work,0.029615,0.962042,0.028491
##p,0.027598,0.941754,0.02599
students,0.030444,0.835242,0.025428
it,0.027936,0.897242,0.025066
the,0.025728,0.917597,0.023608
class,0.026872,0.78415,0.021072
semester,0.050516,0.400243,0.020219


Token,Relevance to Target (Attention),P(favor|Token),Total Contribution
organic,0.176518,0.676981,0.119499
grown,0.143932,0.805783,0.115978
organic,0.130318,0.750914,0.097858
natural,0.12302,0.718798,0.088427
it,0.090176,0.93866,0.084645
cultivated,0.102926,0.754031,0.077609
natural,0.030612,0.801144,0.024525
food,0.03995,0.445075,0.017781
fda,0.017731,0.782079,0.013867
them,0.008217,0.968236,0.007956


Token,Relevance to Target (Attention),P(neutral|Token),Total Contribution
be,0.039267,0.999426,0.039245
.,0.039106,0.999435,0.039084
.,0.039,0.999446,0.038978
the,0.03878,0.999471,0.038759
to,0.038563,0.999474,0.038542
an,0.038519,0.999488,0.038499
to,0.038513,0.999492,0.038493
",",0.038468,0.999484,0.038448
and,0.038435,0.99946,0.038414
with,0.038382,0.999476,0.038362


Token,Relevance to Target (Attention),P(against|Token),Total Contribution
fascism,0.10639,0.980542,0.104319
christians,0.109972,0.900296,0.099008
religious,0.104391,0.82426,0.086045
muslims,0.069402,0.943662,0.065493
do,0.062708,0.94339,0.059158
freedom,0.063281,0.833488,0.052744
##te,0.059724,0.87581,0.052307
el,0.060665,0.838385,0.050861
##eva,0.055189,0.800674,0.044188
law,0.048828,0.766025,0.037404


In [11]:
# html_str = '<table><tbody> <tr> <td colspan="2">Doof</td> </tr> <tr> <td>Hi</td> <td>Yo</td> <td>Wassup</td> </tbody></table>'
display(HTML(html_str))

Token,Att,Sent,Contribution
modest,0.0736158862709999,0.7244386076927185,0.0533301904797554
female,0.045711301267147,0.7637415528297424,0.0349116213619709
##ital,0.0442905910313129,0.7480758428573608,0.0331327207386493
male,0.04917573928833,0.5769947171211243,0.0283741410821676
sexual,0.0362986996769905,0.734909176826477,0.0266762468963861
sexuality,0.0341798439621925,0.7561362385749817,0.0258446186780929
women,0.0343526490032672,0.7336478233337402,0.0252027455717325
women,0.029122395440936,0.6735926866531372,0.0196166317909955
family,0.0314452424645423,0.5897860527038574,0.0185459647327661
male,0.0345892906188964,0.5014801621437073,0.0173458438366651


In [9]:
0.0736 * 0.724

0.0532864