Simple visualization for pytorch model, Test version for classification task
Implementation of paper: Axiomatic Attribution for Deep Networks
- python > 3.6
- pytorch > 1.4
- Colr > 0.9
pip install transformer_viewer
import numpy as np
import torch
import json
from transformers import *
from transformer_viewer import Glimpse
# Load dict
with open(PATH_ID2LABEL) as infile: idx_label_map = json.load(infile)
# Load model
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH, num_labels=len(idx_label_map))
def adaptor_embed(model, sign):
if sign == 'value':
return model.bert.embeddings.word_embeddings.weight.data
elif sign == 'grad':
return model.bert.embeddings.word_embeddings.weight.grad
else:
pass
def adaptor_model(text, model, tokenizer, device, label=None):
_input = tokenizer(
text = text,
add_special_tokens=True,
return_tensors="pt"
)
_input.to(device=device)
if label is None:
with torch.no_grad():
outputs = model(**_input)
logits = outputs[0]
logits = logits.cpu().detach().numpy()
label_id = np.argmax(logits, axis=1)
return _input["input_ids"], label_id
else:
_input["labels"] = torch.tensor(label).to(device, dtype=torch.long)
outputs = model(**_input)
loss = outputs[0]
return loss
viewer = Glimpse(model, tokenizer, adaptor_embed, adaptor_model, device, spliter=' ', id2label=idx_label_map)
viewer.color_bar()
viewer.view("乌鲁木齐市 新增 一处 城市 中心 旅游 目的地", "travel")
viewer.view("郭晶晶 曾 撮合 吴敏霞 与 章子怡 前男友 , 拒绝 豪门 平淡 才 是 真", "sports")
Glimplse(model, tokenizer, adaptor_embed, adaptor_model, device, spliter=' ', id2label=None, step=20)
parameter | type | description | example |
---|---|---|---|
model | object | pytorch model | Bert |
tokenizer | transformers tokenizer | BertTokenizer | |
adaptor_embed | func | extract embedding and grad | see example |
adaptor_model | func | output loss target and label from different model | see example |
device | torch device | see example | |
spliter | str | how to connect tokens | ' ' for english " for chinese |
id2label | dict | mapping from id to label | {0:'sports'} |
parameter | type | description | example |
---|---|---|---|
text | str | input text | '我 爱 中国' |
label | int or str | 'car' |