In [59]:
import pickle
import numpy as np
import re
import pandas as pd
from nltk.corpus import stopwords
import numpy as np
import sklearn
import nltk
from sklearn.naive_bayes import MultinomialNB
import warnings
warnings.filterwarnings('ignore')
from lime import lime_text
from sklearn.pipeline import make_pipeline
from lime.lime_text import LimeTextExplainer
from nltk.stem import PorterStemmer, WordNetLemmatizer
lemmetizer = WordNetLemmatizer()
stemmer = PorterStemmer()

In [60]:
class_names = np.asarray([u'ARTS', u'ARTS & CULTURE', u'BLACK VOICES', u'BUSINESS', u'COLLEGE',
 u'COMEDY', u'CRIME', u'EDUCATION', u'ENTERTAINMENT', u'FIFTY', u'GOOD NEWS',
 u'GREEN', u'HEALTHY LIVING', u'IMPACT', u'LATINO VOICES', u'MEDIA', u'PARENTS',
 u'POLITICS', u'QUEER VOICES', u'RELIGION', u'SCIENCE', u'SPORTS', u'STYLE',
 u'TASTE', u'TECH', u'THE WORLDPOST', u'TRAVEL', u'WEIRD NEWS', u'WOMEN', u'WORLD NEWS', u'WORLDPOST'])
class_names.shape

(31,)

In [61]:
model_filename = 'interests_classifier.sav'
vect_filename = 'vectorize_interests_classifier.sav'

def load_models(model_filename, vect_filename):
    loaded_model = pickle.load(open(model_filename, 'rb'))
    vectorize = pickle.load(open(vect_filename, 'rb'))
    return loaded_model, vectorize

loaded_model, vectorize = load_models(model_filename, vect_filename)

In [62]:
def get_words(headlines_list):
    headlines = headlines_list[0]   
    author_names = [x for x in headlines_list[1].lower().replace('and',',').replace(' ', '').split(',') if x != '']
    headlines_only_letters = re.sub('[^a-zA-Z]', ' ', headlines)
    words = nltk.word_tokenize(headlines_only_letters.lower())
    stops = set(stopwords.words('english'))
    meaningful_words = [lemmetizer.lemmatize(w) for w in words if w not in stops]
    return ' '.join(meaningful_words + author_names)

In [82]:
input_val = np.asarray([u"Cricket is a beatiful sport. Baseball, hockey. Dhoni is the best!!!",
       u'some_author'])

def clean_input_and_vectorize(input_val):
    cleanHeadlines_list = []
    cleanHeadline = get_words(input_val) #Processing the data and getting words with no special characters, numbers or html tags
    cleanHeadlines_list.append( cleanHeadline )
    tfidwords_input = vectorize.transform(cleanHeadlines_list)
    return tfidwords_input, cleanHeadlines_list

tfidwords_input, cleanHeadlines_list = clean_input_and_vectorize(input_val)

In [83]:
def predict(tfidwords_input):
    return loaded_model.predict(tfidwords_input)[0]

pred_category = predict(tfidwords_input)

In [84]:
def get_explainer(class_names, loaded_model, vectorize, cleanHeadlines_list, tfidwords_input):
    explainer = LimeTextExplainer(class_names=class_names)
    c = make_pipeline(vectorize, loaded_model)
    exp = explainer.explain_instance(cleanHeadlines_list[0], c.predict_proba, labels=range(class_names.shape[0]))
    return exp

def explain_all(exp, class_names):
    all_explainers = {}
    print 'Predicted class =', loaded_model.predict(tfidwords_input_test).reshape(1,-1)[0,0]
    for idx in range(class_names.shape[0]):
        try:
            all_explainers[class_names[idx]] = (exp.as_list(label=i))
        except:
            print("Error for {}".format(i))
    return all_explainers

exp = get_explainer(class_names, loaded_model, vectorize, cleanHeadlines_list, tfidwords_input)
all_explainers = explain_all(exp, class_names)
print(all_explainers)

Predicted class = SPORTS
{u'BUSINESS': [(u'hockey', 0.2976336465958332), (u'baseball', 0.2688058903031485), (u'cricket', -0.2655826509364169), (u'sport', 0.20616728084130584), (u'best', -0.02856751391295393), (u'beatiful', -0.004072770293723184), (u'some_author', -0.0035699197318947037), (u'dhoni', -0.002496119789379354)], u'ENTERTAINMENT': [(u'hockey', 0.2976336465958332), (u'baseball', 0.2688058903031485), (u'cricket', -0.2655826509364169), (u'sport', 0.20616728084130584), (u'best', -0.02856751391295393), (u'beatiful', -0.004072770293723184), (u'some_author', -0.0035699197318947037), (u'dhoni', -0.002496119789379354)], u'CRIME': [(u'hockey', 0.2976336465958332), (u'baseball', 0.2688058903031485), (u'cricket', -0.2655826509364169), (u'sport', 0.20616728084130584), (u'best', -0.02856751391295393), (u'beatiful', -0.004072770293723184), (u'some_author', -0.0035699197318947037), (u'dhoni', -0.002496119789379354)], u'GREEN': [(u'hockey', 0.2976336465958332), (u'baseball', 0.268805890303148

In [86]:
def get_class_name_index(class_names, class_name):
    return list(class_names).index(class_name)

idx = get_class_name_index(class_names, pred_category)

def explain_class(exp, idx):
    try:
        return exp.as_list(label=idx)
    except:
        print("Error for {}".format(i))
        return None
    
print(explain_class(exp, idx))

[(u'hockey', 0.2976336465958332), (u'baseball', 0.2688058903031485), (u'cricket', -0.2655826509364169), (u'sport', 0.20616728084130584), (u'best', -0.02856751391295393), (u'beatiful', -0.004072770293723184), (u'some_author', -0.0035699197318947037), (u'dhoni', -0.002496119789379354)]
