In [2]:
%config InlineBackend.figure_format = 'retina'

In [3]:
%load_ext autoreload

%autoreload 1

In [4]:
import numpy as np
import pandas as pd

import pickle

from pathlib import Path

# Load data

In [5]:
data_root = Path.home() / "data" / "tmp"
reuters_dir = data_root / "reuters21578"
reuters_corpus_path = reuters_dir / "corpus.pkl"
reuters = pickle.load(open(reuters_corpus_path, "rb"))
top_ten_ids, top_ten_names = reuters.top_n(n=10)

cache_dir = reuters_dir / "cache"

# Build dataframe

In [6]:
train_docs, test_docs = reuters.split_modapte()
docs = train_docs + test_docs
train_labels = reuters.get_labels(train_docs, set(top_ten_ids))
test_labels = reuters.get_labels(test_docs, set(top_ten_ids))

In [7]:
df = pd.DataFrame()
df["modapte"] = [d["modapte"] for d in reuters.docs]
df["label"] = train_labels + test_labels
df["date"] = [d["date"] for d in reuters.docs]
df["title"] = [d["title"] for d in reuters.docs]
df["dateline"] = [d["dateline"] for d in reuters.docs]
df["body"] = [d["body"] for d in reuters.docs]
df["date"] = pd.to_datetime(df.date.str.split(".").apply(lambda x:x[0].lstrip()), format="%d-%b-%Y %H:%M:%S")

In [8]:
df

Unnamed: 0,modapte,label,date,title,dateline,body
0,train,1,1987-03-11 18:14:49,U.S. ECONOMIC DATA KEY TO DEBT FUTURES OUTLOOK,"CHICAGO, March 11 -",U.S. economic data this week could be\nthe key...
1,train,4,1987-03-11 18:36:05,BANK OF BRITISH COLUMBIA 1ST QTR JAN 31 NET,"VANCOUVER, British Columbia, March 11 -\n",Oper shr loss two cts vs profit three cts\n ...
2,train,4,1987-03-11 18:38:02,RESTAURANT ASSOCIATES INC <RA> 4TH QTR JAN 3,"NEW YORK, March 11 -\n",Shr 25 cts vs 36 cts\n Net 1.4 mln vs 1.4 m...
3,train,4,1987-03-11 18:41:59,MICHIGAN GENERAL CORP <MGL> 4TH QTR,"SADDLE BROOK, N.J., March 11 -\n",Shr loss 1.02 dlrs vs 1.01 dlr\n Net loss 1...
4,train,5,1987-03-11 18:45:36,"USX <X> PROVED OIL, GAS RESERVES FALL IN 1986","NEW YORK, March 11 -",USX Corp said proved reserves of oil\nand natu...
5,train,4,1987-03-11 18:53:18,BANK OF B.C. REVISES SHARE PAYOUT ESTIMATE,"VANCOUVER, British Columbia, March 11 -",Bank of British\nColumbia said it revised its ...
6,train,4,1987-03-11 18:56:34,<KIENA GOLD MINES LTD> 4TH QTR NET,"TORONTO, March 11 -\n","Shr 17 cts vs 16 cts\n Net 1,019,000 vs 985..."
7,train,8,1987-03-11 18:56:43,ARGENTINE MEAT EXPORTS HIGHER IN JAN/FEB 1987,"BUENOS AIRES, March 11 -",Argentine meat exports during\nJan/Feb 1987 to...
8,train,4,1987-03-11 19:02:33,KIENA PLANS TWO-FOR-ONE STOCK SPLIT,"TORONTO, March 11 -",<Kiena Gold Mines Ltd> said it planned\na two-...
9,train,4,1987-03-11 19:04:38,ROWE FURNITURE CORP <ROWE> SETS QTLY DIVIDEND,"SALEM, Va., March 11 - \n",Qtly div four cts vs four cts prior\n Pay A...


In [131]:
df.title.str.len().values

array([46., 43., 44., ..., 49., 45., 46.])

In [136]:
df.title.str.count("\.").values

array([2., 0., 0., ..., 0., 0., 0.])

In [134]:
foo = "U.S. ECONOMIC DATA KEY TO DEBT FUTURES OUTLOOK"
print(foo.count("."))

2


In [135]:
help(df.title.str.count)

Help on method str_count in module pandas.core.strings:

str_count(pat, flags=0, **kwargs) method of pandas.core.strings.StringMethods instance
    Count occurrences of pattern in each string of the Series/Index.
    
    This function is used to count the number of times a particular regex
    pattern is repeated in each of the string elements of the
    :class:`~pandas.Series`.
    
    Parameters
    ----------
    pat : str
        Valid regular expression.
    flags : int, default 0, meaning no flags
        Flags for the `re` module. For a complete list, `see here
        <https://docs.python.org/3/howto/regex.html#compilation-flags>`_.
    **kwargs
        For compatability with other string methods. Not used.
    
    Returns
    -------
    counts : Series or Index
        Same type as the calling object containing the integer counts.
    
    Notes
    -----
    Some characters need to be escaped when passing in `pat`.
    eg. ``'$'`` has a special meaning in regex and must b

In [133]:
df.title

0           U.S. ECONOMIC DATA KEY TO DEBT FUTURES OUTLOOK
1              BANK OF BRITISH COLUMBIA 1ST QTR JAN 31 NET
2             RESTAURANT ASSOCIATES INC <RA> 4TH QTR JAN 3
3                      MICHIGAN GENERAL CORP <MGL> 4TH QTR
4            USX <X> PROVED OIL, GAS RESERVES FALL IN 1986
5               BANK OF B.C. REVISES SHARE PAYOUT ESTIMATE
6                       <KIENA GOLD MINES LTD> 4TH QTR NET
7            ARGENTINE MEAT EXPORTS HIGHER IN JAN/FEB 1987
8                      KIENA PLANS TWO-FOR-ONE STOCK SPLIT
9            ROWE FURNITURE CORP <ROWE> SETS QTLY DIVIDEND
10           U.S. HOUSE PANEL TAKES FIRST TRADE BILL VOTES
11           SOVIET MINISTER SAYS TRADE BOOST UP TO FRENCH
12             VENEZUELA TO LEND OIL TO ECUADOR FOR EXPORT
13                  EAGLE CLOTHES INC <EGL> 2nD QTR JAN 31
14              BRITAIN CALLS ON JAPAN TO INCREASE IMPORTS
15       TAFT BROADCASTING REJECTS 145 DLR PER SHARE BU...
16                    TAFT <TFB> REJECTS 145 DLR/SHR OFF

In [63]:
df.body.str.count("\n")

0        50.0
1        18.0
2        11.0
3        11.0
4        49.0
5        31.0
6         8.0
7        23.0
8         7.0
9         4.0
10       61.0
11       47.0
12       23.0
13       12.0
14       26.0
15        NaN
16       12.0
17       10.0
18       10.0
19       11.0
20       22.0
21       11.0
22       22.0
23       19.0
24       33.0
25       10.0
26       21.0
27       86.0
28       11.0
29       41.0
         ... 
10759     7.0
10760    12.0
10761    12.0
10762    39.0
10763    10.0
10764    21.0
10765    31.0
10766    13.0
10767    22.0
10768     7.0
10769    19.0
10770    33.0
10771    76.0
10772    17.0
10773     8.0
10774     8.0
10775     NaN
10776    23.0
10777    47.0
10778    33.0
10779     9.0
10780    23.0
10781    22.0
10782    23.0
10783    12.0
10784    23.0
10785    49.0
10786     9.0
10787    40.0
10788    10.0
Name: body, Length: 10789, dtype: float64

# Build feature extraction pipeline

In [380]:
from sklearn.pipeline import Pipeline
from sklearn.pipeline import FeatureUnion

from sklearn.preprocessing import Imputer
from sklearn.preprocessing import StandardScaler

from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import classification_report
from sklearn.feature_extraction import DictVectorizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.feature_extraction.text import TfidfVectorizer

from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier

import xgboost as xgb

In [51]:
class EmptyFitMixin:
    def fit(self, x, y=None):
        return self

In [331]:
class TextFromColumns(EmptyFitMixin, BaseEstimator, TransformerMixin):
    """Extract the text from a list of columns in a single pass.

    Takes a pandas dataframe and produces a series of texts
    from joined columns defined in `text_cols`.
    """
    def __init__(self, columns=["title", "body"]):
        self.text_cols = columns

    def transform(self, df):
        def join(items):
            return ' '.join([str(item) for item in items])

        data = df[self.text_cols].apply(lambda x: "" if x is None else x)
        texts = data.apply(join, axis=1)
        return texts

In [221]:
class ColumnSelector(EmptyFitMixin, BaseEstimator, TransformerMixin):
    def __init__(self, column, filter_none=True):
        self.column = column
        self.filter_none = filter_none

    def transform(self, df):
        col = df[self.column]
        if self.filter_none:
            col = col.apply(lambda x: "" if x is None else x)
        return col

In [None]:
class FilterNone(EmptyFitMixin, BaseEstimator, TransformerMixin):
    def __init__(self, column):
        self.column = column

    def transform(self, data):
        return df[self.column].values

In [303]:
class TextStats(BaseEstimator, EmptyFitMixin, TransformerMixin):
    """Extract features from each document"""

    def transform(self, col):
        tc = col.str
        features = [
            tc.len(),  # character count
            tc.count("\n"),  # line count
            tc.count("\."),  # sentence count
            tc.split().apply(lambda x:len(x) if x is not None else 0),  # word count
        ]
        features = np.concatenate(
            [f.values.reshape(-1, 1) for f in features],
            axis=1,
        )
        where_are_NaNs = np.isnan(features)
        features[where_are_NaNs] = 0
        return features.astype(np.float)

In [138]:
class FooTransform(BaseEstimator, EmptyFitMixin, TransformerMixin):
    def transform(self, data):
        print(data.shape)

In [308]:
df_train = df.query("modapte == 'train'")
df_test = df.query("modapte == 'test'")
y_train_wrong = df_train.label.values
y_test_wrong = df_test.label.values
y_train = np.array(train_labels)
y_test = np.array(test_labels)

In [309]:
y_train, y_train2

(array([1, 4, 4, ..., 4, 4, 4]), array([ 1,  4,  4, ..., 20, 11, 11]))

In [310]:
np.all(y_train == y_train2)

False

In [311]:
np.all(y_test == y_test2)

False

In [None]:
pipeline = Pipeline(memory=str(cache_dir), steps=[
    ("union", FeatureUnion(transformer_list=[
#        ("title_stats", Pipeline([
#            ("column", ColumnSelector("title")),
#            ("stats", TextStats()),
#            ("scaled", StandardScaler()),
#        ])),
#        ("body_stats", Pipeline([
#            ("column", ColumnSelector("body")),
#            ("stats", TextStats()),
#            ("scaled", StandardScaler()),
#        ])),
        ("combined_text", Pipeline([
            ("column", TextFromColumns(columns=["title", "body"])),
            ("tfidf", TfidfVectorizer()),
            ("best", TruncatedSVD(n_components=200, random_state=2018))
        ])),
#        ("title_text", Pipeline([
#            ("column", ColumnSelector("title")),
#            ("tfidf", TfidfVectorizer()),
#            ("best", TruncatedSVD(n_components=200, random_state=2018))
#        ])),
#        ("body_text", Pipeline([
#            ("column", ColumnSelector("body")),
#            ("tfidf", TfidfVectorizer()),
#            ("best", TruncatedSVD(n_components=200, random_state=2018))
#        ])),
    ])),
    # ("clf", LinearSVC()),
    # ("clf", RandomForestClassifier())
    # ("clf", xgb.XGBClassifier())
])

In [397]:
help(xgb.XGBClassifier)

Help on class XGBClassifier in module xgboost.sklearn:

class XGBClassifier(XGBModel, sklearn.base.ClassifierMixin)
 |  Implementation of the scikit-learn API for XGBoost classification.
 |  
 |      Parameters
 |  ----------
 |  max_depth : int
 |      Maximum tree depth for base learners.
 |  learning_rate : float
 |      Boosting learning rate (xgb's "eta")
 |  n_estimators : int
 |      Number of boosted trees to fit.
 |  silent : boolean
 |      Whether to print messages while running boosting.
 |  objective : string or callable
 |      Specify the learning task and the corresponding learning objective or
 |      a custom objective function to be used (see note below).
 |  booster: string
 |      Specify which booster to use: gbtree, gblinear or dart.
 |  nthread : int
 |      Number of parallel threads used to run xgboost.  (Deprecated, please use n_jobs)
 |  n_jobs : int
 |      Number of parallel threads used to run xgboost.  (replaces nthread)
 |  gamma : float
 |      Minimum

In [421]:
%%time
X_train = pipeline.fit_transform(df_train)
X_test = pipeline.transform(df_test)

CPU times: user 34.3 s, sys: 2.82 s, total: 37.1 s
Wall time: 38.5 s


In [389]:
X_train

array([[ 1.15745524e-01, -3.00533858e-01,  2.23166285e+00, ...,
        -7.47255889e-03, -4.37132822e-02,  2.85387423e-02],
       [-1.59030743e-01, -3.00533858e-01, -4.31934746e-01, ...,
        -2.01396917e-02,  1.99619761e-02, -2.12017013e-02],
       [-6.74386541e-02, -3.00533858e-01, -4.31934746e-01, ...,
         1.33891956e-03,  2.27831241e-03, -1.48789167e-02],
       ...,
       [ 1.15745524e-01, -3.00533858e-01, -4.31934746e-01, ...,
         1.20381849e-03, -1.21296789e-02,  5.43001439e-03],
       [ 1.15745524e-01, -3.00533858e-01,  3.56346165e+00, ...,
         2.65797312e-02, -3.72713485e-03,  1.51672852e-02],
       [ 1.15745524e-01, -3.00533858e-01, -4.31934746e-01, ...,
         3.28866663e-03, -1.64978059e-02,  4.49500906e-02]])

In [387]:
X_test = pipeline.transform(df_test)

In [403]:
model = xgb.XGBClassifier(
    early_stopping_rounds=10, n_estimators=50,
    silent=False, n_jobs=4, booster="gblinear",
    random_state=2018, verbose_eval=True,
    objective="binary:logistic"
)

In [423]:
model = xgb.XGBClassifier(n_jobs=4, random_state=2018)

In [424]:
%%time
model.fit(X_train, y_train)

CPU times: user 35min 37s, sys: 17.5 s, total: 35min 55s
Wall time: 17min 2s


XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
       max_depth=3, min_child_weight=1, missing=None, n_estimators=100,
       n_jobs=4, nthread=None, objective='multi:softprob',
       random_state=2018, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
       seed=None, silent=True, subsample=1)

In [425]:
y_pred = model.predict(X_test)

  if diff:


In [426]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.94      0.99      0.96      1087
        acq       0.89      0.95      0.92       710
   money-fx       0.63      0.85      0.72       145
      grain       0.33      0.38      0.36        42
      crude       0.70      0.81      0.75       164
      trade       0.71      0.83      0.76       109
   interest       0.73      0.57      0.64       117
       ship       0.51      0.59      0.55        71
      wheat       0.69      0.56      0.62        55
       corn       0.47      0.62      0.53        45

avg / total       0.84      0.89      0.86      2545



In [409]:
%%time
model.fit(X_train, y_train)

CPU times: user 19min 18s, sys: 2.08 s, total: 19min 20s
Wall time: 19min 29s


XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
       max_depth=3, min_child_weight=1, missing=None, n_estimators=100,
       n_jobs=1, nthread=None, objective='multi:softprob',
       random_state=2018, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
       seed=None, silent=True, subsample=1)

In [410]:
y_pred = model.predict(X_test)

  if diff:


In [411]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.92      0.98      0.95      1087
        acq       0.87      0.95      0.91       710
   money-fx       0.62      0.78      0.69       145
      grain       0.30      0.33      0.31        42
      crude       0.70      0.84      0.76       164
      trade       0.67      0.81      0.73       109
   interest       0.68      0.62      0.65       117
       ship       0.57      0.65      0.61        71
      wheat       0.67      0.58      0.62        55
       corn       0.54      0.56      0.55        45

avg / total       0.82      0.89      0.85      2545



In [422]:
%%time
model = LinearSVC()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.97      0.99      0.98      1087
        acq       0.90      0.97      0.94       710
   money-fx       0.74      0.81      0.78       145
      grain       0.52      0.31      0.39        42
      crude       0.76      0.83      0.79       164
      trade       0.72      0.86      0.78       109
   interest       0.78      0.77      0.78       117
       ship       0.64      0.69      0.67        71
      wheat       0.77      0.75      0.76        55
       corn       0.59      0.73      0.65        45

avg / total       0.88      0.92      0.90      2545

CPU times: user 8.47 s, sys: 122 ms, total: 8.59 s
Wall time: 11.5 s


In [382]:
%%time
pipeline.fit(df_train, y_train)

KeyboardInterrupt: 

In [383]:
y_pred = pipeline.predict(df_test)

XGBoostError: need to call fit beforehand

In [378]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.83      0.99      0.90      1087
        acq       0.79      0.92      0.85       710
   money-fx       0.51      0.49      0.50       145
      grain       0.23      0.14      0.18        42
      crude       0.64      0.75      0.69       164
      trade       0.56      0.66      0.61       109
   interest       0.57      0.59      0.58       117
       ship       0.50      0.30      0.37        71
      wheat       0.64      0.42      0.51        55
       corn       0.38      0.33      0.35        45

avg / total       0.73      0.84      0.78      2545



In [373]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.97      0.99      0.98      1087
        acq       0.91      0.97      0.94       710
   money-fx       0.70      0.79      0.74       145
      grain       0.46      0.31      0.37        42
      crude       0.78      0.81      0.80       164
      trade       0.71      0.86      0.78       109
   interest       0.77      0.76      0.76       117
       ship       0.64      0.66      0.65        71
      wheat       0.73      0.67      0.70        55
       corn       0.56      0.76      0.64        45

avg / total       0.88      0.91      0.89      2545



In [345]:
# linear svc svd 200
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.97      0.99      0.98      1087
        acq       0.92      0.97      0.94       710
   money-fx       0.73      0.81      0.77       145
      grain       0.52      0.33      0.41        42
      crude       0.76      0.82      0.79       164
      trade       0.72      0.87      0.79       109
   interest       0.76      0.78      0.77       117
       ship       0.65      0.69      0.67        71
      wheat       0.75      0.75      0.75        55
       corn       0.56      0.64      0.60        45

avg / total       0.88      0.92      0.90      2545



In [339]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.98      0.99      0.98      1087
        acq       0.94      0.98      0.96       710
   money-fx       0.74      0.79      0.76       145
      grain       0.68      0.55      0.61        42
      crude       0.80      0.89      0.84       164
      trade       0.74      0.85      0.79       109
   interest       0.82      0.75      0.79       117
       ship       0.70      0.63      0.67        71
      wheat       0.75      0.75      0.75        55
       corn       0.65      0.71      0.68        45

avg / total       0.90      0.92      0.91      2545



In [335]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.98      0.99      0.98      1087
        acq       0.93      0.98      0.96       710
   money-fx       0.73      0.79      0.76       145
      grain       0.66      0.55      0.60        42
      crude       0.80      0.88      0.84       164
      trade       0.73      0.83      0.78       109
   interest       0.82      0.75      0.79       117
       ship       0.74      0.63      0.68        71
      wheat       0.75      0.75      0.75        55
       corn       0.65      0.69      0.67        45

avg / total       0.90      0.92      0.91      2545



In [293]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.98      0.99      0.98      1087
        acq       0.93      0.98      0.96       710
   money-fx       0.70      0.77      0.73       145
      grain       0.67      0.52      0.59        42
      crude       0.77      0.88      0.82       164
      trade       0.75      0.86      0.80       109
   interest       0.79      0.73      0.76       117
       ship       0.71      0.55      0.62        71
      wheat       0.73      0.69      0.71        55
       corn       0.64      0.67      0.65        45

avg / total       0.89      0.92      0.90      2545



In [184]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.40      0.88      0.55      1164
        acq       0.24      0.16      0.20       664
   money-fx       0.00      0.00      0.00       161
      grain       0.00      0.00      0.00        50
      crude       0.00      0.00      0.00       144
      trade       0.00      0.00      0.00       107
   interest       0.00      0.00      0.00        90
       ship       0.00      0.00      0.00        67
      wheat       0.00      0.00      0.00        65
       corn       0.00      0.00      0.00        64

avg / total       0.24      0.44      0.30      2576



  'precision', 'predicted', average, warn_for)


In [171]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.40      0.61      0.49      1164
        acq       0.23      0.22      0.23       664
   money-fx       0.07      0.04      0.05       161
      grain       0.00      0.00      0.00        50
      crude       0.00      0.00      0.00       144
      trade       0.02      0.11      0.04       107
   interest       0.00      0.00      0.00        90
       ship       0.00      0.00      0.00        67
      wheat       0.00      0.00      0.00        65
       corn       0.00      0.00      0.00        64

avg / total       0.25      0.34      0.28      2576



  'precision', 'predicted', average, warn_for)


In [260]:
df_train.body.values

array(['U.S. economic data this week could be\nthe key in determining whether U.S. interest rate futures break\nout of a 3-1/2 month trading range, financial analysts said.\n    Although market expectations are for February U.S. retail\nsales Thursday and industrial production Friday to show healthy\ngains, figures within or slightly below expectations would be\npositive for the market, the analysts said.\n    "You have to be impressed with the resiliency of bonds\nright now," said Smith Barney Harris Upham analyst Craig\nSloane.\n    Treasury bond futures came under pressure today which\ntraders linked to a persistently firm federal funds rate and a\nrise in oil prices. However, when sufficient selling interest\nto break below chart support in the June contract failed to\nmaterialize, participants who had sold bond futures early\nquickly covered short positions, they said.\n    "Everyone is expecting strong numbers, and if they come in\nas expected it won\'t be that bad for the market

In [263]:
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(df_train.body.apply(lambda x: "" if x is None else x).values)
X_test = vectorizer.transform(df_test.body.apply(lambda x: "" if x is None else x).values)

In [287]:
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform([d["body"] if d["body"] is not None else "" for d in train_docs])
X_test = vectorizer.transform([d["body"] if d["body"] is not None else "" for d in test_docs])

In [276]:
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform([d["text"] for d in train_docs])
X_test = vectorizer.transform([d["text"] for d in test_docs])

In [282]:
y_train = np.array(train_labels)
y_test = np.array(test_labels)

In [283]:
model = LinearSVC()

In [288]:
model.fit(X_train, y_train)

LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,
     intercept_scaling=1, loss='squared_hinge', max_iter=1000,
     multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,
     verbose=0)

In [289]:
y_pred = model.predict(X_test)

In [290]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.98      0.95      0.96      1087
        acq       0.74      0.97      0.84       710
   money-fx       0.70      0.60      0.65       145
      grain       0.56      0.33      0.42        42
      crude       0.78      0.73      0.75       164
      trade       0.72      0.79      0.75       109
   interest       0.75      0.57      0.65       117
       ship       0.71      0.62      0.66        71
      wheat       0.70      0.71      0.70        55
       corn       0.66      0.56      0.60        45

avg / total       0.84      0.87      0.84      2545



In [None]:
np.arra

In [9]:
np.array?

In [10]:
np.lookfor('diagonal')

Search results for 'diagonal'
-----------------------------
numpy.diagonal
    Return specified diagonals.
numpy.eye
    Return a 2-D array with ones on the diagonal and zeros elsewhere.
numpy.tri
    An array with ones at and below the given diagonal and zeros elsewhere.
numpy.diag
    Extract a diagonal or construct a diagonal array.
numpy.tril
    Lower triangle of an array.
numpy.triu
    Upper triangle of an array.
numpy.trace
    Return the sum along diagonals of the array.
numpy.fill_diagonal
    Fill the main diagonal of the given array of any dimensionality.
numpy.diagflat
    Create a two-dimensional array with the flattened input as a diagonal.
numpy.ma.diagonal
    a.diagonal(offset=0, axis1=0, axis2=1)
numpy.diag_indices
    Return the indices to access the main diagonal of an array.
numpy.ma.diag
    Extract a diagonal or construct a diagonal array.
numpy.chararray.diagonal
    Return specified diagonals. In NumPy 1.9 the returned array is a
numpy.matlib.eye
    Return a 