In [None]:
!pip install captum


Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m42.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: captum
Successfully installed captum-0.7.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
hf_weights_name = 'cardiffnlp/twitter-roberta-base-sentiment'


In [None]:
hf_tokenizer = AutoTokenizer.from_pretrained(hf_weights_name)


config.json: 100%|██████████| 747/747 [00:00<00:00, 170kB/s]
vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 33.2MB/s]
merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 45.1MB/s]
special_tokens_map.json: 100%|██████████| 150/150 [00:00<00:00, 68.0kB/s]


In [None]:
hf_model = AutoModelForSequenceClassification.from_pretrained(hf_weights_name)


pytorch_model.bin: 100%|██████████| 499M/499M [00:01<00:00, 303MB/s]


### Predicted probabilities

We create a function to take a text as input and returns the predicted probabilities for each class. 
The input text is encoded using defined tokenizer (from pre-trained roberta base sentiment), giving the output of tensor.

Then pass the input tensor through the hf_model to get the logits (raw outputs) for each class. [0] is used to access the logits from the output tuple.

Then the softmax function along the second dimension (dimension 1) of the logits tensor. Softmax converts raw logits into probabilities, ensuring that the probabilities sum up to 1 across all classes

In [None]:
def hf_predict_one_proba(text):
    input_ids = hf_tokenizer.encode(
        text, add_special_tokens=True, return_tensors='pt')
    hf_model.eval()
    with torch.no_grad():
        logits = hf_model(input_ids)[0]  
        preds = F.softmax(logits, dim=1)
    hf_model.train()
    return preds.squeeze(0)

### Prepares the input and baseline token IDs required for Integrated Gradients attribution method,

In [None]:
def hf_ig_encodings(text):
    pad_id = hf_tokenizer.pad_token_id
    cls_id = hf_tokenizer.cls_token_id
    sep_id = hf_tokenizer.sep_token_id
    input_ids = hf_tokenizer.encode(text, add_special_tokens=False)
    base_ids = [pad_id] * len(input_ids)
    input_ids = [cls_id] + input_ids + [sep_id]
    base_ids = [cls_id] + base_ids + [sep_id]
    return torch.LongTensor([input_ids]), torch.LongTensor([base_ids])

In [None]:
def hf_ig_analyses(text2class):
    data = []
    for text, true_class in text2class.items():
        score_vis = hf_ig_analysis_one(text, true_class)
        data.append(score_vis)
    viz.visualize_text(data)


def hf_ig_analysis_one(text, true_class):
    # Option to look at different layers:
    # layer = model.roberta.encoder.layer[0]
    # layer = model.roberta.embeddings.word_embeddings
    layer = hf_model.roberta.embeddings

    def ig_forward(inputs):
        return hf_model(inputs).logits

    ig = LayerIntegratedGradients(ig_forward, layer)

    input_ids, base_ids = hf_ig_encodings(text)

    attrs, delta = ig.attribute(
        input_ids,
        base_ids,
        target=true_class,
        return_convergence_delta=True)

    # Summarize and z-score normalize the attributions
    # for each representation in `layer`:
    scores = attrs.sum(dim=-1).squeeze(0)
    scores = (scores - scores.mean()) / scores.norm()

    # Intuitive tokens to help with analysis:
    raw_input = hf_tokenizer.convert_ids_to_tokens(input_ids.tolist()[0])
    # RoBERTa-specific clean-up:
    raw_input = [x.strip("Ġ") for x in raw_input]

    # Predictions for comparisons:
    pred_probs = hf_predict_one_proba(text)
    pred_class = pred_probs.argmax()

    score_vis = viz.VisualizationDataRecord(
        word_attributions=scores,
        pred_prob=pred_probs.max(),
        pred_class=pred_class,
        true_class=true_class,
        attr_class=None,
        attr_score=attrs.sum(),
        raw_input_ids=raw_input,
        convergence_score=delta)

    return score_vis

In [None]:
import pandas as pd
import numpy as np

df_train = pd.read_csv('/work/train.csv')
df_test = pd.read_csv('/work/test.csv')
df_test.head()

Unnamed: 0,textID,text,sentiment
0,f87dea47db,Last session of the day http://twitpic.com/67ezh,neutral
1,96d74cb729,Shanghai is also really exciting (precisely -...,positive
2,eee518ae67,"Recession hit Veronique Branquinho, she has to...",negative
3,01082688c6,happy bday!,positive
4,33987a8ee5,http://twitpic.com/4w75p - I like it!!,positive


In [None]:
# Define mapping from string to integer
sentiment_mapping = {"negative": 0, "neutral": 1, "positive": 2}

# Convert sentiment column to integer using mapping
df_test['sentiment_id'] = df_test['sentiment'].map(sentiment_mapping)
df_train['sentiment_id'] = df_train['sentiment'].map(sentiment_mapping)

In [None]:
train_set = df_train.head(5).set_index('text')
train_set_explain = train_set['sentiment_id'].to_dict()
print(train_set)
print(train_set_explain)

score_vis = hf_ig_analyses(train_set_explain)

                                                        textID  \
text                                                             
 I`d have responded, if I were going                cb774db0d1   
 Sooo SAD I will miss you here in San Diego!!!      549e992a42   
my boss is bullying me...                           088c60f138   
 what interview! leave me alone                     9642c003ef   
 Sons of ****, why couldn`t they put them on th...  358bd9e861   

                                                                          selected_text  \
text                                                                                      
 I`d have responded, if I were going                I`d have responded, if I were going   
 Sooo SAD I will miss you here in San Diego!!!                                 Sooo SAD   
my boss is bullying me...                                                   bullying me   
 what interview! leave me alone                                          leave me 

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.81),,0.81,"#s I ` d have responded , if I were going #/s"
,,,,
0.0,0 (0.92),,1.75,#s So oo S AD I will miss you here in San Diego !!! #/s
,,,,
0.0,0 (0.92),,2.58,#s my boss is bullying me ... #/s
,,,,
0.0,0 (0.78),,0.73,#s what interview ! leave me alone #/s
,,,,
0.0,0 (0.56),,1.66,"#s Sons of **** , why couldn ` t they put them on the releases we already bought #/s"
,,,,


In [None]:
df_test[30:35]['text']

30     guess that depends on if you want to be on th...
31                  resorted to eating Mickey Ds ALONE.
32    There is a faux gothy chick looking at me, sor...
33       did he ask for your Twitter ID? Your sun sign?
34                                      where`d you go!
Name: text, dtype: object

In [None]:
text = df_test[15:16]['text']
print(text)
sentiment_id = df_test[15:16]['sentiment_id']
print(sentiment_id)

15    So hot today =_=  don`t like it and i hate my ...
Name: text, dtype: object
15    0
Name: sentiment_id, dtype: int64


In [None]:

df_subset = df_test.head(10).set_index('text')

# Convert the subset DataFrame into a dictionary
text_test = df_subset['sentiment_id'].to_dict()



In [None]:
score_vis = hf_ig_analyses(text_test)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.90),,0.89,#s Last session of the day http :// tw it pic . com / 67 ez h #/s
,,,,
2.0,2 (0.99),,6.69,#s Shanghai is also really exciting ( pre cise ly -- skysc rap ers gal ore ). Good twe eps in China : ( SH ) ( B J ). #/s
,,,,
0.0,0 (0.91),,2.36,"#s Rec ession hit Ver on ique Br anqu inho , she has to quit her company , such a shame ! #/s"
,,,,
2.0,2 (0.98),,3.06,#s happy b day ! #/s
,,,,
2.0,2 (0.98),,4.1,#s http :// tw it pic . com / 4 w 75 p - I like it !! #/s
,,,,


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=0f4a837a-38ea-4fe2-a29c-4b77ec476727' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>