In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [2]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [3]:
from eli5 import explain_weights, explain_prediction
from eli5 import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.950  heis
  +2.479  eis 
  +2.118  eist
  +2.013  ░ath
  +1.824  post
  +1.801  nat 
  +1.783  thei
  +1.748  ░pos
  +1.739  hei 
  +1.616  athe
  +1.588  mott
  +1.578  sla 
  +1.564  ish░
  +1.497  pos 
  +1.487  ish 
  +1.486  slam
  +1.446  rna 
  +1.443  ogi 
       …  (19968 more positive features)
       …  (31848 more negative features)
  -1.442  pac 
  -1.481  ░his

y='comp.graphics' top features
--------------
  +1.91

In [4]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [5]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names'], top=100))

y=alt.atheism  top features,y=alt.atheism  top features,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
y=comp.graphics  top features,y=comp.graphics  top features,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
y=sci.space  top features,y=sci.space  top features,Unnamed: 2_level_4,Unnamed: 3_level_4
Weight,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5
y=talk.religion.misc  top features,y=talk.religion.misc  top features,Unnamed: 2_level_6,Unnamed: 3_level_6
Weight,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7
+2.950,heis,,
+2.479,eis,,
+2.118,eist,,
+2.013,ath,,
+1.824,post,,
+1.801,nat,,
+1.783,thei,,
+1.748,pos,,
+1.739,hei,,
+1.616,athe,,

y=alt.atheism  top features,y=alt.atheism  top features
Weight,Feature
+2.950,heis
+2.479,eis
+2.118,eist
+2.013,ath
+1.824,post
+1.801,nat
+1.783,thei
+1.748,pos
+1.739,hei
+1.616,athe

y=comp.graphics  top features,y=comp.graphics  top features
Weight,Feature
+1.919,file
+1.872,phi
+1.843,3d
+1.825,gra
+1.822,mage
+1.743,ima
+1.726,raph
+1.717,fil
+1.708,mag
+1.668,imag

y=sci.space  top features,y=sci.space  top features
Weight,Feature
+3.164,spac
+3.135,pace
+2.690,spa
+2.509,spa
+2.465,pac
+1.940,nas
+1.885,orb
+1.875,rbit
+1.839,nas
+1.774,orbi

y=talk.religion.misc  top features,y=talk.religion.misc  top features
Weight,Feature
+2.397,*
+1.651,us
+1.608,rist
+1.589,ian
+1.566,he
+1.402,fbi
+1.348,ian
+1.340,de
+1.338,^^^^
+1.323,dea


In [6]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

y=alt.atheism  (score -1.643) top features,y=alt.atheism  (score -1.643) top features,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
y=comp.graphics  (score -2.708) top features,y=comp.graphics  (score -2.708) top features,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
y=sci.space  (score 1.993) top features,y=sci.space  (score 1.993) top features,Unnamed: 2_level_4,Unnamed: 3_level_4
Weight,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5
y=talk.religion.misc  (score -1.275) top features,y=talk.religion.misc  (score -1.275) top features,Unnamed: 2_level_6,Unnamed: 3_level_6
Weight,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7
+0.390,:,,
+0.049,be,,
+0.048,the,,
+0.046,ill,,
+0.040,ght,,
+0.039,as,,
+0.037,tin,,
+0.032,of,,
+0.032,up,,
+0.031,of,,

y=alt.atheism  (score -1.643) top features,y=alt.atheism  (score -1.643) top features
Weight,Feature
+0.390,:
+0.049,be
+0.048,the
+0.046,ill
+0.040,ght
+0.039,as
+0.037,tin
+0.032,of
+0.032,up
+0.031,of

y=comp.graphics  (score -2.708) top features,y=comp.graphics  (score -2.708) top features
Weight,Feature
+0.206,:
+0.047,fi
+0.045,rig
+0.039,it
+0.034,co
+0.034,ase
+0.033,ile
+0.032,fra
+0.032,li
… 480 more positive …,… 480 more positive …

y=sci.space  (score 1.993) top features,y=sci.space  (score 1.993) top features
Weight,Feature
+0.105,pac
+0.098,spac
+0.096,pace
+0.093,astr
+0.090,igh
+0.079,spa
+0.078,spa
+0.075,th
+0.072,orb
+0.071,rbit

y=talk.religion.misc  (score -1.275) top features,y=talk.religion.misc  (score -1.275) top features
Weight,Feature
+0.101,:
+0.084,th
+0.047,is
+0.043,fra
+0.038,his
+0.038,br
+0.034,der
+0.034,"ase,"
+0.034,of
+0.033,of

Weight,Feature
… 481 more positive …,… 481 more positive …
… 577 more negative …,… 577 more negative …
-0.070,Highlighted in text (sum)
-0.989,<BIAS>

Weight,Feature
… 480 more positive …,… 480 more positive …
… 566 more negative …,… 566 more negative …
-0.961,<BIAS>
-1.258,Highlighted in text (sum)

Weight,Feature
+1.990,Highlighted in text (sum)
… 539 more positive …,… 539 more positive …
… 518 more negative …,… 518 more negative …
-0.965,<BIAS>

Weight,Feature
+0.209,Highlighted in text (sum)
… 498 more positive …,… 498 more positive …
… 554 more negative …,… 554 more negative …
-0.969,<BIAS>


In [7]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

y=alt.atheism  (score -1.968) top features,y=alt.atheism  (score -1.968) top features,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
y=comp.graphics  (score 1.531) top features,y=comp.graphics  (score 1.531) top features,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
y=sci.space  (score -0.757) top features,y=sci.space  (score -0.757) top features,Unnamed: 2_level_4,Unnamed: 3_level_4
Weight,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5
y=talk.religion.misc  (score -2.463) top features,y=talk.religion.misc  (score -2.463) top features,Unnamed: 2_level_6,Unnamed: 3_level_6
Weight,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7
+0.113,mad,,
+0.094,vat,,
+0.082,mad,,
+0.082,atic,,
+0.081,ican,,
+0.051,ble.,,
+0.048,bra,,
… 54 more positive …,… 54 more positive …,,
… 102 more negative …,… 102 more negative …,,
-0.049,col,,

y=alt.atheism  (score -1.968) top features,y=alt.atheism  (score -1.968) top features
Weight,Feature
+0.113,mad
+0.094,vat
+0.082,mad
+0.082,atic
+0.081,ican
+0.051,ble.
+0.048,bra
… 54 more positive …,… 54 more positive …
… 102 more negative …,… 102 more negative …
-0.049,col

y=comp.graphics  (score 1.531) top features,y=comp.graphics  (score 1.531) top features
Weight,Feature
+0.105,ftp
+0.102,ft
+0.092,ftp
+0.088,help
+0.086,elp
+0.083,hel
+0.082,tp
+0.080,ftp
+0.080,lib
+0.077,lp

y=sci.space  (score -0.757) top features,y=sci.space  (score -0.757) top features
Weight,Feature
+0.077,ndin
+0.070,ndi
+0.065,lle
+0.059,oll
+0.056,ry
+0.054,tou
+0.053,av
+0.052,vat
+0.050,a
… 90 more positive …,… 90 more positive …

y=talk.religion.misc  (score -2.463) top features,y=talk.religion.misc  (score -2.463) top features
Weight,Feature
+0.067,is
+0.066,us.
+0.061,indi
+0.051,th
… 57 more positive …,… 57 more positive …
… 99 more negative …,… 99 more negative …
-0.049,nyon
-0.049,anyo
-0.050,ican
-0.052,le.

Weight,Feature
… 54 more positive …,… 54 more positive …
… 102 more negative …,… 102 more negative …
-0.121,Highlighted in text (sum)
-0.989,<BIAS>

Weight,Feature
+1.450,Highlighted in text (sum)
… 95 more positive …,… 95 more positive …
… 61 more negative …,… 61 more negative …
-0.961,<BIAS>

Weight,Feature
… 90 more positive …,… 90 more positive …
… 66 more negative …,… 66 more negative …
-0.064,Highlighted in text (sum)
-0.965,<BIAS>

Weight,Feature
… 57 more positive …,… 57 more positive …
… 99 more negative …,… 99 more negative …
-0.659,Highlighted in text (sum)
-0.969,<BIAS>


In [9]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([t.score for t in expl.targets])
    expl.targets = [expl.targets[max_class_idx]]
    show_html_expl(expl, force_weights=False)

Weight,Feature
+1.538,Highlighted in text (sum)
… 21 more positive …,… 21 more positive …
… 11 more negative …,… 11 more negative …
-0.965,<BIAS>


Weight,Feature
+1.450,Highlighted in text (sum)
… 95 more positive …,… 95 more positive …
… 61 more negative …,… 61 more negative …
-0.961,<BIAS>


Weight,Feature
+1.126,Highlighted in text (sum)
… 187 more positive …,… 187 more positive …
… 123 more negative …,… 123 more negative …
-0.961,<BIAS>


Weight,Feature
+0.276,Highlighted in text (sum)
… 885 more positive …,… 885 more positive …
… 740 more negative …,… 740 more negative …
-0.961,<BIAS>


Weight,Feature
+1.257,Highlighted in text (sum)
… 92 more positive …,… 92 more positive …
… 97 more negative …,… 97 more negative …
-0.961,<BIAS>


Weight,Feature
+1.010,Highlighted in text (sum)
… 138 more positive …,… 138 more positive …
… 111 more negative …,… 111 more negative …
-0.961,<BIAS>


Weight,Feature
+0.166,Highlighted in text (sum)
… 58 more positive …,… 58 more positive …
… 57 more negative …,… 57 more negative …
-0.961,<BIAS>


Weight,Feature
+0.933,Highlighted in text (sum)
… 566 more positive …,… 566 more positive …
… 521 more negative …,… 521 more negative …
-0.965,<BIAS>


Weight,Feature
+1.956,Highlighted in text (sum)
… 148 more positive …,… 148 more positive …
… 136 more negative …,… 136 more negative …
-0.989,<BIAS>


Weight,Feature
+0.340,Highlighted in text (sum)
… 387 more positive …,… 387 more positive …
… 361 more negative …,… 361 more negative …
-0.965,<BIAS>
