In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/nepalihatee/train.json
/kaggle/input/nepalihatee/test.json
/kaggle/input/nepalihatee/val_final.json
/kaggle/input/nepalihatee/split_metadata.json
/kaggle/input/nepalihatee/train_final.json
/kaggle/input/fontss/Kalimati.ttf


In [2]:
!pip install transformers joblib

!pip install emoji regex -q
!pip install huggingface_hub





In [None]:
# =============================
# Kaggle-ready Captum Explainer
# =============================
!pip install --quiet --upgrade protobuf==3.20.3
!pip install --quiet captum indic-transliteration

import os
import numpy as np
import pandas as pd
import torch
import joblib
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from matplotlib import pyplot as plt, cm
from matplotlib.font_manager import FontProperties
import matplotlib.colors as mcolors
import regex
import re
import emoji
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm

# Captum
from captum.attr import LayerIntegratedGradients

# Optional transliteration
try:
    from indic_transliteration import sanscript
    from indic_transliteration.sanscript import transliterate
    TRANSLITERATION_AVAILABLE = True
except Exception:
    TRANSLITERATION_AVAILABLE = False

# ----------------------------
# CONFIGURATION
# ----------------------------
LOCAL_MODEL_PATH = '/kaggle/input/xlm-roberta-nepali-hate'
HF_MODEL_ID = "UDHOV/xlm-roberta-large-nepali-hate-classification"
TEST_FILE = '/kaggle/input/nepalihatee/test.json'
OUT_DIR = '/kaggle/working/captum_explanations'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_LENGTH = 256
os.makedirs(OUT_DIR, exist_ok=True)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Nepali font
FONT_PATH = '/kaggle/input/fontss/Kalimati.ttf'
def load_nepali_font(font_path):
    if not os.path.exists(font_path):
        return None
    try:
        fp = FontProperties(fname=font_path)
        print(f"✓ Loaded Nepali font: {fp.get_name()}")
        return fp
    except:
        return None
NEPALI_FONT = load_nepali_font(FONT_PATH)

# ----------------------------
# Preprocessing
# ----------------------------
DIRGHIKARAN_MAP = {
    "उ": "ऊ", "इ": "ई", "ऋ": "रि", "ए": "ऐ", "अ": "आ",
    "\u200d": "", "\u200c": "", "।": ".", "॥": ".",
    "ि": "ी", "ु": "ू"
}
def is_devanagari(text): return bool(regex.search(r'\p{Devanagari}', text)) if isinstance(text, str) else False
def roman_to_devanagari(text):
    if not TRANSLITERATION_AVAILABLE: return text
    try: return transliterate(text, sanscript.ITRANS, sanscript.DEVANAGARI)
    except: return text
def normalize_dirghikaran(text):
    for k,v in DIRGHIKARAN_MAP.items(): text=text.replace(k,v)
    return text
def clean_text(text):
    if not isinstance(text,str): return ""
    text = text.lower()
    text = re.sub(r"http\S+|www\S+", "", text)
    text = re.sub(r"@\w+|#\w+", "", text)
    text = emoji.replace_emoji(text,"")
    text = re.sub(r"\s+"," ",text).strip()
    return text
def preprocess_for_transformer(text):
    if not isinstance(text,str): return ""
    if not is_devanagari(text): text=roman_to_devanagari(text)
    text=clean_text(text)
    return normalize_dirghikaran(text)

# ----------------------------
# Load model
# ----------------------------
def load_model():
    try:
        if os.path.exists(LOCAL_MODEL_PATH):
            tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
            model = AutoModelForSequenceClassification.from_pretrained(LOCAL_MODEL_PATH).to(DEVICE).eval()
            le_path=os.path.join(LOCAL_MODEL_PATH,"label_encoder.pkl")
            le = joblib.load(le_path) if os.path.exists(le_path) else None
            print("✅ Model loaded from LOCAL path")
            return model, tokenizer, le
        else:
            tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
            model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_ID).to(DEVICE).eval()
            try:
                le_file = hf_hub_download(repo_id=HF_MODEL_ID, filename="label_encoder.pkl", cache_dir="models/cache")
                le = joblib.load(le_file)
                print("✅ Label encoder loaded from HF repo")
            except:
                from sklearn.preprocessing import LabelEncoder
                le = LabelEncoder()
                le.fit(['NO','OO','OR','OS'])
                print("⚠️ Label encoder not found in HF repo, using new encoder")
            print("✅ Model loaded from HuggingFace Hub")
            return model, tokenizer, le
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return None,None,None

# ----------------------------
# Captum utilities
# ----------------------------
def explain_text_with_captum(text, model, tokenizer, label_encoder, target=None, n_steps=50, out_dir=OUT_DIR, nepali_font=NEPALI_FONT):
    preprocessed = preprocess_for_transformer(text)
    if not preprocessed: return None
    encoding = tokenizer(preprocessed, return_tensors="pt", truncation=True, padding="max_length", max_length=MAX_LENGTH)
    input_ids = encoding['input_ids'].to(DEVICE)
    attention_mask = encoding['attention_mask'].to(DEVICE)
    
    # Prediction
    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attention_mask)
        probs = torch.softmax(out.logits, dim=-1)[0].cpu().numpy()
        pred_idx = int(np.argmax(probs))
        pred_label = label_encoder.classes_[pred_idx]
        pred_conf = float(probs[pred_idx])
    if target is None: target=pred_idx
    
    embedding_layer = model.roberta.embeddings.word_embeddings
    lig = LayerIntegratedGradients(lambda ids, mask: model(ids, attention_mask=mask).logits[:, target], embedding_layer)
    attributions, delta = lig.attribute(input_ids, baselines=torch.full_like(input_ids, tokenizer.pad_token_id),
                                        additional_forward_args=(attention_mask,), return_convergence_delta=True, n_steps=n_steps)
    attributions_sum = attributions.sum(dim=-1).squeeze(0)
    toks = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().tolist(), skip_special_tokens=False)
    
    # Aggregate word-level attributions
    words, groups = [], []
    current_tokens, current_word_pieces = [], []
    for i, tok in enumerate(toks):
        if tok.startswith("▁") or tok in [tokenizer.cls_token, tokenizer.sep_token]:
            if current_tokens:
                words.append("".join(current_word_pieces))
                groups.append(current_tokens)
            current_tokens=[i]
            current_word_pieces=[tok.replace("▁","")]
        else:
            current_tokens.append(i)
            current_word_pieces.append(tok.replace("▁",""))
    if current_tokens:
        words.append("".join(current_word_pieces))
        groups.append(current_tokens)
    
    word_attributions=[]
    for grp in groups:
        grp_vals = attributions_sum[grp].detach().cpu().numpy()
        score = float(np.sum(np.abs(grp_vals)))
        signed_score = float(np.sum(grp_vals))
        word="".join([toks[i].replace("▁","") for i in grp])
        word_attributions.append((word, score, signed_score))
    
    # Bar chart
    scores_plot=[s for _,s,_ in word_attributions]
    words_plot=[w for w,_,_ in word_attributions]
    maxscore=max(scores_plot) if scores_plot else 1.0
    fig, ax=plt.subplots(figsize=(max(6,0.6*len(words_plot)),4))
    ax.bar(range(len(words_plot)), scores_plot, tick_label=words_plot)
    ax.set_ylabel("Attribution (sum abs)")
    ax.set_title(f"Word attributions -> Pred: {pred_label} ({pred_conf:.2%})")
    if nepali_font:
        for lbl in ax.get_xticklabels():
            if regex.search(r"\p{Devanagari}", lbl.get_text()):
                lbl.set_fontproperties(nepali_font)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    barpath=os.path.join(out_dir,f"ig_bar_pred_{pred_label}_{abs(hash(preprocessed))%10**8}.png")
    fig.savefig(barpath,dpi=300)
    plt.close(fig)
    
    # Text heatmap
    cmap = cm.get_cmap("Reds")
    color_boxes = [mcolors.to_hex(cmap(s/maxscore if maxscore!=0 else 0.0)) for s in scores_plot]
    fig2, ax2=plt.subplots(figsize=(max(6,0.6*len(words_plot)),2))
    ax2.axis('off')
    x, y=0.01, 0.6
    for i, (w, s, sig) in enumerate(word_attributions):
        color=color_boxes[i]
        ax2.text(x,y,f" {w} ", fontsize=12,
                 bbox=dict(facecolor=color, alpha=0.7, boxstyle="round,pad=0.2"),
                 fontproperties=nepali_font if nepali_font and regex.search(r"\p{Devanagari}", w) else None)
        x+=(len(w)*0.03)+0.03
        if x>0.95: x=0.01; y-=0.3
    plt.tight_layout()
    heatpath=os.path.join(out_dir,f"ig_text_pred_{pred_label}_{abs(hash(preprocessed))%10**8}.png")
    fig2.savefig(heatpath,dpi=300,bbox_inches='tight')
    plt.close(fig2)
    
    return {
        "preprocessed": preprocessed,
        "pred_label": pred_label,
        "pred_confidence": pred_conf,
        "word_attributions": word_attributions,
        "bar_chart": barpath,
        "text_heatmap": heatpath,
        "convergence_delta": float(delta.sum().cpu().numpy())
    }

# ----------------------------
# Main execution (LIMITED SAMPLES PER CLASS)
# ----------------------------
if __name__ == "__main__":
    model, tokenizer, le = load_model()
    if model is None:
        raise RuntimeError("Model not loaded")

    df = pd.read_json(TEST_FILE)
    pd.set_option("display.max_colwidth", None)

    # ----------------------------
    # CONFIG: samples per class
    # ----------------------------
    SAMPLES_PER_CLASS = 3   # change to 2 or 3 as needed
    LABEL_COL = "Label_Multiclass"
    TEXT_COL = "Comment"

    # ----------------------------
    # Sample per class (reproducible)
    # ----------------------------
    sampled_df = (
        df
        .dropna(subset=[TEXT_COL, LABEL_COL])
        .groupby(LABEL_COL, group_keys=False)
        .apply(lambda x: x.sample(
            n=min(len(x), SAMPLES_PER_CLASS),
            random_state=SEED
        ))
        .reset_index(drop=True)
    )

    print("✅ Samples per class:")
    print(sampled_df[LABEL_COL].value_counts())

    # ----------------------------
    # Run Captum
    # ----------------------------
    for idx, row in tqdm(sampled_df.iterrows(), total=len(sampled_df)):
        text = str(row[TEXT_COL]).strip()
        true_label = row[LABEL_COL]

        if not text:
            continue

        try:
            res = explain_text_with_captum(
                text=text,
                model=model,
                tokenizer=tokenizer,
                label_encoder=le,
                out_dir=OUT_DIR,
                nepali_font=NEPALI_FONT
            )

            if res:
                print(
                    f"Saved Captum explanation | "
                    f"True: {true_label} | "
                    f"Pred: {res['pred_label']} "
                    f"({res['pred_confidence']:.3f})"
                )

        except Exception as e:
            print(f"⚠️ Failed for sample {idx}: {e}")



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.20.3 which is incompatible.
onnx 1.18.0 requires protobuf>=4.25.1, but you have protobuf 3.20.3 which is incompatible.
a2a-sdk 0.3.10 requires protobuf>=5.29.5, but you have protobuf 3.20.3 which is incompatible.
ray 2.51.1 requires click!=8.3.0,>=7.0, but you have click 8.3.0 which is incompatible.
bigframes 2.12.0 requires rich<14,>=12.4.4, but you have rich 14.2.0 which is incompatible.
tensorflow-metadata 1.17.2 requires protobuf>=4.25.2; python_version >= "3.11", but you have protobuf 3.20.3 which is

tokenizer_config.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/884 [00:00<?, ?B/s]

2025-12-17 09:15:33.560612: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765962933.870718      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765962933.951179      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

⚠️ Label encoder not found in HF repo, using new encoder
✅ Model loaded from HuggingFace Hub
✅ Samples per class:
Label_Multiclass
NO    3
OO    3
OR    3
OS    3
Name: count, dtype: int64


  8%|▊         | 1/12 [03:13<35:27, 193.41s/it]

Saved Captum explanation | True: NO | Pred: NO (0.932)


 17%|█▋        | 2/12 [06:03<29:56, 179.61s/it]

Saved Captum explanation | True: NO | Pred: NO (0.919)


 25%|██▌       | 3/12 [08:53<26:15, 175.06s/it]

Saved Captum explanation | True: NO | Pred: OO (0.748)


 33%|███▎      | 4/12 [12:05<24:15, 181.95s/it]

Saved Captum explanation | True: OO | Pred: OS (0.484)


 42%|████▏     | 5/12 [15:09<21:19, 182.82s/it]

Saved Captum explanation | True: OO | Pred: NO (0.506)


 50%|█████     | 6/12 [17:58<17:47, 177.95s/it]

Saved Captum explanation | True: OO | Pred: OO (0.837)


 58%|█████▊    | 7/12 [20:59<14:55, 179.02s/it]

Saved Captum explanation | True: OR | Pred: OR (0.557)


 67%|██████▋   | 8/12 [23:51<11:47, 176.84s/it]

Saved Captum explanation | True: OR | Pred: OR (0.989)


 75%|███████▌  | 9/12 [26:44<08:46, 175.54s/it]

Saved Captum explanation | True: OR | Pred: OO (0.725)


 83%|████████▎ | 10/12 [29:36<05:48, 174.45s/it]

Saved Captum explanation | True: OS | Pred: NO (0.828)


 92%|█████████▏| 11/12 [32:51<03:00, 180.66s/it]

Saved Captum explanation | True: OS | Pred: OO (0.547)


100%|██████████| 12/12 [35:41<00:00, 178.49s/it]

Saved Captum explanation | True: OS | Pred: OS (0.994)



