In [1]:
!pip install captum

Collecting captum
  Downloading captum-0.5.0-py3-none-any.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m0m
Installing collected packages: captum
Successfully installed captum-0.5.0
[0m

In [2]:
import csv
import torch
from datasets import Dataset
import transformers
from transformers import (
  AdamW,
  BertConfig,
  BertModel,
  BertTokenizer,
  DistilBertTokenizer,
  DistilBertModel,
  DistilBertForSequenceClassification,
  BertForSequenceClassification)
from torch.utils.data import DataLoader
import torch.nn as nn
import os
import captum
from captum.attr import IntegratedGradients, Occlusion, LayerGradCam, LayerAttribution
from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

import numpy as np

In [3]:
finegrained_sentiments_dict = {
"anger": ["anger", "annoyance", "disapproval"],
"disgust": ["disgust"],
"fear": ["fear", "nervousness"],
"joy": ["joy", "amusement", "approval", "excitement", "gratitude",  "love", "optimism", "relief", "pride", "admiration", "desire", "caring"],
"sadness": ["sadness", "disappointment", "embarrassment", "grief",  "remorse"],
"surprise": ["surprise", "realization", "confusion", "curiosity"]
}

In [4]:
!ls ../

input  lib  working


In [5]:
DATA_DIR = "../input/emotionclss/"
train = {"input": [], "labels": []}
dev = {"input": [], "labels": []}
test = {"input": [], "labels": []}

with open(DATA_DIR + "train.tsv") as file:
    tsv_file = csv.reader(file, delimiter="\t") 
    for line in tsv_file:
        train["input"].append(line[0])
        labels = line[1].split(",")
        one_hot = [0 for i in range(28)]
        for label in labels:
            one_hot[int(label)] = 1
        train["labels"].append(one_hot)

with open(DATA_DIR + "dev.tsv") as file:
    tsv_file = csv.reader(file, delimiter="\t") 
    for line in tsv_file:
        dev["input"].append(line[0])
        labels = line[1].split(",")
        one_hot = [0 for i in range(28)]
        for label in labels:
            one_hot[int(label)] = 1
        dev["labels"].append(one_hot)

with open(DATA_DIR + "test.tsv") as file:
    tsv_file = csv.reader(file, delimiter="\t") 
    for line in tsv_file:
        test["input"].append(line[0])
        labels = line[1].split(",")
        one_hot = [0 for i in range(28)]
        for label in labels:
            one_hot[int(label)] = 1
        test["labels"].append(one_hot)
        
print("Number of train examples are {}".format(len(train["input"])))
print("Number of dev examples are {}".format(len(dev["input"])))
print("Number of test examples are {}".format(len(test["input"])))

Number of train examples are 43410
Number of dev examples are 5426
Number of test examples are 5427


In [6]:
# Creating higgingface datasets
train_dataset = Dataset.from_dict(train)
dev_dataset = Dataset.from_dict(dev)
test_dataset = Dataset.from_dict(test)

print(train_dataset)

Dataset({
    features: ['input', 'labels'],
    num_rows: 43410
})


In [7]:
from torch.utils.data import Dataset
class LoadData(Dataset):
    """
    Using this since dataloader expects map-style dataset objects
    
    """
    
    def __init__(
        self, dataset, tokenizer, source_length):
        """
        Initializes a Dataset class

        Args:
            dataset (Dataset object): Input Dataset
            tokenizer (Tokenizer object): Transformer tokenizer
            source_length (int): Max length of source text
        """
        
        self.tokenizer = tokenizer
        self.data = dataset
        self.source_length = source_length
        self.source_text = self.data["input"]
        self.target_labels = self.data["labels"]

    def __len__(self):
        return len(self.target_labels)

    def __getitem__(self, index):
        """
        return input ids, attention masks and target ids
        
        """
        source_text = str(self.source_text[index])
        target_label = self.target_labels[index]

        # cleaning data so as to ensure data is in string type
        source_text = " ".join(source_text.split())

        source = self.tokenizer.__call__(
            [source_text],
            max_length=self.source_length,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        
        target = torch.tensor(target_label)

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()

        return {
            "source_ids": source_ids.to(dtype=torch.long),
            "source_mask": source_mask.to(dtype=torch.long),
            "target": target.squeeze().to(dtype=torch.long)
        }

In [8]:
#joeddav/distilbert-base-uncased-go-emotions-student
parameters = {"model": "bhadresh-savani/bert-base-go-emotion",  # model_type: t5-base/t5-large
    "train_bs": 8,  # training batch size
    "val_bs": 10,  # validation batch size
    "test_bs": 15,
    "epochs": 3,  # number of training epochs
    "lr": 6e-4,  # learning rate
    "wd": 0.0001,
    "max_source_length": 512,  # max length of source text
    "SEED": 42,
    "out_dir": "./",
    "hidden_size": 768,
    "num_classes": 28}

index_label = {0:"admiration", 1:"amusement", 2:"anger", 3:"annoyance", 4:"approval", 5:"caring", 6:"confusion",
            7:"curiosity", 8:"desire", 9:"disappointment", 10:"disapproval", 11:"disgust", 12:"embarrassment",
            13:"excitement", 14:"fear", 15:"gratitude", 16:"grief", 17:"joy", 18:"love", 19:"nervousness",
            20:"optimism", 21:"pride", 22:"realization", 23:"relief", 24:"remorse", 25:"sadness",
            26:"surprise", 27:"neutral"}
label_list = list(index_label.values())

In [9]:
def compute_test_outputs(model, test_dataloader, tokenizer, device, label_list, index_label):
    predictions = []
    labels = []
    
    with torch.no_grad():
        steps = 0
        for test_batch in test_dataloader:
            y = test_batch['target'].to(device, dtype = torch.float32)
            ids = test_batch['source_ids'].to(device, dtype = torch.long)
            mask = test_batch['source_mask'].to(device, dtype = torch.long)

            output = model(
                input_ids=ids,
                attention_mask=mask,
            )
            
            output = output["logits"]
            output = torch.sigmoid(output)
            
            predictions.extend(output.detach().cpu().numpy())
            labels.extend(y.detach().cpu().numpy())
            if steps == 5: break
    
    return predictions, labels

In [10]:
cuda =  torch.cuda.is_available()
device = torch.device("cuda") if cuda else torch.device("cpu")

tokenizer = DistilBertTokenizer.from_pretrained(parameters["model"])
model = BertForSequenceClassification.from_pretrained(parameters["model"])
model = model.to(device)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

In [11]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator at the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the word sequence

In [32]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id, device):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * (len(input_ids)-2) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(input_ids)

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def predict(inputs, attention_mask=None):
    output = model(inputs, attention_mask=attention_mask, )
    output = output["logits"]
    output = torch.sigmoid(output)
    return output

def forward_func(inputs, i, device, attention_mask=None):
    pred = predict(inputs,
                   attention_mask=attention_mask)
    #return pred.max(1).values
    pred = torch.index_select(pred, 1, torch.tensor([i], device=device))
    return pred

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [53]:
# preparing data
data = []
for i,text in enumerate(test_dataset["input"][0:100]):
    out = test_dataset["labels"][i]
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id, device)
    attention_mask = construct_attention_mask(input_ids)
    out = torch.tensor([out], device=device)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)

    #print(len(out[0]), len(input_ids[0]), len(ref_input_ids[0]), len(attention_mask[0]))
    data.append((out, input_ids, ref_input_ids, attention_mask, all_tokens))

In [67]:
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)  # embeddings is the first layer
attributions = []
predictions = []
for j,sample in enumerate(data[0:4]):
    pred = predict(sample[1], attention_mask=sample[3])
    
    classes = []
    for i in range(0,28):
        attribution, delta = lig.attribute(inputs=sample[1],
                                      baselines=sample[2],
                                      additional_forward_args=(i, device, sample[3]),
                                      return_convergence_delta=True)
        attribution = summarize_attributions(attribution).detach().cpu().numpy()
        attribution = [(attr,k) for k,attr in enumerate(attribution)]
        attribution.sort(key=(lambda x: x[0]), reverse=True)
        classes.append(attribution)
    attributions.append(classes)
    predictions.append(pred)

In [70]:
print(len(attributions[1][0]))
print(len(data[1][-1]))

16
16


In [72]:
print("Predicted label is:", index_label[np.argmax(predictions[3][0].detach().cpu().numpy())], np.argmax(predictions[3][0].detach().cpu().numpy()))
print("printing more valuable 3 words for each class:")
for i in range(28):
    print("Now we have class {}!!!!".format(index_label[i]))
    words = []
    for k in range(5):
        words.append(data[3][-1][attributions[3][i][k][1]])
    print(tuple(words))

Predicted label is: gratitude 15
printing more valuable 3 words for each class:
Now we have class admiration!!!!
(',', "'", '!', 'thank', 'today')
Now we have class amusement!!!!
('i', '!', 'me', "'", ',')
Now we have class anger!!!!
('t', 'me', 'didn', 'teaching', '!')
Now we have class annoyance!!!!
('t', 'didn', 'me', "'", '!')
Now we have class approval!!!!
("'", 'know', 'you', 'today', ',')
Now we have class caring!!!!
('you', 'thank', "'", ',', 'for')
Now we have class confusion!!!!
('know', 'didn', 't', 'something', 'i')
Now we have class curiosity!!!!
('know', 'didn', 'something', 'me', 'for')
Now we have class desire!!!!
('i', 'something', 'for', 'you', 'today')
Now we have class disappointment!!!!
('t', 'didn', 'me', 'teaching', 'for')
Now we have class disapproval!!!!
('t', "'", 'didn', 'teaching', '[CLS]')
Now we have class disgust!!!!
('t', 'didn', 'me', "'", 'teaching')
Now we have class embarrassment!!!!
('didn', 'me', 'i', 't', 'teaching')
Now we have class excitement!!

In [18]:
start_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum,
                        torch.max(pred),
                        torch.argmax(pred),
                        torch.argmax(pred),
                        str(14),
                        attributions_start_sum.sum(),
                        all_tokens,
                        delta_start)

print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
20.0,20 (0.42),14.0,1.4,"[CLS] kings fan here , good luck to you guys ! will be an interesting game to watch ! [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
20.0,20 (0.42),14.0,1.4,"[CLS] kings fan here , good luck to you guys ! will be an interesting game to watch ! [SEP]"
,,,,
