In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

# Use this if running this notebook from within its place in the truera repository.
sys.path.insert(0, "../..")

# Install transformers / huggingface.
# ! {sys.executable} -m pip install torch
# ! {sys.executable} -m pip install transformers
# ! {sys.executable} -m pip install mkl
# ! {sys.executable} -m pip install vision

# Or otherwise install trulens.
# TODO: UPDATE TO GET THE CORRECT TRULENS:
# ! {sys.executable} -m pip install git+https://github.com/truera/trulens.git
# ! {sys.executable} -m pip uninstall trulens -y

from IPython.display import display, clear_output
import torch
import numpy as np
from pathlib import Path
import base64
from tqdm.auto import tqdm
import re
from typing import List, Tuple, Dict
import functools
from ipywidgets import widgets, interactive, interact
import os
import multiprocessing as mp
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

os.environ['TOKENIZERS_PARALLELISM'] = '0'

# Twitter Sentiment Model

[Huggingface](https://huggingface.co/models) offers a variety of pre-trained NLP models to explore. We exemplify in this notebook a [transformer-based twitter sentiment classification model](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment). Before getting started, familiarize yourself with the general Truera API as demonstrated in the [intro notebook using pytorch](intro_demo_pytorch.ipynb).

In [None]:
# AIQ: Talk about how this is required to update. Homework. Try out your own model.
# TODO: Homework part 2.

import transformers as hugs
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

import numpy.typing as npt


# Wrap all of the necessary components.
class TwitterSentiment:
    #device = 'cpu'
    # Can also use cuda if available:
    device: str = 'cuda:0'

    # model name, see https://huggingface.co/models for others
    MODEL = f"distilbert-base-uncased-finetuned-sst-2-english"

    # model
    model: hugs.PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
        MODEL
    ).to(device)

    # tokenizer
    tokenizer: hugs.PreTrainedTokenizer = AutoTokenizer.from_pretrained(MODEL)

    # the embeddings vectors, one for each token
    embeddings: npt.NDArray[np.float32] = \
        model.distilbert.embeddings.word_embeddings.weight.detach().cpu().numpy()

    # name of the layer that produces token embeddings
    embeddings_layer: str = 'distilbert_embeddings_word_embeddings'

    # number of dimensions in token embedding
    embedding_size: int = embeddings.shape[1]

    # maximum number of tokens to send to model
    max_length: int = 256

    # MODEL = f"cardiffnlp/twitter-roberta-base-sentiment"
    #embeddings = model.roberta.embeddings.word_embeddings.weight.detach().cpu().numpy()
    #embeddings_layer = 'roberta_embeddings_word_embeddings'

    def tokenize(texts: List[str]):
        """
        Tokenize a list of `texts` into a form appropriate for `TwitterSentiment.model` .
        """
        return TwitterSentiment.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=TwitterSentiment.max_length,
            return_tensors='pt'
        ).to(TwitterSentiment.device)

    def token_str(token_id: int) -> str:
        """
        Given a `token_id`, produce a string of how it should be drawn.
        """
        tok = task.tokenizer.decode(token_id)
        if tok.startswith("##"):
            # token starts with "##" to denote a word postfix
            return tok[2:]
        else:
            # if not a postfix, add space better indicate a complete word separation
            return " " + tok
            
    id_of_token: Dict[str, int] = tokenizer.get_vocab()
    token_of_id: Dict[int, str] = {v: k for k, v in id_of_token.items()}

    # number of tokens in vocabulary
    vocab_size: int = len(id_of_token)

    # tokens in order
    vocab = np.array([task.token_of_id[i] for i in range(vocab_size)])

    labels = [
        'negative',
        #'neutral', # for roberta
        'positive'
    ]

    NEGATIVE = labels.index('negative')
    #NEUTRAL = labels.index('neutral') # for roberta
    POSITIVE = labels.index('positive')


task = TwitterSentiment

This model quantifies tweets (or really any text you give it) according to its sentiment: positive, negative, or neutral. Lets try it out on some examples.

In [None]:
sentences = [
    "I'm so happy!", "I'm so sad!",
    "I cannot tell whether I should be happy or sad!", "meh"
]

# Input sentences need to be tokenized first.

inputs = task.tokenize(sentences)

# The tokenizer gives us vocabulary indexes for each input token (in this case,
# words and some word parts like the "'m" part of "I'm" are tokens).

print(inputs)

# Decode helps inspecting the tokenization produced:

print(task.tokenizer.batch_decode(torch.flatten(inputs['input_ids'])))
# Normally decode would give us a single string for each sentence but we would
# not be able to see some of the non-word tokens there. Flattening first gives
# us a string for each input_id.

In [None]:
rand_embedding = np.random.normal(size=(task.embedding_size, 2))


def reduce(embs):
    return (embs @ rand_embedding)


def cb(iteration, error, embedding):
    n = len(task.embeddings)
    fig = go.FigureWidget(layout=dict(width=500, height=500))
    fig.add_scatter(
        x=embedding[:, 0],
        y=embedding[:, 1],
        # text=list(task.tokenizer.vocab.keys()),
        mode='markers',
        marker_size=2
    )
    clear_output(wait=True)
    display(fig)

# AIQ: If you have a few minutes and want to visualize the embedding space better, run this below instead:
tsne_filename = Path("tsne_embedding.test.cosine.npy")
if tsne_filename.exists():
    import pickle
    with tsne_filename.open(mode='rb') as fh:
        tsne_embedding = pickle.load(fh)

else:
    from openTSNE import TSNE
    import pickle
    man = TSNE(
        n_jobs=mp.cpu_count(),
        verbose=True,
        n_iter=10000,
        learning_rate=200,
        negative_gradient_method='bh',
        callbacks_every_iters=50,
        callbacks=[cb],
        metric="cosine",
    )
    tsne_embedding = man.fit(task.embeddings)

    with tsne_filename.open(mode='wb') as fh:
        pickle.dump(obj=tsne_embedding, file=fh)


def reduce(embs):
    return tsne_embedding.transform(embs)

In [None]:
# AIQ
# TODO: add comment on how the output is to be interpreted

aiq_layout = dict(
    border="10px solid teal", padding="5px", width="100%", margin="0px"
)

textbox = lambda t="The last part of this sentence is not a real wordle.", c=True: widgets.Text(
    value=t, continuous_update=c, layout=aiq_layout
)

@interact(text=textbox(c=False))
def show_parse(text: str):
    inputs = task.tokenize([text])

    widget_text = widgets.VBox(
        [widgets.HTML("<h1>Text --></h1>"),
         widgets.Label(text)],
        layout=dict(width="20%")
    )

    tokens = task.tokenizer.batch_decode(torch.flatten(inputs['input_ids']))

    widget_tokens = widgets.VBox(
        [
            widgets.HTML("<h1>Tokenization --></h1>"),
            widgets.HTML(str(inputs)),
            widgets.HTML(str(tokens))
        ],
        layout=dict(width="30%")
    )

    toks = inputs['input_ids'].detach().cpu().numpy()
    embs = np.array([task.embeddings[token_id] for token_id in toks])[0]

    embs_proj = reduce(embs)

    fig = go.FigureWidget(layout=dict(margin=dict(l=0,r=0,b=0,t=0)))
    fig.add_scatter(
        x=embs_proj[:, 0], y=embs_proj[:, 1], text=tokens, mode='text'
    )

    widget_embedding = widgets.VBox(
        [widgets.HTML("<h1>Embedding</h1>"), fig], layout=dict(width="50%")
    )

    display(widgets.HBox([widget_text, widget_tokens, widget_embedding]))

In [None]:
task.token_of_id.keys()

In [None]:
# AIQ: This is computationally intensive picture. It is only useful if you use the tsne reduction above.

embs_proj = tsne_embedding

n = len(task.embeddings)
fig = go.FigureWidget(layout=dict(width=1000, height=1000))
fig.add_scatter(
    x=embs_proj[0:n, 0],
    y=embs_proj[0:n, 1],
    text=task.vocab[0:n],
    mode='markers',
    marker_size=2
)
display(fig)

Evaluating huggingface models is straight-forward if we use the structure produced by the tokenizer.

In [None]:
outputs = task.model(**inputs)

print(outputs)

# From logits we can extract the most likely class for each sentence and its readable label.

predictions = [task.labels[i] for i in outputs.logits.argmax(axis=1)]

for sentence, logits, prediction in zip(sentences, outputs.logits, predictions):
    print(logits.to('cpu').detach().numpy(), prediction, sentence)

In [None]:
# AIQ
# TODO: add comment on how the output is to be interpreted

model_results = []

@interact(text=textbox())
def show_output(text):
    global model_results
    results = model_results
    inputs = task.tokenize([text])

    outputs = task.model(**inputs)

    predictions = [task.labels[i] for i in outputs.logits.argmax(axis=1)]

    results.insert(0, (predictions[0], outputs.logits.detach().cpu().numpy()[0], text))
    results = results[0:10]

    for result in results:
        print(*result)

In [None]:
def evaluate_to_probits(texts: List[str]) -> np.ndarray:
    """
    Evaluate a collection of `texts` into their probits scores.
    """

    inputs = task.tokenize(texts)

    logits = task.model(**inputs).logits

    probits = torch.nn.functional.softmax(logits, dim=1)

    return probits

evaluate_to_probits(["This is great.", "This is not great."])

# Performance

TODO

In [None]:
# AIQ: TODO: talk about dataset

from datasets import load_dataset

rotten_train = load_dataset("rotten_tomatoes", split="train")
rotten_test = load_dataset("rotten_tomatoes", split="test")
rotten_texts = rotten_train['text'] + rotten_test['text']

# Trulens: Model Wrapper

As in the prior notebooks, we need to wrap the pytorch model with the appropriate Trulens functionality. Here we specify the maximum input size (in terms of tokens) each tweet may have.

In [None]:
from trulens.nn.models import get_model_wrapper
from trulens.nn.quantities import ClassQoI
from trulens.nn.attribution import IntegratedGradients
from trulens.nn.attribution import Cut, OutputCut
from trulens.utils.typing import ModelInputs

task.wrapper = get_model_wrapper(task.model, device=task.device)

In [None]:
task.wrapper.print_layer_names()

# Attributions/Explainability

In [None]:
# AIQ: TODO: Explain this.

common_attributor_arguments = dict(
    model=task.wrapper,
    resolution=128,
    rebatch_size=32,
    doi_cut=Cut(task.embeddings_layer),
    qoi=ClassQoI(task.POSITIVE),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

infl = IntegratedGradients(
    **common_attributor_arguments
)

A listing as above is not very readable so Trulens comes with some utilities to present token influences a bit more concisely. First we need to set up a few parameters to make use of it:

In [None]:
# AIQ TODO: explain more

from trulens.visualizations import NLP

V = NLP(
    wrapper=task.wrapper,
    labels=task.labels,
    decode=task.token_str,
    tokenize=lambda sentences: ModelInputs(kwargs=task.tokenize(sentences,)).
    map(lambda t: t.to(task.device)),
    # huggingface models can take as input the keyword args as per produced by their tokenizers.
    input_accessor=lambda x: x.kwargs['input_ids'],
    # for huggingface models, input/token ids are under input_ids key in the input dictionary
    output_accessor=lambda x: x['logits'],
    # and logits under 'logits' key in the output dictionary
    hidden_tokens=set([task.tokenizer.pad_token_id])
    # do not display these tokens
)

print("QOI = POSITIVE")
display(V.tokens(sentences, infl))

In [None]:
# AIQ

results = []

@interact(text=textbox())
def show_attribution(text):
    global results

    results.insert(0, (V.tokens([text], infl)))
    results = results[:10]

    for result in results:
        display(result)

## Baselines

We see in the above results that special tokens such as the sentence end **&lt;/s&gt;** contributes are found to contribute a lot to the model outputs. While this may be useful in some contexts, we are more interested in the contributions of the actual words in these sentences. To focus on the words more, we need to adjust the **baseline** used in the integrated gradients computation. By default in the instantiation so far, the baseline for each token is a zero vector of the same shape as its embedding. By making the basaeline be identicaly to the explained instances on special tokens, we can rid their impact from our measurement. Trulens provides a utility for this purpose in terms of `token_baseline` which constructs for you the methods to compute the appropriate baseline. 

In [None]:
from trulens.utils.nlp import token_baseline

inputs_baseline_ids, inputs_baseline_embeddings = token_baseline(
    keep_tokens=set([task.tokenizer.cls_token_id, task.tokenizer.sep_token_id]),
    # Which tokens to preserve.

    replacement_token=task.tokenizer.pad_token_id,
    # AIQ: Try changing `replacement_token` parameter to other special or non special tokens.
    # replacement_token=task.tokenizer.mask_token_id,
    # replacement_token=task.tokenizer.vocab["happy"],

    # What to replace tokens with.

    input_accessor=lambda x: x.kwargs['input_ids'],

    ids_to_embeddings=task.model.get_input_embeddings()
    # Callable to produce embeddings from token ids.
)

We can now inspect the baselines on some example sentences. The first method returned by `token_baseline` gives us token ids to inspect while the second gives us the embeddings of the baseline which we will pass to the attributions method.

In [None]:
print("originals=", task.tokenizer.batch_decode(inputs['input_ids']))

baseline_word_ids = inputs_baseline_ids(model_inputs=ModelInputs(args=[], kwargs=inputs))

print("baselines=", task.tokenizer.batch_decode(baseline_word_ids))

In [None]:
infl_positive_baseline = IntegratedGradients(
    baseline=inputs_baseline_embeddings, **common_attributor_arguments
)

print("QOI = POSITIVE WITH BASELINE")
display(V.tokens(sentences, infl_positive_baseline))

In [None]:
# AIQ

results = []

@interact(text=textbox(c=False))
def show_attribution(text):
    global results

    default_result = widgets.HTML(V.tokens([text], infl).data)
    baseline_result = widgets.HTML(V.tokens([text], infl_positive_baseline).data)

    results.insert(0, (default_result, baseline_result))
    results = results[:3]

    for result in results:
        display(widgets.HBox(result))

# Robustness

In [None]:
def word_pattern(word):
    """
    Create a pattern that matches the given `word` as long as it is not
    immediately next to an alpha-numeric character.
    """
    return "(?<!\w)" + re.escape(word) + "(?!\w)"

def swap(thing1: str, thing2: str):
    """
    Create a method to swap occurances of `thing1` and `thing2`.
    """

    pat_swapper = re.compile(r":swapper:")
    pat1 = re.compile(word_pattern(thing1), re.IGNORECASE)
    pat2 = re.compile(word_pattern(thing2), re.IGNORECASE)

    def f(sentence: str):        
        """
        Swap instances of thing1 and thing2 in sentence.
        """
        
        temp1 = pat1.sub(":swapper:", sentence)
        temp2 = pat2.sub(thing1, temp1)
        temp3 = pat_swapper.sub(thing2, temp2)
        return temp3

    return f

def contains(s: str, pat: re.Pattern):
    """
    Determine whether the given string `s` satisfies regular expression `pat`.
    """
    return pat.search(s) is not None

In [None]:
token_pairs = [
    ("good", "great"),
    ("great", "amazing"),
    ("good", "amazing"),
]

def get_sentence_pairs(token_pairs: List[Tuple[str, str]],
                       texts: List[str]) -> List[Tuple[str, str]]:
    """
    Create sentence pairs from examples in `texts` that swap words from the
    pairs list `token_pairs`.
    """

    patterns = [
        re.compile(
            "|".join([word_pattern(tok) for tok in pair]), re.IGNORECASE
        ) for pair in token_pairs
    ]
    swappers = [swap(*pair) for pair in token_pairs]

    sentence_pairs = [
        (sentence, swap(sentence))
        for pattern, swap in
        tqdm(zip(patterns, swappers), desc="finding swap pairs", unit="pair")
        for sentence in texts
        if contains(sentence, pattern)
    ]

    print(f"found {len(sentence_pairs)} pair(s)")

    return sentence_pairs


sentence_pairs_quality = get_sentence_pairs(token_pairs, rotten_texts)

In [None]:
def compute_pair_disparities(
    sentence_pairs: List[Tuple[str, str]]
) -> List[Tuple[Tuple[str, str], float]]:
    """
    Given a collection of `sentence_pairs`, produce a list of tuples containing
    the pairs as the first element and the disparity in model scores as the second.
    """

    diffs = []

    for pair in tqdm(sentence_pairs,
                     desc="evaluating sentence pair score differences"):
        a_probits, b_probits = evaluate_to_probits(list(pair))

        diffs.append(
            torch.nn.functional.cross_entropy(
                torch.unsqueeze(a_probits, dim=0),
                torch.unsqueeze(b_probits, dim=0)
            ).detach().cpu().numpy()
        )

    diffs = np.array(diffs)
    diffs_pairs = list(
        reversed(sorted(zip(sentence_pairs, diffs), key=lambda pair: pair[1]))
    )

    return diffs_pairs

diffs_pairs_quality = compute_pair_disparities(sentence_pairs_quality)

In [None]:
def show_biggest_disparities(
    diffs: List[Tuple[Tuple[str, str], float]],
    attributor=infl_positive_baseline,
    n=3
) -> None:
    """
    Display the top disparate pairs along with their attributions.
    """

    display(
        V.tokens_stability(
            texts1=[p[0][0] for p in diffs][0:n],
            texts2=[p[0][1] for p in diffs][0:n],
            attributor=attributor
        )
    )

show_biggest_disparities(diffs_pairs_quality)

In [None]:
# AIQ

tokenbox = lambda t, c=False: widgets.Text(
    value=t, continuous_update=c, layout=aiq_layout
)


@interact(token1=tokenbox("good"), token2=tokenbox("bad"))
def show_disparities(token1, token2):

    sentence_pairs = get_sentence_pairs([(token1, token2)], rotten_texts)

    if len(sentence_pairs) == 0:
        return

    diffs_pairs = compute_pair_disparities(sentence_pairs)
    show_biggest_disparities(diffs_pairs)

# Fairness

In [None]:
gender_pairs = [
    ('he', 'she'),
    ('guy', 'gal'),
    ('himself', 'herself'),
    ('boy', 'girl'),
    ('husband', 'wife'),
    ('man', 'woman'),
    ('men', 'women'),
    ('brother', 'sister'),
    ('uncle', 'aunt'),
    ('nephew', 'niece'),
    ('dad', 'mom'),
    ('father', 'mother'),
    ('son', 'daughter'),
    ('actor', 'actress'),
    ('male', 'female'),
    ('hero', 'heroine'),
]

sentence_pairs_gender = get_sentence_pairs(gender_pairs, rotten_texts)
diffs_pairs_gender = compute_pair_disparities(sentence_pairs_gender)

In [None]:
show_biggest_disparities(diffs_pairs_gender)

In [None]:
embeddings = task.embeddings

# A vector approximating the difference between embeddings of pairs of words
# of the opposite gender. This one is for the token embedding used in
# distilbert.
gender_vector = np.frombuffer(
    base64.b85decode(
        b'?4-;r+n_BbxT&BWWUIlV^Qc9YM<?$ffFvBOyeVWPZ6QZ1ttD(IWT3Vs@g>z1(IHALq#<!8`6&IFovN59FQQ<rjVfRxf2awiYAAOf0xHHLzn)JeL8GLalpxcndmiqk3oS&V=c3uGKP<ea=_dCgN-Vq~^C<SDBB=f;{-m#?lBazp6`Fi20ii4{HmQFaeJNI`eJG!zN2p7nT&MsSPpc;*^d-=!XDFVgZ7N3{HKO`0uqm#msif;Ew5=Wq#ib{lJ)>zVdn~jjwWg4$E-Hm9mL8ELf+nV@>7XN}EFh#J6Du&8d89uqq^KIJJfWAH<s(?A83iFILnMlrgei)r<*S^fJgHDAWT^b8{i%wqmaLPhbS9T6Jt#99Sf6L8+bydo;3dAJS}Uz4E1CW%<)@0Q!Ypv2O(|R<c%D?4xR(l{UL)hBP^^)syPV#omZOCyB&3O+Iv%o~8KW|tw5a_oU?h^MQ!4KuK_zc2qAUp_g(MRpzbQy4vZ$V-e<o?1A1Fqg-lZKX51h559ip|Y3@9lngeeCdV4_MV1tujcN-QR*NT-IaRil6?RICoHU?Y?#zN8MQpeo*^uBGuJ$tk-V*C>ytF{UOeM4I9z{VUa<Vx*-d-l$9`fGFf76r}JNuPLCYkTBb<^`&<w5-3!r4Wiwe2&9!9nI?EFcBnfjS1ipX@gcS+XD6R2W3AF1h>!iL@|kC$dla6YHl}MM<0kVa@~U*KA0pVH3?n6;I4Dsj{2DZ*bRJnH^CRFT$f!e}dniDvR;=tH3Mn0=KqeCw)1Lk)mZ)|kEUBrWAE-g9fhH>_i>J6EXBbkVJ|R9RW2H-=R-}(7FDB6^EvphMpez}nRGA2_{iNV0kf_8hG_1QR;-b?fwIDDpnI}W4-5<av4k0q42p)f_%Ay~t-K6*-@+Mjz%$YHy51(}(H7Uv>(<l=m@*IUHzoEq}BB-CJ^(USw-Yae=oTara=C19iiYMJ3?yV>$*qY=iC!+l-R;E;>I4FCe5iVFIf2tZC5UZpoNG!mnxS~QM38iYMm?*WNN1_!Zh^a27Eu`Qoc$}IfR+ikN7@WSNXQ}EcY!wqH2PuOl#HIx#l_0O7ogf&f@u_Slmm`>_AgeQ^RwvCWk}20BY9RBYb|sFgpeAH0Dx#e#tt9TASt-Dq@1$NOs3duoK`EM`FQsOske;`wT_Z0g0w<>+xh5K%eJIJG%cHs?)}7U(Vx_Yrcpk?fF(C#ZaHY4W)TXW`zn~3}NvZsnOQu36Q6W#L)T+R%yeKOyr=sMY+ou^QsHsz@^&6EdH6?hZOqr#q&8vec0VOpo7b`0ylcE<TLM8XAv7>h-WRQC!;i)z*1gXcS@g$NXP%J#EOdu8^39HJf-!4$5t1F|X>nB1bOs2G`_@1C8;w3n$h$<i`f+X;tZX^b#;Ucys8?I0(!zlPFE1h{F1FV`R&#Ig!^d}%BvmpE}lBB68#i@S|a3H`YHy+|3qoeSwl&35tv!2JIffi^f_$lxpN~Hd%My0>0I32|%;S;zZ38EJ*&ZM^~8z+q-g_sARQK)mI6e2np%BPPeTP$EE%&YOJg(wxF_94wE;wph611IgKq9+HaTdc1oA14&4ysA2+%p*K3yQG9HBPA3nuBm+}Bc<mgCL^*Z>ztA#$1O7^A1jL_&?9yoLLzG;4W}ldVyJc}@u_hmODydxda9wBL7j9dl?n!@W-8FCn5dm7l`5;HN2rgX&7LBtm8K*teWQw>D=p^^x~S!>GaR6#Mkv3kktndAR2qe)DyAc)iKy+Q&ZhH}K&En`@+qt+lOfouwj+$EbEZh08mM%hgeI${1Sv!*_$5}TS09xr!ljccx+a^awVhe0JfIpJSSmoK&7Kw@hbW;WB_D1h{;R}`OQ|d^%d6>`KqUt#-K4;xLadW0q^d!nF(UV!=_;$CJ1X3%@2E4Yk*i)Mz>y9k<fnnC5Gh|R<|O+jB#_&s2B9h)'
    ),
    dtype='float16'
)


def normalize(v):
    """
    Normalize a single vector.
    """
    return v / np.linalg.norm(v, ord=2)


def normalize_many(v):
    """
    Normalize an array of vectors.
    """
    return v / np.linalg.norm(v, axis=1, ord=2)[:, np.newaxis]


all_embs_norm = normalize_many(embeddings)
baseline_penalties = np.abs(np.dot(all_embs_norm, gender_vector))

direction_vector: np.ndarray = gender_vector


def embedding_opposite_id(emb: np.ndarray) -> Tuple[int, float]:
    """
    Get the token id of the token closest to the gender-opposite of the given
    `emb`.
    """

    emb = normalize(emb)
    scores = np.abs(
        np.dot(
            normalize_many(emb - all_embs_norm + 0.000000001), direction_vector
        )
    ) - 0.55 * baseline_penalties

    best = np.argmax(scores)

    return best, scores[best]


def embedding_opposite(emb: np.ndarray) -> np.ndarray:
    """
    Try to find the embedding close to the opposite gender relative to the given
    `emb`. 
    """

    best_id, best_score = embedding_opposite_id(emb)

    if best_score > 0.25:
        return embeddings[best_id]
    else:
        return emb


def embedding_neutralize(emb: np.ndarray) -> np.ndarray:
    """
    Remove the component of the given embedding that points in the gender
    direction.
    """
    return emb - np.dot(emb, direction_vector) * direction_vector


@functools.lru_cache(maxsize=len(embeddings))
def token_id_opposite(token_id: int):
    """
    Try to find the opposite of `token_id` according to the direction of
    `direction_vector`. If a good candidate is not found, returns the given
    `token_id` instead.
    """
    best_id, best_score = embedding_opposite_id(all_embs_norm[token_id])

    if best_score > 0.20:
        return best_id
    else:
        return token_id


def swap_token(token: str) -> str:
    a_id = task.id_of_token[token]
    b_id = token_id_opposite(a_id)
    return task.token_of_id[b_id]

In [85]:
# geometry of gender in embedding space

# AIQ: This is computationally intensive picture. It is only useful if you use the tsne reduction.


color = np.dot(normalize_many(task.embeddings), direction_vector) 
cmin = color.min()
cmax = color.max()
most_gendered = np.abs(color) >= 0.0

embs_proj = tsne_embedding

fig = go.FigureWidget(layout=dict(width=1000, height=1000))
fig.add_scatter(
    x=embs_proj[most_gendered, 0],
    y=embs_proj[most_gendered, 1],
    text=task.vocab[most_gendered],
    mode='markers',
    marker={'cmin': cmin, 'cmax': cmax, 'colorscale': "Picnic", 'color':color[most_gendered], 'colorbar': dict(thickness=20)},
    marker_size=4,
)


display(fig)

FigureWidget({
    'data': [{'marker': {'cmax': 0.18029511,
                         'cmin': -0.40321934,
    …

In [None]:
for (s1, _), _ in diffs_pairs_gender[0:4]:
    print("original sentence:", s1)
    toks = task.tokenize([s1])['input_ids'][0].detach().cpu().numpy()
    toks_opposites = [token_id_opposite(t) for t in toks]
    print("swapped sentence:", task.tokenizer.decode(toks_opposites))
    print()

In [None]:
def baseline_swap(z: torch.Tensor) -> torch.Tensor:
    """
    Given input tensor of embeddings, produce a baseline that swaps their gender
    component.
    """

    if isinstance(z, torch.Tensor):
        z = z.detach().cpu().numpy()

    return torch.tensor(np.array(
        [[embedding_opposite(emb) for emb in instance] for instance in z]
    )).to(task.device)

infl_swap_gender = IntegratedGradients(
    baseline=baseline_swap,
    **common_attributor_arguments
)

show_biggest_disparities(diffs_pairs_gender, attributor=infl_positive_baseline)
show_biggest_disparities(diffs_pairs_gender, attributor=infl_swap_gender)

In [None]:
# AIQ

results = []


@interact(
    text=textbox(
        t=
        "Johnson has, in his first film, set himself a task he is not early up to.",
        c=False
    )
)
def show_attribution(text):
    global results

    default_result = widgets.HTML(V.tokens([text], infl_positive_baseline).data)
    baseline_result = widgets.HTML(V.tokens([text], infl_swap_gender).data)

    results.insert(0, (default_result, baseline_result))
    results = results[:3]

    for result in results:
        display(widgets.HBox(result))

In [None]:
def baseline_neutralize(z: torch.Tensor) -> torch.Tensor:
    """
    Given input tensor of embeddings, produce a baseline that removes their
    gender component.
    """

    if isinstance(z, torch.Tensor):
        z = z.detach().cpu().numpy()

    return torch.tensor(
        [[embedding_neutralize(emb) for emb in instance] for instance in z]
    ).to(task.device)


infl_neutralize_gender = IntegratedGradients(
    baseline=baseline_neutralize, **common_attributor_arguments
)

show_biggest_disparities(diffs_pairs_gender, attributor=infl_positive_baseline)
show_biggest_disparities(diffs_pairs_gender, attributor=infl_neutralize_gender)

In [None]:
# AIQ

results = []


@interact(
    text=textbox(
        t=
        "Johnson has, in his first film, set himself a task he is not early up to.",
        c=False
    )
)
def show_attribution(text):
    global results

    default_result = widgets.HTML(V.tokens([text], infl_positive_baseline).data)
    baseline_result = widgets.HTML(
        V.tokens([text], infl_neutralize_gender).data
    )

    results.insert(0, (default_result, baseline_result))
    results = results[:3]

    for result in results:
        display(widgets.HBox(result))

# Drift

In [None]:
# Get another dataset to compare to.

imdb_train = load_dataset("imdb", "plain_text", split="train")
imdb_test = load_dataset("imdb", "plain_text", split="test")
imdb_texts = imdb_train['text'] + imdb_test['text']

In [None]:
def tokenize(portion):
    return task.tokenizer.batch_encode_plus(
        portion,
        add_special_tokens=True,
        return_attention_mask=False,
        max_length=512,
        truncation=True
    )['input_ids']

p = mp.Pool(24)

def toks_of_texts(texts):
    toks = p.map(tokenize, [texts[1000*i: 1000*(i+1)] for i in range(len(texts)//1000)])
    all = np.array([i for tok in toks for t in tok for i in t ])

    return all

def dists_of_texts(texts):
    all = toks_of_texts(texts)   

    counts = np.zeros(task.tokenizer.vocab_size)
    total = len(all)
    for i in all:
        counts[i] += 1

    dist = counts / total

    return counts, dist

def tops_of_texts(texts, n = 10):
    counts, dist = dists_of_texts(texts)

    return tops_of_dists(counts, dist, n=n)

def tops_of_dists(c, d, n=10):
    sortindex = np.argsort(d)
    top = []

    for idx in sortindex[0:n]:
        top.append((idx, c[idx], d[idx], task.tokenizer.decode(idx)))

    crest_pos = 0
    crest_neg = 0
    drest_pos = 0
    drest_neg = 0

    for idx in sortindex[n:-n]:
        if c[idx] >= 0:
            crest_pos += c[idx]
            drest_pos += d[idx]
        else:
            crest_neg += c[idx]
            drest_neg += d[idx]

    top.append((-1, crest_neg, drest_neg, "*"))
    top.append((-1, crest_pos, drest_pos, "*"))

    for idx in sortindex[-n:]:
        top.append((idx, c[idx], d[idx], task.tokenizer.decode(idx)))

    return top

In [None]:
def plotdist(d1, d2, top, l1, l2):

    n = len(top)

    dprobs = pd.DataFrame(
        {
            "token": [t[3] for t in top] * 2,
            "dataset": ([l1] * n) + ([l2] * n),
            "prob": [d1[t[0]] for t in top] + [d2[t[0]] for t in top]
        }
    )
    fig = px.bar(dprobs, x="token", y="prob", color="dataset", barmode='group')
    display(fig)

    ddiff = pd.DataFrame(
        {
            "token": [t[3] for t in top],
            "prob": [t[2] for t in top]
        }
    )
    fig = px.bar(ddiff, x="token", y="prob")
    display(fig)


c1, d1 = dists_of_texts(imdb_train['text'])
c2, d2 = dists_of_texts(rotten_train['text'])
top = tops_of_dists(c1 - c2, d1 - d2, n = 20)

plotdist(d1, d2, top, l1='imdb', l2='rotten')

In [None]:
c1, d1 = dists_of_texts(imdb_train['text'])
c2, d2 = dists_of_texts(imdb_test['text'])
top = tops_of_dists(c1 - c2, d1 - d2, n = 20)

plotdist(d1=d1, d2=d2, top=top, l1='imdb train', l2='imdb test')

## Drift in embedding distribution

In [None]:
c1, d1 = dists_of_texts(rotten_train['text'])
c2, d2 = dists_of_texts(imdb_train['text'])

data1 = dict(
    prob=d1,
    token_id=range(len(task.embeddings)))
data1.update({f"dim{did}": task.embeddings[:,did] for did in range(768)})

df1 = pd.DataFrame(data1)

data2 = dict(
    prob=d2,
    token_id=range(len(task.embeddings)))
    
data2.update({f"dim{did}": task.embeddings[:,did] for did in range(768)})

df2 = pd.DataFrame(data2)

In [None]:
def show_hists(s1, s2, df1, df2, title):
    counts1, bin_edges = np.histogram(s1, bins=50, weights=df1.prob.values)
    counts2, _ = np.histogram(s2, bins=bin_edges, weights=df2.prob.values)

    fig = go.FigureWidget(layout=dict(title=title))
    bar1= fig.add_bar(x=bin_edges, y=counts1, name="rotten")
    bar2= fig.add_bar(x=bin_edges, y=counts2, name="imdb")

    display(fig)

@interact(dim=widgets.IntSlider(value=0, min=0, max=767, layout=aiq_layout))
def show_dim_hist(dim):
    show_hists(df1[f'dim{dim}'], df2[f'dim{dim}'], df1, df2, title=f"embedding dimension {dim}")    

## Drift in gender distribution

In [None]:
data1g = dict(
    gender=np.dot(task.embeddings, direction_vector),
    prob=d1,
    token_id=range(len(task.embeddings)))

df1g = pd.DataFrame(data1g)

data2g = dict(
    gender=np.dot(task.embeddings, direction_vector),
    prob=d2,
    token_id=range(len(task.embeddings)))

df2g = pd.DataFrame(data2g)

show_hists(df1g.gender, df2g.gender, df1g, df2g, title="gender histogram")