In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [2]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [3]:
from eli5.sklearn import explain_weights, explain_prediction
from eli5.formatters import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +3.077  heis
  +2.572  eis 
  +2.408  eist
  +2.257   ath
  +2.049  thei
  +2.033  nat 
  +2.031  hei 
  +1.961   pos
  +1.850  athe
  +1.804  post
  +1.699  ish 
  +1.592  pos 
  +1.568  ish 
  +1.506  lam 
  +1.506  lai 
  +1.502   po 
  +1.488  it  
  +1.478   it 
       …  (20053 more positive features)
       …  (30025 more negative features)
  -1.586  pac 
  -1.557   us 

y='comp.graphics' top features
--------------
  +2.09

In [4]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [5]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names']))

Weight,Feature
+3.077,heis
+2.572,eis
+2.408,eist
+2.257,ath
+2.049,thei
+2.033,nat
+2.031,hei
+1.961,pos
+1.850,athe
+1.804,post

Weight,Feature
+2.099,phi
+2.037,file
+2.030,gra
+2.006,3d
+1.905,mag
+1.878,mage
+1.870,ima
+1.850,fil
+1.841,hics
+1.840,imag

Weight,Feature
+3.324,spac
+3.306,pace
+2.837,spa
+2.805,spa
+2.739,pac
+2.051,ace
+2.022,orb
+1.978,nas
+1.902,ace
+1.871,rbit

Weight,Feature
+2.487,*
+1.728,rist
+1.683,ian
+1.653,us
+1.607,he
+1.504,ian
+1.402,that
+1.354,de
+1.349,ent?
+1.333,may


In [6]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

Weight,Feature
+0.175,:
+0.052,be
+0.039,tin
+0.035,the
+0.033,ill
+0.033,of
+0.032,se
+0.031,ght
+0.031,wh
+0.030,up

Weight,Feature
+0.138,:
+0.042,fi
+0.034,co
+0.033,ile
+0.032,fra
+0.030,ha
+0.028,a
+0.027,it
… 477 more positive …,… 477 more positive …
… 572 more negative …,… 572 more negative …

Weight,Feature
+0.116,pac
+0.103,spac
+0.102,igh
+0.102,pace
+0.099,th
+0.089,astr
+0.087,spa
+0.083,spa
+0.077,orb
+0.073,orb

Weight,Feature
+0.129,th
+0.048,of
+0.047,is
+0.044,of
+0.042,hat
+0.041,of
+0.037,fra
+0.037,sa
+0.036,his
+0.036,br


In [7]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

Weight,Feature
+0.115,mad
+0.097,atic
+0.091,vat
+0.080,mad
+0.078,ican
+0.065,ble.
+0.058,an
+0.055,le.
+0.049,fin
… 50 more positive …,… 50 more positive …

Weight,Feature
+0.095,ftp
+0.091,ft
+0.086,lib
+0.083,ftp
+0.081,tp
+0.081,lib
+0.080,help
+0.079,ftp
+0.078,elp
+0.077,hel

Weight,Feature
+0.065,ndin
+0.063,ry
+0.061,th
+0.053,ndi
+0.052,lle
+0.048,tou
+0.044,the
+0.041,vat
+0.039,the
… 78 more positive …,… 78 more positive …

Weight,Feature
+0.079,th
+0.065,is
+0.060,indi
+0.058,vati
+0.054,us.
+0.050,he
… 57 more positive …,… 57 more positive …
… 99 more negative …,… 99 more negative …
-0.046,any
-0.046,vat


In [8]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([cl['score'] for cl in expl['classes']])
    expl['classes'] = [expl['classes'][max_class_idx]]
    show_html_expl(expl)

Weight,Feature
+0.150,rry
+0.135,sky
+0.126,rry
+0.124,sky
+0.093,sk
+0.092,ry
+0.089,riz
+0.081,roje
+0.081,ojec
+0.081,oje


Weight,Feature
+0.095,ftp
+0.091,ft
+0.086,lib
+0.083,ftp
+0.081,tp
+0.081,lib
+0.080,help
+0.079,ftp
+0.078,elp
+0.077,hel


Weight,Feature
+0.173,phi
+0.145,ware
+0.088,soft
+0.088,hics
+0.082,sof
+0.081,aphi
+0.078,phic
+0.077,raph
+0.077,sof
+0.073,war


Weight,Feature
+0.072,|
+0.049,for
+0.047,pli
+0.046,line
+0.046,gr
+0.045,hics
+0.044,phi
+0.041,aphi
+0.040,ine
+0.040,phic


Weight,Feature
+0.167,gra
+0.153,3d
+0.104,mac
+0.098,ram
+0.089,ima
+0.086,3d
+0.085,any
+0.080,3d
+0.078,pc
+0.073,int


Weight,Feature
+0.220,gra
+0.130,ram
+0.128,!!
+0.070,gram
+0.068,pc
+0.063,i'
+0.055,rog
+0.050,prog
+0.050,rogr
+0.050,gif


Weight,Feature
+0.051,ect
+0.051,er
+0.036,ner
+0.035,ct
+0.035,bou
+0.034,po
+0.033,nner
+0.028,it
+0.028,pol
+0.026,anne


Weight,Feature
+0.116,pac
+0.103,spac
+0.102,igh
+0.102,pace
+0.099,th
+0.089,astr
+0.087,spa
+0.083,spa
+0.077,orb
+0.073,orb


Weight,Feature
+0.335,heis
+0.265,eis
+0.257,ath
+0.230,thei
+0.227,athe
+0.223,hei
+0.161,ath
+0.142,eist
+0.081,cau
+0.063,eism


Weight,Feature
+0.122,th
+0.103,the
+0.096,the
+0.057,lo
+0.056,he
+0.052,the
+0.050,!!!
+0.045,ry
+0.041,!!!!
+0.039,uni
