In [1]:
from sklearn.datasets import fetch_20newsgroups
categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
fetch_subset = lambda subset: fetch_20newsgroups(
    subset=subset, categories=categories,
    shuffle=True, random_state=42,
    remove=('headers', 'footers', 'quotes'))
train = fetch_subset('train')
test = fetch_subset('test')

In [2]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3, 4))
clf = SGDClassifier(n_jobs=-1)
pipeline = Pipeline([('vec', vec), ('clf', clf)])
pipeline.fit(train['data'], train['target'])

Pipeline(steps=[('vec', TfidfVectorizer(analyzer='char_wb', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(3, 4), norm='l2', preprocessor=None, smooth_idf=True,
...   penalty='l2', power_t=0.5, random_state=None, shuffle=True,
       verbose=0, warm_start=False))])

In [3]:
from eli5 import explain_weights, explain_prediction
from eli5 import format_as_html, format_as_text, format_html_styles

print(format_as_text(explain_weights(clf, vec, target_names=train['target_names'])))

Explained as: linear model

Features with largest coefficients per class.
Caveats:
1. Be careful with features which are not
   independent - weights don't show their importance.
2. If scale of input features is different then scale of coefficients
   will also be different, making direct comparison between coefficient values
   incorrect.
3. Depending on regularization, rare features sometimes may have high
   coefficients; this doesn't mean they contribute much to the
   classification result for most examples.

y='alt.atheism' top features
--------------
  +2.677  heis
  +2.202  eis 
  +2.030  eist
  +1.927  ░pos
  +1.918  nat 
  +1.880  post
  +1.846  pos 
  +1.807  ish░
  +1.709  ░ath
  +1.694  thei
  +1.634  ░mad
  +1.632  ish 
  +1.612  hei 
  +1.592  sla 
  +1.545  lai 
  +1.503  rna 
  +1.448  logi
       …  (20181 more positive features)
       …  (31794 more negative features)
  -1.492  et░ 
  -1.553  ░*░ 
  -1.660  ░us 

y='comp.graphics' top features
--------------
  +2.13

In [4]:
from IPython.core.display import display, HTML
show_html = lambda html: display(HTML(html))
show_html_expl = lambda expl, **kwargs: show_html(format_as_html(expl, include_styles=False, **kwargs))
show_html(format_html_styles())

In [5]:
show_html_expl(explain_weights(clf, vec, target_names=train['target_names'], top=100))

y=alt.atheism  top features,y=alt.atheism  top features,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
y=comp.graphics  top features,y=comp.graphics  top features,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
y=sci.space  top features,y=sci.space  top features,Unnamed: 2_level_4,Unnamed: 3_level_4
Weight,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5
y=talk.religion.misc  top features,y=talk.religion.misc  top features,Unnamed: 2_level_6,Unnamed: 3_level_6
Weight,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7
+2.677,heis,,
+2.202,eis,,
+2.030,eist,,
+1.927,pos,,
+1.918,nat,,
+1.880,post,,
+1.846,pos,,
+1.807,ish,,
+1.709,ath,,
+1.694,thei,,

y=alt.atheism  top features,y=alt.atheism  top features
Weight,Feature
+2.677,heis
+2.202,eis
+2.030,eist
+1.927,pos
+1.918,nat
+1.880,post
+1.846,pos
+1.807,ish
+1.709,ath
+1.694,thei

y=comp.graphics  top features,y=comp.graphics  top features
Weight,Feature
+2.130,file
+2.028,3d
+1.976,phi
+1.877,fil
+1.873,raph
+1.855,mage
+1.775,aph
+1.747,ima
+1.709,hics
+1.703,gra

y=sci.space  top features,y=sci.space  top features
Weight,Feature
+3.268,spac
+3.215,pace
+2.802,spa
+2.645,pac
+2.541,spa
+1.917,ace
+1.872,nas
+1.827,ace
+1.815,nas
+1.814,orb

y=talk.religion.misc  top features,y=talk.religion.misc  top features
Weight,Feature
+2.205,*
+1.843,he
+1.680,rist
+1.664,sa
+1.619,ian
+1.522,ian
+1.484,us
+1.463,de
+1.404,fbi
+1.330,fbi


In [6]:
show_html_expl(explain_prediction(clf, test['data'][7], vec, target_names=train['target_names'], top=50), force_weights=True)

y=alt.atheism  (score -1.429) top features,y=alt.atheism  (score -1.429) top features,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
y=comp.graphics  (score -2.545) top features,y=comp.graphics  (score -2.545) top features,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
y=sci.space  (score 1.911) top features,y=sci.space  (score 1.911) top features,Unnamed: 2_level_4,Unnamed: 3_level_4
Weight,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5
y=talk.religion.misc  (score -1.523) top features,y=talk.religion.misc  (score -1.523) top features,Unnamed: 2_level_6,Unnamed: 3_level_6
Weight,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7
+0.399,:,,
+0.053,tin,,
+0.049,be,,
+0.038,the,,
+0.038,as,,
+0.035,ght,,
+0.032,up,,
+0.031,ill,,
+0.031,ting,,
+0.028,what,,

y=alt.atheism  (score -1.429) top features,y=alt.atheism  (score -1.429) top features
Weight,Feature
+0.399,:
+0.053,tin
+0.049,be
+0.038,the
+0.038,as
+0.035,ght
+0.032,up
+0.031,ill
+0.031,ting
+0.028,what

y=comp.graphics  (score -2.545) top features,y=comp.graphics  (score -2.545) top features
Weight,Feature
+0.253,:
+0.042,fi
+0.039,co
+0.034,ile
+0.029,ha
+0.028,fra
+0.028,a
+0.026,li
… 483 more positive …,… 483 more positive …
… 570 more negative …,… 570 more negative …

y=sci.space  (score 1.911) top features,y=sci.space  (score 1.911) top features
Weight,Feature
+0.112,pac
+0.101,spac
+0.099,pace
+0.094,th
+0.094,astr
+0.082,spa
+0.079,spa
+0.078,igh
+0.075,the
+0.069,orb

y=talk.religion.misc  (score -1.523) top features,y=talk.religion.misc  (score -1.523) top features
Weight,Feature
+0.090,th
+0.057,he
+0.051,sa
+0.043,br
+0.042,the
+0.040,as
+0.039,fra
+0.038,of
+0.036,der
+0.035,eld

Weight,Feature
+0.110,Highlighted in text (sum)
… 489 more positive …,… 489 more positive …
… 568 more negative …,… 568 more negative …
-0.979,<BIAS>

Weight,Feature
… 483 more positive …,… 483 more positive …
… 570 more negative …,… 570 more negative …
-1.007,<BIAS>
-1.158,Highlighted in text (sum)

Weight,Feature
+1.940,Highlighted in text (sum)
… 549 more positive …,… 549 more positive …
… 505 more negative …,… 505 more negative …
-0.916,<BIAS>

Weight,Feature
+0.039,Highlighted in text (sum)
… 481 more positive …,… 481 more positive …
… 572 more negative …,… 572 more negative …
-1.008,<BIAS>


In [7]:
show_html_expl(explain_prediction(clf, test['data'][1], vec, target_names=train['target_names']))

y=alt.atheism  (score -2.318) top features,y=alt.atheism  (score -2.318) top features,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
y=comp.graphics  (score 1.767) top features,y=comp.graphics  (score 1.767) top features,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
y=sci.space  (score -0.776) top features,y=sci.space  (score -0.776) top features,Unnamed: 2_level_4,Unnamed: 3_level_4
Weight,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5
y=talk.religion.misc  (score -2.192) top features,y=talk.religion.misc  (score -2.192) top features,Unnamed: 2_level_6,Unnamed: 3_level_6
Weight,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7
+0.139,mad,,
+0.104,mad,,
+0.076,atic,,
+0.070,vat,,
+0.062,ican,,
+0.053,ade,,
+0.052,ble.,,
+0.049,made,,
+0.048,bra,,
+0.048,ade,,

y=alt.atheism  (score -2.318) top features,y=alt.atheism  (score -2.318) top features
Weight,Feature
0.139,mad
0.104,mad
0.076,atic
0.07,vat
0.062,ican
0.053,ade
0.052,ble.
0.049,made
0.048,bra
0.048,ade

y=comp.graphics  (score 1.767) top features,y=comp.graphics  (score 1.767) top features
Weight,Feature
0.093,lib
0.091,lib
0.086,ftp
0.084,ft
0.083,hel
0.082,help
0.082,can
0.082,brar
0.08,elp
0.078,ibra

y=sci.space  (score -0.776) top features,y=sci.space  (score -0.776) top features
Weight,Feature
0.079,ndin
0.064,ndi
0.06,lle
0.058,th
0.052,ry
0.047,the
0.046,the
0.044,vat
0.039,oll
0.038,ary

y=talk.religion.misc  (score -2.192) top features,y=talk.religion.misc  (score -2.192) top features
Weight,Feature
0.072,us.
0.067,vati
0.065,indi
0.056,rece
0.055,th
0.051,us.
0.048,is
0.047,ecen
0.047,he
0.046,he

Weight,Feature
-0.979,<BIAS>
-1.339,Highlighted in text (sum)

Weight,Feature
2.774,Highlighted in text (sum)
-1.007,<BIAS>

Weight,Feature
0.14,Highlighted in text (sum)
-0.916,<BIAS>

Weight,Feature
-1.008,<BIAS>
-1.185,Highlighted in text (sum)


In [8]:
import numpy as np
for doc in test['data'][:10]:
    expl = explain_prediction(clf, doc, vec, target_names=train['target_names'])
    # haaack - leave only the winner
    max_class_idx = np.argmax([t.score for t in expl.targets])
    expl.targets = [expl.targets[max_class_idx]]
    show_html_expl(expl, force_weights=False)

Weight,Feature
1.311,Highlighted in text (sum)
-0.916,<BIAS>


Weight,Feature
2.774,Highlighted in text (sum)
-1.007,<BIAS>


Weight,Feature
2.947,Highlighted in text (sum)
-1.007,<BIAS>


Weight,Feature
1.562,Highlighted in text (sum)
-1.007,<BIAS>


Weight,Feature
1.48,Highlighted in text (sum)
-1.007,<BIAS>


Weight,Feature
1.349,Highlighted in text (sum)
-1.007,<BIAS>


Weight,Feature
0.422,Highlighted in text (sum)
-1.007,<BIAS>


Weight,Feature
2.827,Highlighted in text (sum)
-0.916,<BIAS>


Weight,Feature
2.818,Highlighted in text (sum)
-0.979,<BIAS>


Weight,Feature
0.874,Highlighted in text (sum)
-0.916,<BIAS>
