In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [2]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [3]:
from eli5 import explain_weights, explain_prediction
from eli5 import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.579  heis
  +2.158  eis 
  +2.092  eist
  +2.036  ░pos
  +1.916  post
  +1.840  nat 
  +1.787  ░ath
  +1.763  thei
  +1.710  athe
  +1.710  hei 
  +1.693  pos 
  +1.506  mott
  +1.495  sla 
  +1.464  rna 
  +1.453  ░up?
  +1.448  up? 
  +1.444  up?░
       …  (20146 more positive features)
       …  (30471 more negative features)
  -1.591  pac 
  -1.440  et░ 
  -1.429  ░us 

y='comp.graphics' top features
--------------
  +2.17

In [4]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [5]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names'], top=100))

Weight,Feature
+2.579,heis
+2.158,eis
+2.092,eist
+2.036,pos
+1.916,post
+1.840,nat
+1.787,ath
+1.763,thei
+1.710,athe
+1.710,hei

Weight,Feature
+2.173,file
+1.992,3d
+1.980,phi
+1.925,gra
+1.834,mage
+1.831,fil
+1.824,ima
+1.772,raph
+1.745,mag
+1.737,imag

Weight,Feature
+3.241,spac
+3.180,pace
+2.816,spa
+2.673,spa
+2.527,pac
+2.000,orb
+1.992,rbit
+1.912,rbi
+1.910,orbi
+1.881,ace

Weight,Feature
+1.965,*
+1.819,ian
+1.764,rist
+1.656,ian
+1.607,he
+1.419,us
+1.389,fbi
+1.333,ord
+1.311,fbi
+1.297,fbi


In [6]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

Weight,Feature
+0.057,:
+0.054,be
+0.042,the
+0.037,tin
+0.036,ght
+0.030,ill
+0.027,it
+0.027,ost
+0.025,as
+0.023,n't

Weight,Feature
… 467 more positive …,… 467 more positive …
… 588 more negative …,… 588 more negative …
-0.592,Highlighted in text (sum)
-1.024,<BIAS>

Weight,Feature
+0.232,:
+0.042,fi
+0.039,rig
+0.035,ile
+0.035,co
+0.031,cop
+0.031,is
+0.031,fra
+0.030,ha
… 486 more positive …,… 486 more positive …

Weight,Feature
… 486 more positive …,… 486 more positive …
… 561 more negative …,… 561 more negative …
-0.998,<BIAS>
-1.320,Highlighted in text (sum)

Weight,Feature
+0.107,pac
+0.100,spac
+0.098,astr
+0.098,pace
+0.083,spa
+0.082,spa
+0.076,orb
+0.075,rbit
+0.074,orb
+0.074,igh

Weight,Feature
+1.869,Highlighted in text (sum)
… 552 more positive …,… 552 more positive …
… 509 more negative …,… 509 more negative …
-0.998,<BIAS>

Weight,Feature
+0.055,th
+0.052,is
+0.038,eld
+0.038,der
+0.037,he
+0.036,sa
+0.036,fra
+0.036,:
+0.035,serv
+0.035,br

Weight,Feature
… 488 more positive …,… 488 more positive …
… 565 more negative …,… 565 more negative …
-0.317,Highlighted in text (sum)
-0.996,<BIAS>


In [7]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

Weight,Feature
+0.115,mad
+0.094,vat
+0.083,mad
+0.081,atic
+0.077,ican
+0.055,ble.
+0.051,fin
+0.044,an
… 61 more positive …,… 61 more positive …
… 95 more negative …,… 95 more negative …

Weight,Feature
… 61 more positive …,… 61 more positive …
… 95 more negative …,… 95 more negative …
-0.049,Highlighted in text (sum)
-1.024,<BIAS>

Weight,Feature
+0.096,ftp
+0.093,ft
+0.092,lib
+0.086,lib
+0.082,ftp
+0.081,can
+0.080,hel
+0.077,tp
+0.075,ftp
+0.072,can

Weight,Feature
+1.397,Highlighted in text (sum)
… 101 more positive …,… 101 more positive …
… 55 more negative …,… 55 more negative …
-0.998,<BIAS>

Weight,Feature
+0.072,ndin
+0.065,lle
+0.062,oll
+0.050,ry
+0.050,ndi
+0.048,te
+0.046,col
… 82 more positive …,… 82 more positive …
… 74 more negative …,… 74 more negative …
-0.048,ftp

Weight,Feature
… 82 more positive …,… 82 more positive …
… 74 more negative …,… 74 more negative …
-0.375,Highlighted in text (sum)
-0.998,<BIAS>

Weight,Feature
+0.076,indi
+0.073,is
+0.055,us.
+0.049,he
+0.046,us
+0.046,ecen
… 62 more positive …,… 62 more positive …
… 94 more negative …,… 94 more negative …
-0.046,vat
-0.046,ade

Weight,Feature
… 62 more positive …,… 62 more positive …
… 94 more negative …,… 94 more negative …
-0.442,Highlighted in text (sum)
-0.996,<BIAS>


In [8]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([cl['score'] for cl in expl['classes']])
    expl['classes'] = [expl['classes'][max_class_idx]]
    show_html_expl(expl, force_weights=False)

Weight,Feature
+1.247,Highlighted in text (sum)
… 20 more positive …,… 20 more positive …
… 12 more negative …,… 12 more negative …
-0.998,<BIAS>


Weight,Feature
+1.397,Highlighted in text (sum)
… 101 more positive …,… 101 more positive …
… 55 more negative …,… 55 more negative …
-0.998,<BIAS>


Weight,Feature
+1.309,Highlighted in text (sum)
… 177 more positive …,… 177 more positive …
… 132 more negative …,… 132 more negative …
-0.998,<BIAS>


Weight,Feature
+0.286,Highlighted in text (sum)
… 912 more positive …,… 912 more positive …
… 705 more negative …,… 705 more negative …
-0.998,<BIAS>


Weight,Feature
+1.396,Highlighted in text (sum)
… 96 more positive …,… 96 more positive …
… 94 more negative …,… 94 more negative …
-0.998,<BIAS>


Weight,Feature
+1.175,Highlighted in text (sum)
… 141 more positive …,… 141 more positive …
… 109 more negative …,… 109 more negative …
-0.998,<BIAS>


Weight,Feature
… 63 more positive …,… 63 more positive …
… 51 more negative …,… 51 more negative …
-0.118,Highlighted in text (sum)
-0.998,<BIAS>


Weight,Feature
+0.824,Highlighted in text (sum)
… 579 more positive …,… 579 more positive …
… 512 more negative …,… 512 more negative …
-0.998,<BIAS>


Weight,Feature
+1.927,Highlighted in text (sum)
… 160 more positive …,… 160 more positive …
… 124 more negative …,… 124 more negative …
-1.024,<BIAS>


Weight,Feature
+0.375,Highlighted in text (sum)
… 391 more positive …,… 391 more positive …
… 355 more negative …,… 355 more negative …
-0.998,<BIAS>
