<a href="https://colab.research.google.com/github/XuTiany1/Explainable_Misinformation_Detection/blob/main/Explainable_Misinformation_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

In [9]:
!pip install transformers datasets matplotlib captum

import torch
from transformers import BertTokenizer, BertForSequenceClassification, GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from captum.attr import IntegratedGradients, visualization

Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: captum
Successfully installed captum-0.7.0


# Load Pre-Trained Model and Tokenizer

In [10]:
# Load the dataset
dataset = load_dataset('liar')

# Preprocess the dataset
def preprocess_data(example):
    # Convert labels to binary: 0 for true ('pants-fire', 'false', 'barely-true') and 1 for misinformation ('half-true', 'mostly-true', 'true')
    misinfo_labels = ['pants-fire', 'false', 'barely-true']
    example['label'] = 1 if example['label'] in misinfo_labels else 0
    return example

# Apply preprocessing
dataset = dataset.map(preprocess_data)

# Split into train and test sets
train_dataset = dataset['train']
test_dataset = dataset['test']

Map:   0%|          | 0/10269 [00:00<?, ? examples/s]

Map:   0%|          | 0/1283 [00:00<?, ? examples/s]

Map:   0%|          | 0/1284 [00:00<?, ? examples/s]

# Train/Fine-Tune Model

In [None]:
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Tokenize the data
def tokenize_function(examples):
    return tokenizer(examples['statement'], truncation=True, padding='max_length', max_length=128)

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)

# Prepare for PyTorch
tokenized_train.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
tokenized_test.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

# DataLoader
from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_train, batch_size=8, shuffle=True)
test_dataloader = DataLoader(tokenized_test, batch_size=8)

# Fine-tune the model
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop
epochs = 1  # Set to 1 for quick demonstration; increase for better performance
for epoch in range(epochs):
    model.train()
    for batch in train_dataloader:
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
        labels = batch['label'].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/10269 [00:00<?, ? examples/s]

Map:   0%|          | 0/1283 [00:00<?, ? examples/s]



# Attention Visualization

In [None]:
def visualize_attention(text):
    inputs = tokenizer(text, return_tensors='pt')
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model(**inputs, output_attentions=True)
    attentions = outputs.attentions  # Tuple of attention weights from all layers

    # For simplicity, use the attention weights from the last layer
    attention = attentions[-1][0]  # Shape: [Heads, Tokens, Tokens]
    attention = attention.mean(dim=0)  # Average over heads

    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    # Plot heatmap
    plt.figure(figsize=(10, 10))
    plt.imshow(attention.detach().cpu().numpy(), cmap='hot', interpolation='nearest')
    plt.xticks(range(len(tokens)), tokens, rotation=90)
    plt.yticks(range(len(tokens)), tokens)
    plt.colorbar()
    plt.show()

In [None]:
# Visualize Attention sample testing

sample_text = "The Earth is flat and NASA faked the moon landing."
visualize_attention(sample_text)

# Counterfactual Generation with GPT-2

In [None]:
# Initialize GPT-2
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2')
gpt2_model.to(device)

def generate_counterfactual(text):
    prompt = f"Correct the following misinformation: {text}\nCorrection:"
    inputs = gpt2_tokenizer.encode(prompt, return_tensors='pt').to(device)
    outputs = gpt2_model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2)
    generated = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
    correction = generated.split('Correction:')[-1].strip()
    return correction

# Generate counterfactual
counterfactual = generate_counterfactual(sample_text)
print("Original:", sample_text)
print("Counterfactual:", counterfactual)

# Saliency Maps with Integrated Gradients

In [None]:
# Function to compute saliency map
def compute_saliency(model, text):
    model.eval()
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device)
    input_ids = inputs['input_ids']
    baseline = torch.zeros_like(input_ids).to(device)

    def forward_func(input_ids):
        outputs = model(input_ids, attention_mask=inputs['attention_mask'])
        return outputs.logits[:, 1]  # Get the logits for the 'misinformation' class

    ig = IntegratedGradients(forward_func)
    attributions, delta = ig.attribute(input_ids, baseline, target=1, return_convergence_delta=True)

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)

    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))

    # Visualize
    visualization.visualize_text([visualization.VisualizationDataRecord(
        word_attributions=attributions.cpu().detach().numpy(),
        pred_prob=0,
        pred_class='',
        true_class='',
        attr_class='',
        attr_score=attributions.sum(),
        raw_input=text
    )])

In [None]:
# Saliency Maps sample testing
compute_saliency(model, sample_text)

# MAIN CODE

In [None]:
def explain_misinformation(text):
    # Prediction
    inputs = tokenizer(text, return_tensors='pt').to(device)
    outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    pred = torch.argmax(probs).item()

    if pred == 1:
        print("Misinformation detected.")
        # Attention Visualization
        visualize_attention(text)
        # Saliency Map
        compute_saliency(model, text)
        # Counterfactual Generation
        counterfactual = generate_counterfactual(text)
        print("Counterfactual:", counterfactual)
    else:
        print("No misinformation detected.")

# Test the function
explain_misinformation(sample_text)