# XAI - LIME AND SHAP

### Load the model

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel
from transformers import BlipProcessor, BlipForImageTextRetrieval
from torchvision import transforms

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

# Paths (update these as needed)
csv_path = '/content/gdrive/MyDrive/MMD1/images_description.csv'
image_dir = '/content/gdrive/MyDrive/MMD1/Images'
model_path = '/content/gdrive/MyDrive/MMD1/best_model_blip_bb.pth'

# Data loading
df = pd.read_csv(csv_path)
label_map = {
    'non-aggressive': 0,
    'gendered aggression': 1,
    'religious aggression': 2,
    'political aggression': 3,
    'others': 4
}
df['label_encoded'] = df['Label'].map(label_map)

# Preprocessing
image_preprocess = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Tokenizer & models
bangla_bert_model_name = "csebuetnlp/banglabert"
bangla_bert_tokenizer = AutoTokenizer.from_pretrained(bangla_bert_model_name)
bangla_bert_model = AutoModel.from_pretrained(bangla_bert_model_name).to(device)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

class MemeDataset(Dataset):
    def __init__(self, df, image_dir, tokenizer, max_length=128):
        self.df = df
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_dir, row['image_name'])

        # Load and preprocess image
        try:
            image = Image.open(image_path).convert('RGB')
            image = image_preprocess(image)
        except:
            # Return blank image if loading fails
            image = torch.zeros(3, 384, 384)

        # Tokenize text
        text = str(row['Captions'])
        text_encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        label = row['label_encoded']

        return {
            'image': image,
            'input_ids': text_encoding['input_ids'].flatten(),
            'attention_mask': text_encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Split dataset
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['Label'], random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, stratify=train_df['Label'], random_state=42)

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")
print(f"Test size: {len(test_df)}")

# Create datasets and dataloaders
train_dataset = MemeDataset(train_df, image_dir, bangla_bert_tokenizer)
val_dataset = MemeDataset(val_df, image_dir, bangla_bert_tokenizer)
test_dataset = MemeDataset(test_df, image_dir, bangla_bert_tokenizer)

batch_size = 16  # Adjust based on your GPU memory

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Model
class MultimodalClassifier(nn.Module):
    def __init__(self, text_embed_dim=768, image_embed_dim=768, num_classes=5, dropout=0.2):
        super().__init__()
        self.text_model = bangla_bert_model
        self.image_model = blip_model.vision_model
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(text_embed_dim + image_embed_dim, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, input_ids, attention_mask, image):
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_features = text_outputs.last_hidden_state[:, 0, :]
        image_outputs = self.image_model(pixel_values=image)
        image_features = image_outputs.last_hidden_state[:, 0, :]
        combined = torch.cat([text_features, image_features], dim=1)
        combined = self.dropout(combined)
        x = F.relu(self.fc1(combined))
        x = self.dropout(x)
        logits = self.fc2(x)
        return logits

model = MultimodalClassifier().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))

!pip install lime
!pip install shap
import shap
import lime

Using device: cuda
Mounted at /content/gdrive


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/586 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/443M [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

Some weights of BlipForImageTextRetrieval were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['itm_head.bias', 'itm_head.weight', 'text_encoder.embeddings.LayerNorm.bias', 'text_encoder.embeddings.LayerNorm.weight', 'text_encoder.embeddings.position_embeddings.weight', 'text_encoder.embeddings.word_embeddings.weight', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.attention.output.dense.bias', 'text_encoder.encoder.layer.0.attention.output.dense.weight', 'text_encoder.encoder.layer.0.attention.self.key.bias', 'text_encoder.encoder.layer.0.attention.self.key.weight', 'text_encoder.encoder.layer.0.attention.self.query.bias', 'text_encoder.encoder.layer.0.attention.self.query.weight', 'text_encoder.encoder.layer.0.attention.self.value.bias', 'text_encoder.encoder.layer.0.attention.self.value.weight', 'text_encoder.

Train size: 13961
Validation size: 1552
Test size: 3879
Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=0b4fe29b4fb5b77de1a9e1f6f8fa10ea60e366c7cad3df0957cfaed93871097d
  Stored in directory: /root/.cache/pip/wheels/85/fa/a3/9c2d44c9f3cd77cf4e533b58900b2bf4487f2a17e8ec212a3d
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1


### LIME AND SHAP (TEXT + IMAGE) -> LOCAL

In [None]:
import lime
import lime.lime_tabular
import lime.lime_image
import lime.lime_text
from lime import submodular_pick
from skimage.segmentation import mark_boundaries
import shap
import numpy as np
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from IPython.display import HTML, display  # For displaying HTML in notebooks

# Assume model, model_path, test_loader, bangla_bert_tokenizer, and label_names are defined elsewhere

# Load best model (make sure in eval mode)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
MAX_LENGTH = 128

# Classification labels (should match your label encoding)
label_names = ['non-aggressive', 'gendered aggression', 'religious aggression', 'political aggression', 'others']

def lime_image_explanation(model, image, device, top_labels=5, num_samples=1000):
    """
    Explain image classification using LIME
    """
    def batch_predict(images_np):
        model.eval()
        batch = torch.tensor(images_np).permute(0, 3, 1, 2).float().to(device)
        dummy_input_ids = torch.zeros((len(images_np), MAX_LENGTH), dtype=torch.long).to(device)
        dummy_attention_mask = torch.ones((len(images_np), MAX_LENGTH), dtype=torch.long).to(device)
        with torch.no_grad():
            outputs = model(input_ids=dummy_input_ids, attention_mask=dummy_attention_mask, image=batch)
        return torch.softmax(outputs, dim=1).detach().cpu().numpy()
    pil_img = transforms.ToPILImage()(image.cpu())
    explainer = lime.lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(np.array(pil_img),
                                             batch_predict,
                                             top_labels=top_labels,
                                             hide_color=0,
                                             num_samples=num_samples)
    return explanation

def lime_text_explanation(model, tokenizer, text, device, num_samples=1000):
    """
    Explain text classification using LIME with proper visualization
    """
    def predict_proba(texts):
        model.eval()
        if isinstance(texts, str):
            texts = [texts]
        inputs = tokenizer(texts, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt').to(device)
        dummy_image = torch.zeros((len(texts), 3, 224, 224)).to(device)
        with torch.no_grad():
            outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], image=dummy_image)
        return torch.softmax(outputs, dim=1).detach().cpu().numpy()
    #explainer = lime.lime_text.LimeTextExplainer(class_names=label_names)
    explainer = lime.lime_text.LimeTextExplainer(
        class_names=label_names,
        split_expression=lambda x: x.split()
    )
    explanation = explainer.explain_instance(text, predict_proba, num_features=10, num_samples=num_samples, labels=list(range(len(label_names))))
    return explanation

def shap_image_explanation(model, image, device, n_evals=100):
    """
    Explain image classification using SHAP for a single image.
    """
    def predict(images_np):
        model.eval()
        img_tensor = torch.tensor(images_np).float().to(device)
        img_tensor = img_tensor.permute(0, 3, 1, 2)
        dummy_input_ids = torch.zeros((img_tensor.shape[0], MAX_LENGTH), dtype=torch.long).to(device)
        dummy_attention_mask = torch.ones((img_tensor.shape[0], MAX_LENGTH), dtype=torch.long).to(device)
        with torch.no_grad():
            outputs = model(input_ids=dummy_input_ids, attention_mask=dummy_attention_mask, image=img_tensor)
        return torch.softmax(outputs, dim=1).detach().cpu().numpy()
    img_np = image.permute(1, 2, 0).cpu().numpy()
    img_np = np.expand_dims(img_np, axis=0)
    masker = shap.maskers.Image("blur(224,224)", img_np[0].shape)
    explainer = shap.Explainer(predict, masker, output_names=label_names)
    shap_values = explainer(img_np, max_evals=n_evals, batch_size=1, outputs=shap.Explanation.argsort.flip[:5])
    return shap_values

def shap_text_explanation(model, tokenizer, text, device):
    """
    Explain text classification using SHAP
    """
    def predict_text_only(texts_list):
        model.eval()
        if isinstance(texts_list, str):
            texts_list = [texts_list]
        elif isinstance(texts_list, np.ndarray):
            texts_list = texts_list.tolist()
        elif not isinstance(texts_list, list):
            texts_list = [str(t) for t in texts_list]
        inputs = tokenizer(texts_list, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt').to(device)
        dummy_image = torch.zeros((len(texts_list), 3, 224, 224)).to(device)
        with torch.no_grad():
            outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], image=dummy_image)
        return torch.softmax(outputs, dim=1).detach().cpu().numpy()
    masker = shap.maskers.Text(tokenizer, mask_token=tokenizer.mask_token)
    explainer = shap.Explainer(predict_text_only, masker, output_names=label_names)
    shap_values = explainer([text])
    return shap_values

def run_explanations(model, tokenizer, test_loader, device, num_samples=3):
    """
    Run LIME and SHAP explanations on sample test cases with correct true/predicted labels
    """
    try:
        test_batch = next(iter(test_loader))
    except StopIteration:
        print("Test dataloader is empty.")
        return
    num_samples = min(num_samples, test_batch['image'].shape[0])
    images_batch = test_batch['image'][:num_samples]
    input_ids_batch = test_batch['input_ids'][:num_samples]
    attention_mask_batch = test_batch['attention_mask'][:num_samples]
    labels_batch = test_batch['label'][:num_samples]
    print("\n" + "="*50)
    print("Running LIME and SHAP Explanations")
    print("="*50 + "\n")
    for i in range(num_samples):
        image = images_batch[i]
        input_ids = input_ids_batch[i].unsqueeze(0).to(device)
        attention_mask = attention_mask_batch[i].unsqueeze(0).to(device)
        label = labels_batch[i].item()
        text = tokenizer.decode(input_ids_batch[i], skip_special_tokens=True)
        # Get model prediction (make sure image, text match)
        with torch.no_grad():
            output = model(input_ids=input_ids, attention_mask=attention_mask, image=image.unsqueeze(0).to(device))
            pred_probs = torch.softmax(output, dim=1)
            pred_label = torch.argmax(pred_probs, dim=1).item()
        print(f"\nSample {i+1}/{num_samples}")
        print(f"True Label: {label_names[label]}")
        print(f"Text: {text[:200]}...\n")
        print(f"Predicted Label: {label_names[pred_label]} (Confidence: {pred_probs[0][pred_label]:.2f})")
        # LIME Image Explanation
        print("\nLIME Image Explanation:")
        lime_img_exp = lime_image_explanation(model, image, device)
        if lime_img_exp.top_labels:
            temp, mask = lime_img_exp.get_image_and_mask(lime_img_exp.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
            plt.figure(figsize=(8, 6))
            plt.imshow(mark_boundaries(temp / 255.0, mask))
            plt.title(f"LIME Explanation for {label_names[pred_label]}")
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(f'lime_image_example_{i+1}.png', dpi=300, bbox_inches='tight')
            plt.show()
        else:
            print("LIME image explanation returned no top labels.")
        # LIME Text Explanation
        print("\nLIME Text Explanation:")
        lime_text_exp = lime_text_explanation(model, tokenizer, text, device)
        try:
            display(HTML(lime_text_exp.as_html()))
        except:
            print(lime_text_exp.as_list(label=pred_label))
        with open(f'lime_text_example_{i+1}.html', 'w', encoding='utf-8') as f:
            f.write(lime_text_exp.as_html())
        print(f"Saved LIME text explanation to lime_text_example_{i+1}.html")
        # SHAP Image Explanation
        print("\nSHAP Image Explanation (this may take a while)...")
        shap_img_exp = shap_image_explanation(model, image, device, n_evals=100)
        plt.figure(figsize=(8, 6))
        shap.image_plot([shap_img_exp.values[0]], -image.permute(1, 2, 0).cpu().numpy()[np.newaxis, ...], labels=[label_names[pred_label]])
        plt.title(f"SHAP Explanation for {label_names[pred_label]}")
        plt.tight_layout()
        plt.savefig(f'shap_image_example_{i+1}.png', dpi=300, bbox_inches='tight')
        plt.show()
        # SHAP Text Explanation
        print("\nSHAP Text Explanation:")
        shap_text_exp = shap_text_explanation(model, tokenizer, text, device)
        plt.figure(figsize=(10, 6))
        shap.plots.text(shap_text_exp[0])
        plt.title(f"SHAP Text Explanation for Sample {i+1}")
        plt.tight_layout()
        plt.savefig(f'shap_text_example_{i+1}.png', dpi=300, bbox_inches='tight')
        plt.show()
        print("\n" + "-"*50 + "\n")

# Execute the explanations
run_explanations(model, bangla_bert_tokenizer, test_loader, device, num_samples=3)

Output hidden; open in https://colab.research.google.com to view.