In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [None]:
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.utils.rnn import pad_sequence
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification, Trainer, TrainingArguments


In [None]:
class BertWithAttentionSupervision(nn.Module):
    def __init__(self, num_labels=2):
        super(BertWithAttentionSupervision, self).__init__()
        self.bert_classifier = XLMRobertaForSequenceClassification.from_pretrained("xlm-roberta-base", num_labels=num_labels)
        self.attention_loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask, labels=None, ground_truth_attention=None, token_type_ids=None):
        # Adjust input for models like XLM-Roberta that do not use token_type_ids
        if token_type_ids is not None:
            outputs = self.bert_classifier(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           token_type_ids=token_type_ids,
                                           labels=labels,
                                           output_attentions=True,
                                           return_dict=True)
        else:
            outputs = self.bert_classifier(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           labels=labels,
                                           output_attentions=True,
                                           return_dict=True)

        classification_loss = outputs.loss if labels is not None else 0  # Set to 0 if no labels are provided
        logits = outputs.logits

        # Extract attention scores from the final layer
        attention_scores = outputs.attentions[-1][:, :, 0, :]  # CLS token attention across all heads
        avg_attention_scores = attention_scores.mean(dim=1)  # Average attention scores across heads

        # Calculate attention loss only if ground_truth_attention is provided
        attention_loss = 0
        if ground_truth_attention is not None:
            if ground_truth_attention.size(0) != avg_attention_scores.size(0):
                raise ValueError(f"Expected ground_truth_attention to be of size {avg_attention_scores.size(0)}, but got {ground_truth_attention.size(0)}")
            attention_loss = self.attention_loss_fn(avg_attention_scores, ground_truth_attention)

        total_loss = classification_loss + attention_loss

        return total_loss, logits

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
model = BertWithAttentionSupervision(num_labels=2).to(device)
model.load_state_dict(torch.load("/path"))
model.eval()

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from collections import defaultdict

def calculate_logits(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
    
    # Extract only the necessary input_ids and attention_mask for XLM-Roberta
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    with torch.no_grad():
        # Only pass input_ids and attention_mask (no token_type_ids needed for XLM-Roberta)
        total_loss, logits = model(input_ids=input_ids, attention_mask=attention_mask)
    
    probabilities = torch.softmax(logits, dim=-1)  # Convert logits to probabilities
    return probabilities

def ngrams(text, n):
    tokens = text.split()
    return [tokens[i:i+n] for i in range(len(tokens) - n + 1)]

def calculate_explainability(model, tokenizer, original_text):
    original_probs = calculate_logits(model, tokenizer, original_text)  # Logits for the full text
    
    # Initialize explainability score
    explainability_scores = defaultdict(float)
    token_count = defaultdict(int)
    
    # Loop over unigram, bigram, and trigram
    ngram_weights = {1: 0.5, 2: 0.3, 3: 0.2}
    
    for n in [1, 2, 3]:
        ngram_list = ngrams(original_text, n)
        
        for ngram in ngram_list:
            ngram_text = " ".join(ngram)
            ngram_probs = calculate_logits(model, tokenizer, ngram_text)
            
            # Calculate probability difference from original
            prob_diff = torch.abs(original_probs - ngram_probs).mean().item()
            
            # Assign the explainability score to each token in the ngram
            for token in ngram:
                explainability_scores[token] += ngram_weights[n] * prob_diff
                token_count[token] += 1
    
    # Average explainability scores by the number of times each token appeared
    for token in explainability_scores:
        explainability_scores[token] /= token_count[token]
    
    return explainability_scores

def normalize_scores(explainability_scores):
    # Sum all explainability scores
    total_score = sum(explainability_scores.values())
    
    # Normalize each score so that the sum is 1
    if total_score > 0:
        explainability_scores = {token: score / total_score for token, score in explainability_scores.items()}
    
    return explainability_scores

In [None]:
data = pd.read_excel('path')
texts = data['text'].astype(str).tolist()

final_scores = []

for i in range(len(texts)):
    
    # Calculate explainability for each ngram
    explainability_scores = calculate_explainability(model, tokenizer, texts[i])

    # Normalize the scores so that their sum equals 1
    normalized_explainability_scores = normalize_scores(explainability_scores)
    
    
    list1=[]
    # Print ranked tokens
    for token, score in normalized_explainability_scores.items():
        list1.append(score)
    
    final_scores.append(list1)
    print(i)

In [None]:
data['ngram_scores'] = final_scores

In [None]:
data.to_excel('save file')