In [38]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [39]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [60]:
from eli5.sklearn import explain_weights, explain_prediction
from eli5.formatters import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.594  heis
  +2.101  eist
  +2.063  eis 
  +1.864  nat 
  +1.830  ▒ath
  +1.818  ▒pos
  +1.668  hei 
  +1.661  thei
  +1.613  sla 
  +1.606  post
  +1.587  pos 
  +1.581  athe
  +1.495  lai 
  +1.483  rna 
  +1.449  slam
  +1.431  aim 
  +1.427  laim
       …  (20330 more positive features)
       …  (31127 more negative features)
  -1.540  ▒us 
  -1.483  pac 
  -1.415  ▒his

y='comp.graphics' top features
--------------
  +2.09

In [55]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [75]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names'], top=100))

Weight,Feature
+2.594,heis
+2.101,eist
+2.063,eis
+1.864,nat
+1.830,ath
+1.818,pos
+1.668,hei
+1.661,thei
+1.613,sla
+1.606,post

Weight,Feature
+2.098,file
+1.964,3d
+1.875,phi
+1.806,mage
+1.777,fil
+1.749,raph
+1.732,gra
+1.695,aph
+1.646,mag
+1.594,ima

Weight,Feature
+3.211,spac
+3.151,pace
+2.727,spa
+2.595,pac
+2.576,spa
+2.106,nas
+1.990,nas
+1.947,ace
+1.915,sp
+1.891,..

Weight,Feature
+2.493,*
+1.985,he
+1.751,ian
+1.695,us
+1.652,rist
+1.567,ian
+1.487,de
+1.480,amor
+1.413,fbi
+1.339,fbi


In [76]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

Weight,Feature
+0.354,:
+0.055,be
+0.037,ill
+0.034,tin
+0.030,is
+0.030,se
+0.029,of
+0.029,up
+0.029,the
+0.027,wha

Weight,Feature
+0.045,fi
+0.036,ile
+0.034,bri
+0.033,rig
+0.033,ha
+0.031,co
… 477 more positive …,… 477 more positive …
… 575 more negative …,… 575 more negative …
-0.030,moo
-0.030,tro

Weight,Feature
+0.110,pac
+0.106,th
+0.099,spac
+0.097,pace
+0.088,astr
+0.087,igh
+0.080,spa
+0.080,spa
+0.071,orb
+0.068,orb

Weight,Feature
+0.058,is
+0.054,th
+0.049,:
+0.042,he
+0.039,as
+0.038,fra
+0.038,eld
+0.035,der
+0.034,sa
+0.033,br


In [77]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

Weight,Feature
+0.113,mad
+0.081,mad
+0.080,atic
+0.070,vat
+0.055,bra
+0.055,ble.
+0.055,ican
… 50 more positive …,… 50 more positive …
… 106 more negative …,… 106 more negative …
-0.050,vati

Weight,Feature
+0.104,ftp
+0.099,ft
+0.093,lib
+0.086,ftp
+0.086,lib
+0.085,hel
+0.083,tp
+0.082,help
+0.081,ftp
+0.080,elp

Weight,Feature
+0.065,th
+0.065,ry
+0.064,ndin
+0.055,ndi
+0.054,lle
+0.054,tou
+0.053,vat
+0.046,the
… 85 more positive …,… 85 more positive …
… 71 more negative …,… 71 more negative …

Weight,Feature
+0.081,is
+0.065,us.
+0.059,vati
+0.057,indi
+0.051,us.
+0.050,he
… 55 more positive …,… 55 more positive …
… 101 more negative …,… 101 more negative …
-0.047,ade
-0.049,tly


In [78]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([cl['score'] for cl in expl['classes']])
    expl['classes'] = [expl['classes'][max_class_idx]]
    show_html_expl(expl)

Weight,Feature
+0.141,rry
+0.122,sky
+0.111,sky
+0.109,rry
+0.104,riz
+0.096,sk
+0.095,ry
+0.087,oje
+0.087,ojec
+0.087,proj


Weight,Feature
+0.104,ftp
+0.099,ft
+0.093,lib
+0.086,ftp
+0.086,lib
+0.085,hel
+0.083,tp
+0.082,help
+0.081,ftp
+0.080,elp


Weight,Feature
+0.155,phi
+0.137,ware
+0.075,hics
+0.075,raph
+0.074,hic
+0.070,aphi
+0.070,phic
+0.070,aph
+0.069,soft
+0.068,grap


Weight,Feature
+0.077,|
+0.055,for
+0.044,line
+0.040,phi
+0.039,pli
+0.038,hics
+0.038,raph
+0.037,ine
+0.036,aphi
+0.036,phic


Weight,Feature
+0.150,3d
+0.142,gra
+0.102,mac
+0.093,ram
+0.085,any
+0.084,3d
+0.081,3d
+0.076,pc
+0.076,ima
+0.074,mac


Weight,Feature
+0.188,gra
+0.142,!!
+0.123,ram
+0.079,i'
+0.067,pc
+0.060,!!!
+0.057,i
+0.056,!!!
+0.054,gif
+0.050,gram


Weight,Feature
+0.054,er
+0.041,bou
+0.033,pol
+0.033,ab
+0.033,ect
+0.031,anne
+0.031,po
+0.031,ow!
+0.030,luti
… 70 more positive …,… 70 more positive …


Weight,Feature
+0.110,pac
+0.106,th
+0.099,spac
+0.097,pace
+0.088,astr
+0.087,igh
+0.080,spa
+0.080,spa
+0.071,orb
+0.068,orb


Weight,Feature
+0.283,heis
+0.213,eis
+0.208,ath
+0.194,athe
+0.186,thei
+0.183,hei
+0.149,ath
+0.124,eist
+0.062,cau
+0.059,ete


Weight,Feature
+0.131,th
+0.108,the
+0.099,the
+0.070,lo
+0.057,the
+0.057,he
+0.051,!!!
+0.047,ry
+0.047,uni
+0.041,!!!!
