In [15]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [16]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [17]:
from eli5.sklearn import explain_weights, explain_prediction
from eli5.formatters import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.687  heis
  +2.227  eis 
  +2.098  eist
  +1.951   ath
  +1.824   pos
  +1.800  nat 
  +1.693  athe
  +1.648  thei
  +1.615  ish 
  +1.575  post
  +1.552  hei 
  +1.548  lai 
  +1.544  pos 
  +1.520  sla 
  +1.486  mott
  +1.459  slam
       …  (20295 more positive features)
       …  (34393 more negative features)
  -1.688   us 
  -1.617   *  
  -1.488   his
  -1.457  pac 

y='comp.graphics' top features
--------------
  +1.94

In [18]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [19]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names']))

Weight,Feature
+2.687,heis
+2.227,eis
+2.098,eist
+1.951,ath
+1.824,pos
+1.800,nat
+1.693,athe
+1.648,thei
+1.615,ish
+1.575,post

Weight,Feature
+1.949,file
+1.942,3d
+1.842,mage
+1.825,fil
+1.737,gra
+1.724,phi
+1.716,mag
+1.675,raph
+1.603,aph
+1.586,imag

Weight,Feature
+3.260,spac
+3.219,pace
+2.801,spa
+2.592,pac
+2.561,spa
+1.971,nas
+1.943,ace
+1.931,nas
+1.874,ace
+1.856,sp

Weight,Feature
+2.113,*
+1.937,he
+1.748,rist
+1.687,ian
+1.650,us
+1.572,ian
+1.520,amor
+1.499,fbi
+1.426,fbi
+1.381,fbi


In [20]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

Weight,Feature
+0.111,:
+0.050,be
+0.045,tin
+0.044,the
+0.031,se
+0.030,ing
+0.028,up
+0.028,ill
+0.026,of
+0.026,as

Weight,Feature
+0.136,:
+0.047,fi
+0.037,rig
+0.035,co
+0.035,ile
+0.032,fra
+0.030,ase
+0.028,ha
… 472 more positive …,… 472 more positive …
… 579 more negative …,… 579 more negative …

Weight,Feature
+0.110,pac
+0.110,igh
+0.101,spac
+0.100,astr
+0.099,pace
+0.088,ight
+0.082,spa
+0.080,spa
+0.070,ght
+0.069,orb

Weight,Feature
+0.063,is
+0.053,th
+0.045,of
+0.044,br
+0.044,his
+0.042,fra
+0.039,of
+0.039,der
+0.037,of
+0.032,sa


In [21]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

Weight,Feature
+0.120,mad
+0.094,vat
+0.084,atic
+0.083,mad
+0.076,ican
+0.052,fin
+0.052,ble.
… 58 more positive …,… 58 more positive …
… 98 more negative …,… 98 more negative …
-0.048,ava

Weight,Feature
+0.086,ftp
+0.086,hel
+0.085,ft
+0.083,help
+0.081,elp
+0.075,lib
+0.074,lib
+0.071,lp
+0.068,site
+0.068,ftp

Weight,Feature
+0.063,ry
+0.052,ndin
+0.052,lle
+0.049,oll
+0.045,ndi
… 84 more positive …,… 84 more positive …
… 72 more negative …,… 72 more negative …
-0.046,find
-0.046,atic
-0.047,lib

Weight,Feature
+0.088,is
+0.059,us.
+0.054,his
+0.052,he
… 55 more positive …,… 55 more positive …
… 101 more negative …,… 101 more negative …
-0.050,lib
-0.050,ican
-0.051,any
-0.051,ade


In [22]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([cl['score'] for cl in expl['classes']])
    expl['classes'] = [expl['classes'][max_class_idx]]
    show_html_expl(expl)

Weight,Feature
+0.137,sky
+0.135,rry
+0.124,sky
+0.105,riz
+0.100,rry
+0.098,sk
+0.093,ry
+0.093,roj
+0.093,roje
+0.093,ojec


Weight,Feature
+0.086,ftp
+0.086,hel
+0.085,ft
+0.083,help
+0.081,elp
+0.075,lib
+0.074,lib
+0.071,lp
+0.068,site
+0.068,ftp


Weight,Feature
+0.142,phi
+0.136,ware
+0.080,soft
+0.075,here
+0.074,sof
+0.072,sof
+0.072,war
+0.072,raph
+0.070,hics
+0.067,aphi


Weight,Feature
+0.087,|
+0.052,for
+0.049,line
+0.044,pli
+0.038,ine
+0.037,raph
+0.036,phi
+0.036,gr
+0.036,hics
+0.035,mati


Weight,Feature
+0.148,3d
+0.143,gra
+0.119,mac
+0.091,3d
+0.090,any
+0.088,ram
+0.080,3d
+0.079,int
+0.079,mac
+0.075,ima


Weight,Feature
+0.188,gra
+0.119,!!
+0.117,ram
+0.071,i'
+0.063,pc
+0.054,gif
+0.052,gram
+0.051,i
+0.048,li
+0.041,ctur


Weight,Feature
+0.051,ner
+0.050,er
+0.045,anne
+0.035,ect
+0.034,ner
+0.033,bou
+0.031,luti
+0.030,ow!
+0.030,po
+0.029,nner


Weight,Feature
+0.110,pac
+0.110,igh
+0.101,spac
+0.100,astr
+0.099,pace
+0.088,ight
+0.082,spa
+0.080,spa
+0.070,ght
+0.069,orb


Weight,Feature
+0.293,heis
+0.230,eis
+0.222,ath
+0.208,athe
+0.185,thei
+0.170,hei
+0.155,ath
+0.124,eist
+0.088,cau
+0.069,ca


Weight,Feature
+0.079,the
+0.077,th
+0.073,lo
+0.056,the
+0.051,nerg
+0.049,uni
+0.047,erg
+0.046,ene
+0.046,ry
+0.045,rgy
