In [1]:
import json

path = "Bert_pipeline.ipynb"  # đúng tên file của bạn
with open(path, "r", encoding="utf-8") as f:
    nb = json.load(f)

nb.get("metadata", {}).pop("widgets", None)

for cell in nb.get("cells", []):
    cell.get("metadata", {}).pop("widgets", None)

with open(path, "w", encoding="utf-8") as f:
    json.dump(nb, f, ensure_ascii=False, indent=1)

print("Done: removed metadata.widgets")


Done: removed metadata.widgets


In [2]:
import torch
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification

# ====== EDIT THIS if your path is different ======
MODEL_DIR = "outputs/bert_best"
id2label = {0: "hate_speech", 1: "offensive_language", 2: "neither"}
# ===============================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained(MODEL_DIR)
model = BertForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
model.eval()

def predict(text, max_length=128):
    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_length,
        return_attention_mask=True,
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
        pred = int(probs.argmax())
    return pred, id2label[pred], probs


  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [3]:
# Explainability: Integrated Gradients (word/token highlight)

import numpy as np
from captum.attr import IntegratedGradients

def _merge_wordpieces(tokens, scores):
    # Merge BERT wordpieces (##) into whole words by summing scores.
    words = []
    word_scores = []
    cur = ""
    cur_score = 0.0

    for t, s in zip(tokens, scores):
        if t in ["[CLS]", "[SEP]", "[PAD]"]:
            continue
        if t.startswith("##"):
            cur += t[2:]
            cur_score += float(s)
        else:
            if cur:
                words.append(cur)
                word_scores.append(cur_score)
            cur = t
            cur_score = float(s)
    if cur:
        words.append(cur)
        word_scores.append(cur_score)
    return words, np.array(word_scores, dtype=np.float32)

def explain_ig(text, target_label=None, max_length=128, n_steps=50):
    model.eval()

    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_attention_mask=True,
    )
    input_ids = enc["input_ids"].to(device)
    attn_mask = enc["attention_mask"].to(device)

    def forward_embeds(embeds):
        out = model(inputs_embeds=embeds, attention_mask=attn_mask)
        return out.logits

    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attn_mask).logits
        probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
        pred_id = int(probs.argmax())
        pred_label = id2label[pred_id]

    if target_label is None:
        target = pred_id
    else:
        if isinstance(target_label, str):
            inv = {v: k for k, v in id2label.items()}
            target = inv[target_label]
        else:
            target = int(target_label)

    embeddings = model.get_input_embeddings()
    input_embeds = embeddings(input_ids)
    baseline_ids = torch.full_like(input_ids, tokenizer.pad_token_id)
    baseline_embeds = embeddings(baseline_ids)

    ig = IntegratedGradients(forward_embeds)
    attributions = ig.attribute(
        inputs=input_embeds,
        baselines=baseline_embeds,
        target=target,
        n_steps=n_steps
    )

    token_scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).detach().cpu().tolist())

    words, word_scores = _merge_wordpieces(tokens, token_scores)

    imp = np.abs(word_scores)
    if imp.max() > 0:
        imp = imp / imp.max()
    return pred_id, pred_label, probs, words, imp

def render_highlight_html(words, scores, max_words=80):
    words = words[:max_words]
    scores = scores[:max_words]
    spans = []
    for w, s in zip(words, scores):
        s = float(max(0.0, min(1.0, s)))
        spans.append(
            f'<span style="background: rgba(255, 0, 0, {0.15 + 0.75*s}); padding:2px 4px; margin:1px; border-radius:4px; display:inline-block;">{w}</span>'
        )
    return "<div style='line-height: 2.0;'>" + " ".join(spans) + "</div>"


In [4]:
from IPython.display import display, HTML

text = "You are such a disgusting idiot."
pred_id, pred_label, probs, words, imp = explain_ig(text, target_label=None)
print("Text:", text)
print("Pred:", pred_id, pred_label)
print("Probs [hate, offensive, neither]:", probs)

display(HTML(render_highlight_html(words, imp)))


  attn_output = torch.nn.functional.scaled_dot_product_attention(


Text: You are such a disgusting idiot.
Pred: 1 offensive_language
Probs [hate, offensive, neither]: [0.16691194 0.8239382  0.00914981]


In [8]:
text = "I Hate you!"
pred_id, pred_label, probs, words, imp = explain_ig(text, target_label=None)
print("Text:", text)
print("Pred:", pred_id, pred_label)
print("Probs [hate, offensive, neither]:", probs)

display(HTML(render_highlight_html(words, imp)))

Text: I Hate you!
Pred: 1 offensive_language
Probs [hate, offensive, neither]: [0.44031054 0.54600155 0.01368789]


In [9]:
text = "Can we just kill all the jews, they are so annoying."
pred_id, pred_label, probs, words, imp = explain_ig(text, target_label=None)
print("Text:", text)
print("Pred:", pred_id, pred_label)
print("Probs [hate, offensive, neither]:", probs)

display(HTML(render_highlight_html(words, imp)))

Text: Can we just kill all the jews, they are so annoying.
Pred: 1 offensive_language
Probs [hate, offensive, neither]: [0.34869078 0.46079636 0.19051279]


In [10]:
text = "Can we all just get along?"
pred_id, pred_label, probs, words, imp = explain_ig(text, target_label=None)
print("Text:", text)
print("Pred:", pred_id, pred_label)
print("Probs [hate, offensive, neither]:", probs)

display(HTML(render_highlight_html(words, imp)))

Text: Can we all just get along?
Pred: 2 neither
Probs [hate, offensive, neither]: [0.04125289 0.1371559  0.8215912 ]
