In [1]:
import os
import random
import functools
import csv
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score, precision_score, recall_score
from skmultilearn.model_selection import iterative_train_test_split
from sklearn.model_selection import StratifiedKFold


from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
)

from transformers_interpret import MultiLabelClassificationExplainer

In [2]:
def tokenize_examples(examples, tokenizer, classes):
    text = f"Issue: {examples['issue']}.\nAnswer: {examples['post_text']}"
    labels = [examples[label] for label in classes]
    tokenized_inputs = tokenizer(text, truncation=True, max_length=700, padding=True)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs


# define custom batch preprocessor
def collate_fn(batch, tokenizer):
    dict_keys = ['input_ids', 'attention_mask', 'labels']
    d = {k: [dic[k] for dic in batch] for k in dict_keys}
    d['input_ids'] = torch.nn.utils.rnn.pad_sequence(
        d['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
    )
    d['attention_mask'] = torch.nn.utils.rnn.pad_sequence(
        d['attention_mask'], batch_first=True, padding_value=0
    )
    d['labels'] = torch.stack(d['labels']).type(torch.float)
    return d

In [3]:
from datasets import load_dataset
    
ds = load_dataset('timonziegenbein/appropriateness-corpus')
classes = [
    'Excessive Intensity',
    'Emotional Deception',
    'Missing Seriousness',
    'Missing Openness',
    'Unclear Meaning',
    'Missing Relevance',
    'Confusing Reasoning',
    'Detrimental Orthography',
    'Reason Unclassified'
]

class2id = {class_:id for id, class_ in enumerate(classes)}
id2class = {id:class_ for class_, id in class2id.items()}

model_name = 'multilabel_mistral'

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token #= tokenizer.special_tokens_map['pad_token']
tokenized_ds = ds.map(functools.partial(tokenize_examples, tokenizer=tokenizer, classes=classes), batched=False)
tokenized_ds = tokenized_ds.with_format('torch')

labels = tokenized_ds['train']['labels']
label_weights = torch.ones(len(classes))
print(label_weights)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])


In [4]:
# qunatization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True, # enable 4-bit quantization
    bnb_4bit_quant_type = 'nf4', # information theoretically optimal dtype for normally distributed weights
    bnb_4bit_use_double_quant = True, # quantize quantized weights //insert xzibit meme
    bnb_4bit_compute_dtype = torch.bfloat16 # optimized fp format for ML
)

# load model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    device_map="cuda:0",
    quantization_config=quantization_config,
    num_labels=len(classes),
    problem_type="multi_label_classification",
)
model.config.pad_token_id = tokenizer.pad_token_id

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
cls_explainer = MultiLabelClassificationExplainer(model, tokenizer)

text = "Issue: India has the potential to lead the world:.\nAnswer: stupid india they actually really suck. so, BOO INDIA BOO!"

word_attributions = cls_explainer(text)

In [19]:
word_attributions

{'LABEL_0': [('[CLS]', 0.0),
  ('▁Issue', 0.013146981895163794),
  (':', -0.0013371696043348238),
  ('▁India', -0.00894609801953715),
  ('▁has', 0.8839374518453498),
  ('▁the', -0.011381605604238244),
  ('▁potential', 0.04311363652346493),
  ('▁to', 0.13874644952693216),
  ('▁lead', 0.01331531999228797),
  ('▁the', 0.08052670360180214),
  ('▁world', 0.023100105466326445),
  (':', -0.021043718394992198),
  ('▁.', 0.0003411970252437799),
  ('▁Answer', 0.00040444981959140687),
  (':', -0.0073317900444397595),
  ('▁stupid', 0.09524187597528284),
  ('▁india', -0.06423960001556826),
  ('▁they', 0.13550141199386959),
  ('▁actually', -0.0059014663633441735),
  ('▁really', 0.16093917350852688),
  ('▁suck', 0.256274797421075),
  ('▁.', -0.13510445092250212),
  ('▁so', 0.10576543138755036),
  ('▁,', -0.02081970671701328),
  ('▁BOO', 0.05797880908522764),
  ('▁INDIA', -0.030271763622700114),
  ('▁BOO', -0.03362972216534403),
  ('▁!', 0.17611191817966657),
  ('[SEP]', 0.0)],
 'LABEL_1': [('[CLS]', 

In [20]:
cls_explainer.visualize("multilabel_viz.html")

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,(0.38),LABEL_0,1.84,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.35),LABEL_1,-0.87,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.41),LABEL_2,1.65,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.43),LABEL_3,0.6,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.53),LABEL_4,-1.2,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,


n/a,Prediction Score,Attribution Label,Attribution Score,Word Importance
,(0.38),LABEL_0,1.84,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.35),LABEL_1,-0.87,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.41),LABEL_2,1.65,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.43),LABEL_3,0.6,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
,(0.53),LABEL_4,-1.2,"[CLS] ▁Issue : ▁India ▁has ▁the ▁potential ▁to ▁lead ▁the ▁world : ▁. ▁Answer : ▁stupid ▁india ▁they ▁actually ▁really ▁suck ▁. ▁so ▁, ▁BOO ▁INDIA ▁BOO ▁! [SEP]"
,,,,
