In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [2]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [3]:
from eli5 import explain_weights, explain_prediction
from eli5 import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))



Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.834  heis
  +2.299  eis 
  +2.086  eist
  +2.003  ░ath
  +1.851  athe
  +1.810  thei
  +1.702  hei 
  +1.679  sla 
  +1.669  ░pos
  +1.586  nat 
  +1.526  mott
  +1.526  post
  +1.503  slam
  +1.466  stin
  +1.453  ath 
  +1.438  rna 
  +1.434  ░mad
  +1.420  obb 
       …  (20056 more positive features)
       …  (33424 more negative features)
  -1.410  ░ro 
  -1.477  pac 

y='comp.graphics' top features
--------------
  +2.15

In [4]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [5]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names'], top=100))

Weight,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+2.834,heis,,
+2.299,eis,,
+2.086,eist,,
+2.003,ath,,
+1.851,athe,,
+1.810,thei,,
+1.702,hei,,
+1.679,sla,,
+1.669,pos,,
+1.586,nat,,

Weight,Feature
+2.834,heis
+2.299,eis
+2.086,eist
+2.003,ath
+1.851,athe
+1.810,thei
+1.702,hei
+1.679,sla
+1.669,pos
+1.586,nat

Weight,Feature
+2.158,3d
+2.001,file
+1.886,gra
+1.748,raph
+1.745,mage
+1.743,phi
+1.725,fil
+1.702,aph
+1.668,hics
+1.645,mag

Weight,Feature
+3.188,spac
+3.116,pace
+2.675,spa
+2.611,pac
+2.511,spa
+1.931,nas
+1.923,ace
+1.839,orb
+1.808,nas
+1.804,rbit

Weight,Feature
+2.152,*
+1.735,ian
+1.672,rist
+1.595,ian
+1.532,us
+1.524,he
+1.494,mor
+1.409,sa
+1.380,is
+1.367,amor


In [6]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

Weight,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+0.341,:,,
+0.065,the,,
+0.054,be,,
+0.047,tin,,
+0.042,ill,,
+0.031,se,,
+0.030,as,,
+0.029,ght,,
+0.029,of,,
+0.029,as,,

Weight,Feature
+0.341,:
+0.065,the
+0.054,be
+0.047,tin
+0.042,ill
+0.031,se
+0.030,as
+0.029,ght
+0.029,of
+0.029,as

Weight,Feature
+0.117,:
+0.037,ase
+0.036,fi
+0.034,ile
+0.032,rig
+0.031,fra
+0.031,co
+0.029,is
… 484 more positive …,… 484 more positive …
… 562 more negative …,… 562 more negative …

Weight,Feature
+0.112,th
+0.111,pac
+0.098,spac
+0.096,pace
+0.092,astr
+0.086,igh
+0.080,the
+0.078,spa
+0.078,spa
+0.070,orb

Weight,Feature
+0.125,:
+0.078,th
+0.066,is
+0.044,fra
+0.043,sa
+0.039,eld
+0.038,as
+0.035,der
+0.032,his
+0.032,hat

Weight,Feature
… 467 more positive …,… 467 more positive …
… 591 more negative …,… 591 more negative …
-0.019,Highlighted in text (sum)
-0.999,<BIAS>

Weight,Feature
… 484 more positive …,… 484 more positive …
… 562 more negative …,… 562 more negative …
-0.948,<BIAS>
-1.479,Highlighted in text (sum)

Weight,Feature
+2.038,Highlighted in text (sum)
… 542 more positive …,… 542 more positive …
… 513 more negative …,… 513 more negative …
-0.984,<BIAS>

Weight,Feature
… 474 more positive …,… 474 more positive …
… 581 more negative …,… 581 more negative …
-0.146,Highlighted in text (sum)
-0.996,<BIAS>


In [7]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

Weight,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+0.122,mad,,
+0.094,vat,,
+0.086,mad,,
+0.077,atic,,
+0.073,ican,,
+0.062,ble.,,
+0.054,le.,,
+0.053,bra,,
+0.045,ade,,
+0.045,le.,,

Weight,Feature
0.122,mad
0.094,vat
0.086,mad
0.077,atic
0.073,ican
0.062,ble.
0.054,le.
0.053,bra
0.045,ade
0.045,le.

Weight,Feature
0.117,ftp
0.112,ft
0.103,ftp
0.103,tp
0.099,ftp
0.086,hel
0.078,site
0.076,help
0.074,lp
0.074,elp

Weight,Feature
0.076,ndin
0.069,th
0.059,lle
0.058,ry
0.056,ndi
0.055,oll
0.049,the
0.048,vat
0.047,coll
0.045,the

Weight,Feature
0.093,is
0.074,indi
0.052,us.
0.049,ite
0.048,th
0.041,vati
0.041,he
0.04,rece
0.039,his
0.039,tou

Weight,Feature
-0.999,<BIAS>
-1.31,Highlighted in text (sum)

Weight,Feature
2.74,Highlighted in text (sum)
-0.948,<BIAS>

Weight,Feature
0.023,Highlighted in text (sum)
-0.984,<BIAS>

Weight,Feature
-0.996,<BIAS>
-1.509,Highlighted in text (sum)


In [8]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([t.score for t in expl.targets])
    expl.targets = [expl.targets[max_class_idx]]
    show_html_expl(expl, force_weights=False)

Weight,Feature
1.51,Highlighted in text (sum)
-0.984,<BIAS>


Weight,Feature
2.74,Highlighted in text (sum)
-0.948,<BIAS>


Weight,Feature
2.793,Highlighted in text (sum)
-0.948,<BIAS>


Weight,Feature
1.88,Highlighted in text (sum)
-0.948,<BIAS>


Weight,Feature
2.05,Highlighted in text (sum)
-0.948,<BIAS>


Weight,Feature
1.688,Highlighted in text (sum)
-0.948,<BIAS>


Weight,Feature
0.451,Highlighted in text (sum)
-0.948,<BIAS>


Weight,Feature
2.786,Highlighted in text (sum)
-0.984,<BIAS>


Weight,Feature
3.084,Highlighted in text (sum)
-0.999,<BIAS>


Weight,Feature
1.237,Highlighted in text (sum)
-0.984,<BIAS>
