In [None]:
!pip install transformers 
!pip install lime

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertModel
from utils.models import bertCNN
import torch
from utils.data_processor import data_loader
import torch.nn.functional as F
import numpy as np
import os
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'


class Prediction:
    def __init__(self, model_dir, model_path, do_lower_case):
        if model_path.startswith('bert'):
            pretrained_weights = 'bert-base-multilingual-cased'
        elif model_path.startswith('xlm'):
            pretrained_weights = 'xlm-roberta-base'
        else:
            raise ValueError('error path!')
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_weights, do_lower_case=do_lower_case)
        if model_path.startswith('bert_cnn'):
            print('bert_cnn')
            embed_model = BertModel.from_pretrained(pretrained_weights)
            self.model = bertCNN(embed_model=embed_model, dropout=0.2, kernel_num=4, kernel_sizes=[3, 4, 5, 6], num_labels=3)
        else:
            print('sequence classification')
            self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_weights,
                                                                       num_labels=3,
                                                                       output_attentions=False,
                                                                       output_hidden_states=False)
        self.device = 'cuda' if cuda.is_available() else 'cpu'
        self.model.load_state_dict(torch.load(os.path.join(model_dir, model_path)))
        self.model.to(self.device)

    def convert_text_to_features(self, text):
        text = " ".join(text.split())

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=150,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        ids = torch.tensor([ids], dtype=torch.long).to(device)
        mask = torch.tensor([mask], dtype=torch.long).to(device)
        token_type_ids = torch.tensor([token_type_ids], dtype=torch.long).to(device)

        return ids, mask, token_type_ids

    def predictor(self, texts):
        examples = []
        print(texts)
        for example in texts:
            examples.append(self.convert_text_to_features(example))

        results = []
        for example in examples:
            with torch.no_grad():
                outputs = self.model(example[0], example[1], example[2])
                logits = outputs[0]
                logits = F.softmax(logits, dim=1)
                results.append(logits.cpu().detach().numpy()[0])

        results_array = np.array(results)
        # print(results_array)
        return results_array


In [None]:
from lime.lime_text import LimeTextExplainer
class_names = ['solidarity', 'anti-solidarity', 'other']
explainer = LimeTextExplainer(class_names=class_names)

# pre-trained classifier model
model_dir = 'saved_models/saved_weights/ensemble_more'
model_name = 'xlm_finetune_mlm_8000steps+trans+auto_data_0.734_0.781.bin'
prediction = Prediction(model_dir=model_dir, model_path=model_name, do_lower_case=False)

# use LIME to interpretate a tweet preidiction of the classifier model
example = "Countries of south-eastern europe, do not let the migrants in your country. We will not take them this time. Good luck! #Greece #Macedonia #Serbia #Bosnia #Croatia #Hungary #Slovenia #Migration #HumanRightsRefugee #NeverAgain2015 #Turkey #Syria #refugee #RefugeesWelcome https://twitter.com/DerSteirische/status/1233868109045518337"
exp = explainer.explain_instance(example, prediction.predictor, top_labels=1, num_samples=2500)
exp.show_in_notebook(text=example)