In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [2]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [3]:
from eli5 import explain_weights, explain_prediction
from eli5 import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.875  heis
  +2.400  eis 
  +2.314  eist
  +2.106  ░ath
  +1.945  ░pos
  +1.944  thei
  +1.917  post
  +1.886  hei 
  +1.803  nat 
  +1.676  pos 
  +1.663  sla 
  +1.655  stin
  +1.597  athe
  +1.581  ish░
  +1.549  lam 
  +1.501  ░it░
  +1.475  rna 
  +1.468  it░ 
       …  (20273 more positive features)
       …  (31073 more negative features)
  -1.573  pac 
  -1.814  ░us 

y='comp.graphics' top features
--------------
  +1.97

In [4]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [5]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names'], top=100))

Weight,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+2.875,heis,,
+2.400,eis,,
+2.314,eist,,
+2.106,ath,,
+1.945,pos,,
+1.944,thei,,
+1.917,post,,
+1.886,hei,,
+1.803,nat,,
+1.676,pos,,

Weight,Feature
+2.875,heis
+2.400,eis
+2.314,eist
+2.106,ath
+1.945,pos
+1.944,thei
+1.917,post
+1.886,hei
+1.803,nat
+1.676,pos

Weight,Feature
+1.977,phi
+1.968,|||
+1.962,3d
+1.917,mag
+1.894,mage
+1.858,raph
+1.853,file
+1.810,gra
+1.776,ima
+1.746,||||

Weight,Feature
+3.120,spac
+3.043,pace
+2.600,spa
+2.482,spa
+2.458,pac
+1.995,nas
+1.902,nas
+1.880,ace
+1.829,orb
+1.717,..

Weight,Feature
+1.957,*
+1.763,rist
+1.703,he
+1.644,ian
+1.624,ian
+1.528,us
+1.441,fbi
+1.375,amor
+1.367,fbi
+1.339,and


In [6]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

Weight,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+0.222,:,,
+0.059,tin,,
+0.042,be,,
+0.037,the,,
+0.035,as,,
+0.032,ght,,
+0.032,ting,,
+0.029,se,,
+0.029,ill,,
+0.028,up,,

Weight,Feature
+0.222,:
+0.059,tin
+0.042,be
+0.037,the
+0.035,as
+0.032,ght
+0.032,ting
+0.029,se
+0.029,ill
+0.028,up

Weight,Feature
+0.050,:
+0.037,fi
+0.036,co
+0.034,fra
+0.032,ile
+0.031,ase
+0.029,it
+0.028,rig
+0.027,li
… 487 more positive …,… 487 more positive …

Weight,Feature
+0.104,pac
+0.100,igh
+0.100,astr
+0.096,spac
+0.094,pace
+0.080,ight
+0.077,spa
+0.076,spa
+0.074,th
+0.070,orb

Weight,Feature
+0.096,th
+0.054,is
+0.051,he
+0.045,fra
+0.042,of
+0.038,serv
+0.037,der
+0.037,his
+0.036,the
+0.036,erv

Weight,Feature
… 483 more positive …,… 483 more positive …
… 571 more negative …,… 571 more negative …
-0.397,Highlighted in text (sum)
-1.003,<BIAS>

Weight,Feature
… 487 more positive …,… 487 more positive …
… 564 more negative …,… 564 more negative …
-0.997,<BIAS>
-1.441,Highlighted in text (sum)

Weight,Feature
+1.946,Highlighted in text (sum)
… 535 more positive …,… 535 more positive …
… 525 more negative …,… 525 more negative …
-0.937,<BIAS>

Weight,Feature
+0.087,Highlighted in text (sum)
… 489 more positive …,… 489 more positive …
… 566 more negative …,… 566 more negative …
-0.960,<BIAS>


In [7]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

Weight,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+0.113,mad,,
+0.094,vat,,
+0.077,ican,,
+0.076,mad,,
+0.073,atic,,
+0.051,bra,,
+0.049,ble.,,
+0.049,fin,,
+0.048,le.,,
+0.046,ent,,

Weight,Feature
0.113,mad
0.094,vat
0.077,ican
0.076,mad
0.073,atic
0.051,bra
0.049,ble.
0.049,fin
0.048,le.
0.046,ent

Weight,Feature
0.093,ftp
0.091,help
0.09,ft
0.089,elp
0.088,lib
0.088,lp
0.088,hel
0.083,lib
0.081,elp
0.08,ftp

Weight,Feature
0.063,ndin
0.057,lle
0.056,ndi
0.052,av
0.049,ry
0.045,th
0.044,tou
0.04,oll
0.039,the
0.039,te

Weight,Feature
0.075,is
0.059,th
0.057,vati
0.051,indi
0.049,he
0.046,us.
0.045,his
0.042,me
0.042,he
0.039,me

Weight,Feature
-1.003,<BIAS>
-1.397,Highlighted in text (sum)

Weight,Feature
2.666,Highlighted in text (sum)
-0.997,<BIAS>

Weight,Feature
-0.089,Highlighted in text (sum)
-0.937,<BIAS>

Weight,Feature
-0.96,<BIAS>
-1.34,Highlighted in text (sum)


In [8]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([t.score for t in expl.targets])
    expl.targets = [expl.targets[max_class_idx]]
    show_html_expl(expl, force_weights=False)

Weight,Feature
1.317,Highlighted in text (sum)
-0.937,<BIAS>


Weight,Feature
2.666,Highlighted in text (sum)
-0.997,<BIAS>


Weight,Feature
2.638,Highlighted in text (sum)
-0.997,<BIAS>


Weight,Feature
1.606,Highlighted in text (sum)
-0.997,<BIAS>


Weight,Feature
1.662,Highlighted in text (sum)
-0.997,<BIAS>


Weight,Feature
1.416,Highlighted in text (sum)
-0.997,<BIAS>


Weight,Feature
0.273,Highlighted in text (sum)
-0.997,<BIAS>


Weight,Feature
2.749,Highlighted in text (sum)
-0.937,<BIAS>


Weight,Feature
2.564,Highlighted in text (sum)
-1.003,<BIAS>


Weight,Feature
1.041,Highlighted in text (sum)
-0.937,<BIAS>
