In [1]:
from transformers import BertTokenizer
import torch
import numpy as np
import os
import json
import random
import matplotlib.pyplot as plt
import tqdm

from torch import nn
from transformers import BertModel
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
with open('dev.json') as fp:
    stereoSet = json.load(fp)["data"]

all_example = []
all_label = []


inter_set = stereoSet["intersentence"]
for set_exmaple in inter_set:
    context = set_exmaple["context"]
    sentences = set_exmaple["sentences"]
    for sen in sentences:
        all_example.append(context + " " + sen["sentence"])
        if sen["gold_label"] == "unrelated":
            all_label.append(0)
        if sen["gold_label"] == "stereotype":
            all_label.append(1)
        if sen["gold_label"] == "anti-stereotype":
            all_label.append(2)


intra_set = stereoSet["intrasentence"]

for set_exmaple in intra_set:
    context = set_exmaple["context"]
    sentences = set_exmaple["sentences"]
    for sen in sentences:
        all_example.append(context + " " + sen["sentence"])
        if sen["gold_label"] == "unrelated":
            all_label.append(0)
        if sen["gold_label"] == "stereotype":
            all_label.append(1)
        if sen["gold_label"] == "anti-stereotype":
            all_label.append(2)


print(len(all_example))
print(len(all_label))

12687
12687


In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

class StereoDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len     

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long),
        }



In [5]:
indices = list(range(len(all_example)))
random.shuffle(indices)

train_split = int(len(indices) * 0.7)
val_split = int(len(indices) * 0.15)

train_idx = indices[:train_split]
val_idx = indices[train_split : train_split+val_split]
test_idx = indices[train_split+val_split : ]


train_sentences = [all_example[i] for i in train_idx]
train_labels = [all_label[i] for i in train_idx]
val_sentences = [all_example[i] for i in val_idx]
val_labels = [all_label[i] for i in val_idx]
test_sentences = [all_example[i] for i in test_idx]
test_labels = [all_label[i] for i in test_idx]

In [6]:
train_dataset = StereoDataset(train_sentences, train_labels, tokenizer, 32)
val_dataset = StereoDataset(val_sentences, val_labels, tokenizer, 32)
test_dataset = StereoDataset(test_sentences, test_labels, tokenizer, 32)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [15]:
class BertClassifier(nn.Module):
    def __init__(self, num_out):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_out)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_output)
        return logits

In [37]:
num_out = 3
model = BertClassifier(num_out)

optimizer = AdamW(model.parameters(), lr=5e-6)
loss_fn = nn.CrossEntropyLoss()

model.to(device)
loss_fn.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


CrossEntropyLoss()

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_acc = 0

    for inputs in tqdm(train_loader):
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        labels = inputs["label"].to(device)

        optimizer.zero_grad()
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)

        total_loss += loss.item()
        total_acc += (output.argmax(dim=1) == labels).sum().item()

        loss.backward()
        optimizer.step()

    total_dev_loss = 0
    total_dev_acc = 0

    with torch.no_grad():
        model.eval()
        for inputs in tqdm(val_loader):
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)
            labels = inputs["label"].to(device)

            output = model(input_ids, attention_mask)
            loss = loss_fn(output, labels)

            total_dev_loss += loss.item()
            total_dev_acc += (output.argmax(dim=1) == labels).sum().item()

    print(f'Epochs: {epoch + 1}, Train Loss: {total_loss / len(train_dataset): .3f}, Train Accuracy: {total_acc / len(train_dataset): .3f}, Val Loss: {total_dev_loss / len(val_dataset): .3f}, Val Accuracy: {total_dev_acc / len(val_dataset): .3f}')



100%|██████████| 555/555 [00:21<00:00, 26.20it/s]
100%|██████████| 119/119 [00:01<00:00, 74.93it/s]


Epochs: 1, Train Loss:  0.049, Train Accuracy:  0.607, Val Loss:  0.038, Val Accuracy:  0.698


100%|██████████| 555/555 [00:21<00:00, 25.99it/s]
100%|██████████| 119/119 [00:01<00:00, 73.28it/s]


Epochs: 2, Train Loss:  0.033, Train Accuracy:  0.756, Val Loss:  0.036, Val Accuracy:  0.726


100%|██████████| 555/555 [00:21<00:00, 25.33it/s]
100%|██████████| 119/119 [00:01<00:00, 76.83it/s]


Epochs: 3, Train Loss:  0.027, Train Accuracy:  0.805, Val Loss:  0.036, Val Accuracy:  0.733


100%|██████████| 555/555 [00:21<00:00, 25.60it/s]
100%|██████████| 119/119 [00:01<00:00, 72.24it/s]


Epochs: 4, Train Loss:  0.022, Train Accuracy:  0.845, Val Loss:  0.037, Val Accuracy:  0.738


100%|██████████| 555/555 [00:21<00:00, 25.94it/s]
100%|██████████| 119/119 [00:01<00:00, 74.22it/s]

Epochs: 5, Train Loss:  0.018, Train Accuracy:  0.873, Val Loss:  0.038, Val Accuracy:  0.758





In [40]:
total_test_acc = 0

with torch.no_grad():
    for inputs in tqdm(test_loader):
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        labels = inputs["label"].to(device)

        output = model(input_ids, attention_mask)
        
        total_test_acc += (output.argmax(dim=1) == labels).sum().item()

print(f'Test Accuracy: {total_test_acc / len(test_dataset): .3f}')

100%|██████████| 119/119 [00:01<00:00, 74.71it/s]

Test Accuracy:  0.750





In [2]:
annotations_path = []
dir = './BASIL/annotations/'
for folder in os.listdir(dir):
    folder_path = os.path.join(dir, folder)
    for filename in os.listdir(folder_path):
      annotations_path.append(os.path.join(folder_path, filename))


articles_path = []
dir = './BASIL/articles/'
for folder in os.listdir(dir):
    folder_path = os.path.join(dir, folder)
    for filename in os.listdir(folder_path):
      articles_path.append(os.path.join(folder_path, filename))

In [3]:
annotations_path = sorted(annotations_path)
articles_path = sorted(articles_path)

print(annotations_path)
print(articles_path)
annotations = []
articles = []
for path in sorted(annotations_path):
  with open(path, 'r') as fp:
    annotations.append(json.load(fp))

for path in sorted(articles_path):
  with open(path, 'r') as fp:
    articles.append(json.load(fp))

['./BASIL/annotations/2010\\2b95d2cf-e979-4f9c-ae27-9a5370934f23_1_ann.json', './BASIL/annotations/2010\\2b95d2cf-e979-4f9c-ae27-9a5370934f23_2_ann.json', './BASIL/annotations/2010\\2b95d2cf-e979-4f9c-ae27-9a5370934f23_3_ann.json', './BASIL/annotations/2010\\38f7cbb7-5d6a-4c89-bcbd-8e164144172a_1_ann.json', './BASIL/annotations/2010\\38f7cbb7-5d6a-4c89-bcbd-8e164144172a_2_ann.json', './BASIL/annotations/2010\\38f7cbb7-5d6a-4c89-bcbd-8e164144172a_3_ann.json', './BASIL/annotations/2010\\45bd61bc-c356-4450-9e3a-cbfc862b09fd_1_ann.json', './BASIL/annotations/2010\\45bd61bc-c356-4450-9e3a-cbfc862b09fd_2_ann.json', './BASIL/annotations/2010\\45bd61bc-c356-4450-9e3a-cbfc862b09fd_3_ann.json', './BASIL/annotations/2010\\6b541575-99b1-40d2-8730-9bb868ee38ed_1_ann.json', './BASIL/annotations/2010\\6b541575-99b1-40d2-8730-9bb868ee38ed_2_ann.json', './BASIL/annotations/2010\\6b541575-99b1-40d2-8730-9bb868ee38ed_3_ann.json', './BASIL/annotations/2010\\6f95dcb9-e960-45ac-8c0e-91b85724c909_1_ann.json'

In [9]:
all_sentences = []
all_labels = []

for i in range(len(articles)):
    paragraphs = articles[i]["body-paragraphs"]
    sentences = [sent for para in paragraphs for sent in para]
    annotats = annotations[i]["phrase-level-annotations"]
    labels = [0 for _ in range(len(sentences))]
    for annot in annotats:
        if annot["id"][0] == 'p':
            id = int(annot["id"][1:])
            polarity = annot['polarity']
            if polarity == 'neg':
                labels[id] = 1
            elif polarity == 'pos':
                labels[id] = 2
    all_sentences.append(sentences)
    all_labels.append(labels)

sentence_data = [sent for sublist in all_sentences for sent in sublist]
label_data = [label for sublist in all_labels for label in sublist]

In [10]:
indices = list(range(len(sentence_data)))
random.shuffle(indices)

train_split = int(len(indices) * 0.7)
val_split = int(len(indices) * 0.15)

train_idx = indices[:train_split]
val_idx = indices[train_split : train_split+val_split]
test_idx = indices[train_split+val_split : ]


train_sentences = [sentence_data[i] for i in train_idx]
train_labels = [label_data[i] for i in train_idx]
val_sentences = [sentence_data[i] for i in val_idx]
val_labels = [label_data[i] for i in val_idx]
test_sentences = [sentence_data[i] for i in test_idx]
test_labels = [label_data[i] for i in test_idx]

In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

class BasilDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len     

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long),
        }



In [12]:
train_basil_dataset = BasilDataset(train_sentences, train_labels, tokenizer, 128)
val_basil_dataset = BasilDataset(val_sentences, val_labels, tokenizer, 128)
test_basil_dataset = BasilDataset(test_sentences, test_labels, tokenizer, 128)

train_basil_loader = DataLoader(train_basil_dataset, batch_size=16, shuffle=True)
val_basil_loader = DataLoader(val_basil_dataset, batch_size=16, shuffle=False)
test_basil_loader = DataLoader(test_basil_dataset, batch_size=16, shuffle=False)

In [16]:
num_out = 3
model_basil = BertClassifier(num_out)

optimizer = AdamW(model_basil.parameters(), lr=5e-6)
loss_fn = nn.CrossEntropyLoss()

model_basil.to(device)
loss_fn.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


CrossEntropyLoss()

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model_basil.train()
    total_loss = 0
    total_acc = 0

    for inputs in tqdm(train_basil_loader):
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        labels = inputs["label"].to(device)

        optimizer.zero_grad()
        output = model_basil(input_ids, attention_mask)
        loss = loss_fn(output, labels)

        total_loss += loss.item()
        total_acc += (output.argmax(dim=1) == labels).sum().item()

        loss.backward()
        optimizer.step()

    total_dev_loss = 0
    total_dev_acc = 0

    with torch.no_grad():
        model.eval()
        for inputs in tqdm(val_basil_loader):
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)
            labels = inputs["label"].to(device)

            output = model_basil(input_ids, attention_mask)
            loss = loss_fn(output, labels)

            total_dev_loss += loss.item()
            total_dev_acc += (output.argmax(dim=1) == labels).sum().item()

    print(f'Epochs: {epoch + 1}, Train Loss: {total_loss / len(train_basil_dataset): .3f}, Train Accuracy: {total_acc / len(train_basil_dataset): .3f}, Val Loss: {total_dev_loss / len(val_basil_dataset): .3f}, Val Accuracy: {total_dev_acc / len(val_basil_dataset): .3f}')



100%|██████████| 350/350 [00:32<00:00, 10.82it/s]
100%|██████████| 75/75 [00:02<00:00, 28.97it/s]


Epochs: 1, Train Loss:  0.039, Train Accuracy:  0.778, Val Loss:  0.035, Val Accuracy:  0.811


100%|██████████| 350/350 [00:30<00:00, 11.38it/s]
100%|██████████| 75/75 [00:02<00:00, 29.70it/s]


Epochs: 2, Train Loss:  0.034, Train Accuracy:  0.801, Val Loss:  0.032, Val Accuracy:  0.818


100%|██████████| 350/350 [00:30<00:00, 11.41it/s]
100%|██████████| 75/75 [00:02<00:00, 29.64it/s]


Epochs: 3, Train Loss:  0.030, Train Accuracy:  0.824, Val Loss:  0.031, Val Accuracy:  0.824


100%|██████████| 350/350 [00:30<00:00, 11.41it/s]
100%|██████████| 75/75 [00:02<00:00, 29.93it/s]


Epochs: 4, Train Loss:  0.025, Train Accuracy:  0.853, Val Loss:  0.031, Val Accuracy:  0.835


100%|██████████| 350/350 [00:31<00:00, 11.20it/s]
100%|██████████| 75/75 [00:02<00:00, 28.41it/s]

Epochs: 5, Train Loss:  0.019, Train Accuracy:  0.888, Val Loss:  0.034, Val Accuracy:  0.810





In [29]:
total_test_acc = 0

with torch.no_grad():
    model_basil.eval()
    for inputs in tqdm(test_basil_loader):
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        labels = inputs["label"].to(device)

        output = model_basil(input_ids, attention_mask)
        
        total_test_acc += (output.argmax(dim=1) == labels).sum().item()

print(f'Test Accuracy: {total_test_acc / len(test_basil_dataset): .3f}')

100%|██████████| 75/75 [00:02<00:00, 30.17it/s]

Test Accuracy:  0.802





In [48]:
import captum

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

sentences = ["The recent protests led by women once again turned chaotic, with several reports of emotional outbursts and irrational behavior.", "He will even be considered like a traitor because we don\u2019t accept to be offended like we just heard, and it should not happen."]

inputs = tokenizer(sentences, padding='max_length', max_length=128, truncation=True, return_tensors="pt")
input_ids = inputs['input_ids']
attention_masks = inputs['attention_mask']

print(input_ids)
baseline = torch.zeros_like(input_ids)
baseline[input_ids == 101] = 101
baseline[input_ids == 102] = 102
print(baseline)

tensor([[  101,  1109,  2793,  7853,  1521,  1118,  1535,  1517,  1254,  1454,
         22911,   117,  1114,  1317,  3756,  1104,  6438,  1149, 22224,  1116,
          1105, 27447,  4658,   119,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [49]:
net = model_basil

lig = captum.attr.LayerIntegratedGradients(net, net.bert.embeddings)

target = torch.tensor([1, 1])

net = net.to('cpu')
attributions, delta = lig.attribute(inputs=(input_ids, attention_masks), target = target, baselines = (baseline, attention_masks), n_steps=50, return_convergence_delta=True)


In [50]:
print(attributions.shape)
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

attributions_sum = summarize_attributions(attributions)

torch.Size([2, 128, 768])


In [51]:
logits = net(input_ids, attention_masks)
predicted_labels = torch.argmax(logits, dim=1)
print(logits)
print(predicted_labels)

tensor([[ 0.7248,  0.8876, -2.7444],
        [-0.7599,  2.1960, -2.0471]], grad_fn=<AddmmBackward0>)
tensor([1, 1])


In [52]:
from captum.attr import visualization as viz

viz_records = []
for i in range(len(sentences)):
    # Get predicted label for current input sentence
    logits = net(input_ids[i].unsqueeze(0), attention_masks[i].unsqueeze(0))
    predicted_label = torch.argmax(logits, dim=1).item()
    
    # Normalize attributions for current input sentence
    normalized_attributions = attributions[i]
    raw_inputs = tokenizer.convert_ids_to_tokens(input_ids[i])
    raw_inputs = [t for t in raw_inputs if t != '[PAD]']
    
    # Create VisualizationDataRecord object
    record = viz.VisualizationDataRecord(
        word_attributions=attributions_sum[i].tolist(),
        pred_class=predicted_labels[i].item(),
        pred_prob=logits[0][predicted_label].item(),
        true_class=None,
        attr_class=None,
        raw_input_ids=raw_inputs,
        attr_score=torch.sum(attributions_sum[i]),
        convergence_score=torch.sum(torch.abs(attributions_sum[i])).tolist()
    )
    viz_records.append(record)

# Create visualization
viz.visualize_text(viz_records)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,1 (0.89),,0.61,"[CLS] The recent protests led by women once again turned chaotic , with several reports of emotional out ##burst ##s and irrational behavior . [SEP]"
,,,,
,1 (2.20),,1.5,"[CLS] He will even be considered like a traitor because we don ’ t accept to be offended like we just heard , and it should not happen . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,1 (0.89),,0.61,"[CLS] The recent protests led by women once again turned chaotic , with several reports of emotional out ##burst ##s and irrational behavior . [SEP]"
,,,,
,1 (2.20),,1.5,"[CLS] He will even be considered like a traitor because we don ’ t accept to be offended like we just heard , and it should not happen . [SEP]"
,,,,
