In [15]:
import sys
import glob
import numpy
from random import shuffle

In [16]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
from nltk.stem.snowball import EnglishStemmer

from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import PassiveAggressiveClassifier

from sklearn.neighbors import KNeighborsClassifier

from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import LinearSVC

from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [17]:
def sz(a): # str/eval
    return {
        'text': a.text,
        'title': a.title,
        'url': a.url,
        'keywords': a.keywords,
        'tags': a.tags,
        'summary': a.summary,
        'date': int(a.publish_date.timestamp()) if a.publish_date != None else None,
        'raw' : a.html
    }

In [18]:
def objRead(f):
    return eval(open(f, 'r').read())

def objWrite(f, data):
    f = f.split('.')[0]
    open(f + '.lobj', 'w').write(str(data)) # light obj

In [19]:
for f in glob.iglob('**/*.lobj', recursive=True):
    print(f)

other.lobj
disasters/shootings/sandyhook.lobj
disasters/shootings/orlando.lobj
disasters/shootings/lasvegas.lobj
disasters/shootings/sanbernardino.lobj
disasters/earthquakes/sumatra.lobj
disasters/earthquakes/haiti.lobj
disasters/earthquakes/sichuan.lobj
disasters/earthquakes/kashmir.lobj
disasters/wildfires/thomas.lobj
disasters/wildfires/attica.lobj
disasters/wildfires/mountcarmel.lobj
disasters/wildfires/nevada.lobj
disasters/hurricanes/jeanne.lobj
disasters/hurricanes/stan.lobj
disasters/hurricanes/maria.lobj
disasters/hurricanes/katrina.lobj


In [28]:
data = [
    { 'articles': objRead('disasters/shootings/lasvegas.lobj'), 'label': 'lasvegas' },
    { 'articles': objRead('disasters/shootings/orlando.lobj'), 'label': 'orlando' },
    { 'articles': objRead('disasters/shootings/sandyhook.lobj'), 'label': 'sandyhook' },
    { 'articles': objRead('disasters/shootings/sanbernardino.lobj'), 'label': 'sanbernardino' },
    { 'articles': objRead('disasters/earthquakes/kashmir.lobj'), 'label': 'kashmir' },
    { 'articles': objRead('disasters/earthquakes/haiti.lobj'), 'label': 'haiti' },
    { 'articles': objRead('disasters/earthquakes/sumatra.lobj'), 'label': 'sumatra' },
    { 'articles': objRead('disasters/earthquakes/sichuan.lobj'), 'label': 'sichuan' },
#     { 'articles': objRead('disasters/wildfires/mountcarmel.lobj'), 'label': 'mountcarmel' }, # only 19 articles
    { 'articles': objRead('disasters/wildfires/attica.lobj'), 'label': 'attica' },
    { 'articles': objRead('disasters/wildfires/nevada.lobj'), 'label': 'nevada' },
    { 'articles': objRead('disasters/wildfires/thomas.lobj'), 'label': 'thomas' },
    { 'articles': objRead('disasters/hurricanes/maria.lobj'), 'label': 'maria' },
    { 'articles': objRead('disasters/hurricanes/katrina.lobj'), 'label': 'katrina' }
#     { 'articles': objRead('disasters/hurricanes/jeanne.lobj'), 'label': 'jeanne' }, # only 35 articles
#     { 'articles': objRead('disasters/hurricanes/stan.lobj'), 'label': 'stan'} # only 29 articles
]

labels = [x['label'] for x in data]

train_data, test_data = [], []
train_label, test_label = numpy.array([]), numpy.array([])

TRAIN_SIZE = 40
l = 0

for cat in data:
    shuffle(cat['articles'])
    for x in cat['articles'][:TRAIN_SIZE]:
        if 'washingtonpost' in x['url']:
            continue
        train_data.append(x['text'])
        train_label = numpy.append(train_label, l)
    for x in cat['articles'][TRAIN_SIZE:]:
        if 'washingtonpost' in x['url']:
            continue
        test_data.append(x['text'])
        test_label = numpy.append(test_label, l)
    l += 1

In [29]:
classifiers = [
    SGDClassifier(loss='hinge',
                  penalty='elasticnet',
                  alpha=1e-3,
                  max_iter=1000,
                  tol=None),
    MultinomialNB(),
    PassiveAggressiveClassifier(max_iter=200),
    KNeighborsClassifier(n_neighbors=len(data)),
    RandomForestClassifier(n_estimators=200,
                           max_depth=3),
    LogisticRegression(),
    LinearSVC()
]

In [30]:
stemmer = EnglishStemmer()
analyzer = CountVectorizer().build_analyzer()

def stemmed_words(text):
    return (stemmer.stem(word) for word in analyzer(text))

In [31]:
for c in classifiers:
    clf = Pipeline([('vect', CountVectorizer(analyzer=stemmed_words,
                                             stop_words='english')),
                    ('tfidf', TfidfTransformer()),
                    ('clf', c),
    ])
    clf.fit(train_data, train_label)
    pred_label = clf.predict(test_data)
    print(type(c).__name__, accuracy_score(test_label, pred_label))
    print(classification_report(test_label, pred_label, target_names = labels))
    print(confusion_matrix(test_label, pred_label))

SGDClassifier 0.9162861491628614
               precision    recall  f1-score   support

     lasvegas       0.81      0.93      0.86        54
      orlando       0.84      0.92      0.88        51
    sandyhook       1.00      0.88      0.94        58
sanbernardino       0.98      0.88      0.93        57
      kashmir       1.00      0.95      0.97        19
        haiti       1.00      0.91      0.95        55
      sumatra       0.95      0.95      0.95        43
      sichuan       0.96      0.94      0.95        54
       attica       1.00      0.97      0.98        60
       nevada       0.76      0.96      0.85        55
       thomas       0.92      0.85      0.88        40
        maria       0.88      0.86      0.87        57
      katrina       0.96      0.93      0.94        54

  avg / total       0.92      0.92      0.92       657

[[50  3  0  0  0  0  0  0  0  1  0  0  0]
 [ 0 47  0  1  0  0  0  0  0  3  0  0  0]
 [ 1  2 51  0  0  0  0  0  0  0  0  4  0]
 [ 0  1  0 50

LinearSVC 0.9193302891933028
               precision    recall  f1-score   support

     lasvegas       0.74      0.98      0.84        54
      orlando       1.00      0.90      0.95        51
    sandyhook       1.00      0.88      0.94        58
sanbernardino       0.96      0.89      0.93        57
      kashmir       1.00      0.95      0.97        19
        haiti       0.98      0.91      0.94        55
      sumatra       0.97      0.91      0.94        43
      sichuan       0.98      0.94      0.96        54
       attica       1.00      0.97      0.98        60
       nevada       0.75      0.98      0.85        55
       thomas       0.97      0.85      0.91        40
        maria       0.88      0.86      0.87        57
      katrina       0.94      0.93      0.93        54

  avg / total       0.93      0.92      0.92       657

[[53  0  0  0  0  0  0  0  0  1  0  0  0]
 [ 1 46  0  1  0  0  0  0  0  3  0  0  0]
 [ 3  0 51  0  0  0  0  0  0  0  0  4  0]
 [ 1  0  0 51  0 