In [None]:
from __future__ import print_function, division
import os
import torch
import random
from torchvision.transforms import ToTensor, ToPILImage
import zipfile
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import RandomSampler, Sampler, Subset
from torchvision import transforms, utils
import torch.nn as nn
from tqdm import tqdm
from typing import Iterator, List, Callable, Tuple
from functools import partial
from math import *
from IPython.display import HTML
import pandas as pd
from sklearn.model_selection import train_test_split
from datetime import datetime
import time
import seaborn as sns
import random

from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support, classification_report, roc_curve, auc, precision_recall_curve

from torch.optim.lr_scheduler import StepLR
from transformers import BatchEncoding

from matplotlib import rc, cm
rc('animation', html='jshtml')

import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import matplotlib.animation as animation
%matplotlib notebook

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_df = pd.read_csv('./train_data.csv')
val_df = pd.read_csv('./val_data.csv')
test_df = pd.read_csv('./test_data.csv')

In [None]:
train_df['content'] = train_df['title'] + ' ' + train_df['text']
val_df['content'] = val_df['title'] + ' ' + val_df['text']
test_df['content'] = test_df['title'] + ' ' + test_df['text']

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# torch.serialization.add_safe_globals([BatchEncoding])

if not os.path.exists("./bert_train_encodings.pt"):
    train_encodings = tokenizer(train_df['content'].tolist(), truncation=True, padding=True, max_length=512, return_tensors='pt')
    val_encodings = tokenizer(val_df['content'].tolist(), truncation=True, padding=True, max_length=512, return_tensors='pt')
    test_encodings = tokenizer(test_df['content'].tolist(), truncation=True, padding=True, max_length=512, return_tensors='pt')
    torch.save(train_encodings, "./bert_train_encodings.pt")
    torch.save(val_encodings, "./bert_val_encodings.pt")
    torch.save(test_encodings, "./bert_test_encodings.pt")

else:
    train_encodings = torch.load("./bert_train_encodings.pt", weights_only= False)
    val_encodings = torch.load("./bert_val_encodings.pt", weights_only= False)
    test_encodings = torch.load("./bert_test_encodings.pt", weights_only= False)

In [None]:
class FakeNewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        #item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} # input_ids, attention_mask are the most important
        item = {}
        item['input_ids'] =  torch.tensor(self.encodings["input_ids"][idx])
        item['attention_mask'] = torch.tensor(self.encodings["attention_mask"][idx])
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = FakeNewsDataset(train_encodings, train_df['label'].tolist())
val_dataset = FakeNewsDataset(val_encodings, val_df['label'].tolist())
test_dataset = FakeNewsDataset(test_encodings, test_df['label'].tolist())

random_indices_train = random.sample(range(len(train_dataset)), 4000) # 32000
random_indices_val = random.sample(range(len(val_dataset)), 500) # 4000

train_subset = Subset(train_dataset, random_indices_train)
val_subset = Subset(val_dataset, random_indices_val)
# test_subset = Subset(test_dataset, range(500))

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

In [None]:
from transformers import BertForSequenceClassification
sample_indices = random.sample(range(len(test_dataset)), 10)
samples = [test_dataset[i] for i in sample_indices]
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.load_state_dict(torch.load('best_model_init_num_epochs-5lr-1e-05step_size-Nonegamma-None.pth', map_location=device))
model = model.to(device)

In [None]:
from captum.attr import IntegratedGradients, visualization

def compute_bert_outputs(model_bert, embedding_output, attention_mask):
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(dtype=next(model_bert.parameters()).dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    encoder_outputs = model_bert.encoder(embedding_output,
                                         extended_attention_mask)
    sequence_output = encoder_outputs[0]
    pooled_output = model_bert.pooler(sequence_output)
    outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
    return outputs  


class BertModelWrapper(nn.Module):
    
    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model
        
    def forward(self, embeddings, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones(embeddings.shape[0], embeddings.shape[1]).to(embeddings)
        outputs = compute_bert_outputs(self.model.bert, embeddings, attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.model.dropout(pooled_output)
        logits = self.model.classifier(pooled_output)
        return torch.softmax(logits, dim=1)[:, 1].unsqueeze(1)

def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records):
    attributions_sum = attributions[0].sum(dim=2).squeeze(0)
    
    attributions_norm = attributions_sum / torch.norm(attributions_sum)
    
    attributions_list = attributions_norm.cpu().detach().numpy().tolist()
    
    vis_data_records.append(visualization.VisualizationDataRecord(
        attributions_list,
        pred,
        pred_ind,
        label,
        "label",
        attributions_norm.sum().item(),
        tokens[:len(attributions_list)],
        delta))

bert_model_wrapper = BertModelWrapper(model)
ig = IntegratedGradients(bert_model_wrapper)
bert_model_wrapper.eval()
vis_data_records_ig = []
count = 0
for i, sample in enumerate(samples):
    input_ids = sample['input_ids'].to(device).unsqueeze(0)
    attention_mask = sample['attention_mask'].to(device).unsqueeze(0)
    true_label = sample['labels'].item()
    original_text = test_df.iloc[sample_indices[i]]['content']

    # Build a baseline input: [CLS] + [PAD] * N + [SEP]
    baseline_input_ids = input_ids.clone()
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id
    pad_token_id = tokenizer.pad_token_id

    # Identify special token positions
    sep_position = (input_ids == sep_token_id).nonzero(as_tuple=True)[1]
    cls_position = (input_ids == cls_token_id).nonzero(as_tuple=True)[1]

    # Replace all tokens with PAD
    baseline_input_ids.fill_(pad_token_id)

    # Restore [CLS] and [SEP] tokens in the baseline
    baseline_input_ids[0, cls_position] = cls_token_id
    baseline_input_ids[0, sep_position] = sep_token_id
        
    pad_token_id = tokenizer.pad_token_id

    bert_model_wrapper.zero_grad()
    input_embedding = bert_model_wrapper.model.bert.embeddings(input_ids)
    baseline_embedding = bert_model_wrapper.model.bert.embeddings(baseline_input_ids)

    pred = bert_model_wrapper(input_embedding, attention_mask).item()
    pred_ind = round(pred)
    if pred_ind != true_label:
        count += 1
        attributions_ig, delta = ig.attribute((input_embedding, attention_mask), baselines=(baseline_embedding, attention_mask), n_steps=50, return_convergence_delta=True)
        print(f"prediction: {pred_ind} ({pred}), label: {true_label}, delta: {delta}")

        tokens = tokenizer.convert_ids_to_tokens(input_ids[0].numpy().tolist())
        add_attributions_to_visualizer(attributions_ig, tokens, pred, pred_ind, true_label, delta, vis_data_records_ig)
    if count == 2:
        break
visualization.visualize_text(vis_data_records_ig)