Text Classification

<!-- Source: https://github.com/chseifert/tutorials/blob/master/nlp-ie/Text-Classification.ipynb -->
Used method: https://miguelmalvarez.com/2016/11/07/classifying-reuters-21578-collection-with-python/ 


Goal:
- Predicting whether an article belongs to a certain category using a multi-label classifier

In [1]:
import nltk
nltk.download('reuters')


[nltk_data] Downloading package reuters to
[nltk_data]     C:\Users\vreem\AppData\Roaming\nltk_data...
[nltk_data]   Package reuters is already up-to-date!


True

In [2]:
#get reuters
from nltk.corpus import reuters
reuters.categories()

['acq',
 'alum',
 'barley',
 'bop',
 'carcass',
 'castor-oil',
 'cocoa',
 'coconut',
 'coconut-oil',
 'coffee',
 'copper',
 'copra-cake',
 'corn',
 'cotton',
 'cotton-oil',
 'cpi',
 'cpu',
 'crude',
 'dfl',
 'dlr',
 'dmk',
 'earn',
 'fuel',
 'gas',
 'gnp',
 'gold',
 'grain',
 'groundnut',
 'groundnut-oil',
 'heat',
 'hog',
 'housing',
 'income',
 'instal-debt',
 'interest',
 'ipi',
 'iron-steel',
 'jet',
 'jobs',
 'l-cattle',
 'lead',
 'lei',
 'lin-oil',
 'livestock',
 'lumber',
 'meal-feed',
 'money-fx',
 'money-supply',
 'naphtha',
 'nat-gas',
 'nickel',
 'nkr',
 'nzdlr',
 'oat',
 'oilseed',
 'orange',
 'palladium',
 'palm-oil',
 'palmkernel',
 'pet-chem',
 'platinum',
 'potato',
 'propane',
 'rand',
 'rape-oil',
 'rapeseed',
 'reserves',
 'retail',
 'rice',
 'rubber',
 'rye',
 'ship',
 'silver',
 'sorghum',
 'soy-meal',
 'soy-oil',
 'soybean',
 'strategic-metal',
 'sugar',
 'sun-meal',
 'sun-oil',
 'sunseed',
 'tea',
 'tin',
 'trade',
 'veg-oil',
 'wheat',
 'wpi',
 'yen',
 'zinc']

In [6]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import classification_report

documents = reuters.fileids()

train_docs_id = list(filter(lambda doc: doc.startswith("train"),
documents))
test_docs_id = list(filter(lambda doc: doc.startswith("test"),
documents))

train_docs = [reuters.raw(doc_id) for doc_id in train_docs_id]
test_docs = [reuters.raw(doc_id) for doc_id in test_docs_id]

vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(train_docs)
X_test = vectorizer.transform(test_docs)

mlb = MultiLabelBinarizer()
train_labels = mlb.fit_transform([reuters.categories(doc_id)
for doc_id in train_docs_id])
test_labels = mlb.transform([reuters.categories(doc_id)
for doc_id in test_docs_id])

classifier = OneVsRestClassifier(LinearSVC(random_state=42, max_iter=100000, dual='auto'))
classifier.fit(X_train, train_labels)

predictions = classifier.predict(X_test)

print(classification_report(test_labels, predictions, target_names=mlb.classes_))




                 precision    recall  f1-score   support

            acq       0.97      0.95      0.96       719
           alum       1.00      0.39      0.56        23
         barley       1.00      0.64      0.78        14
            bop       0.78      0.70      0.74        30
        carcass       0.92      0.67      0.77        18
     castor-oil       0.00      0.00      0.00         1
          cocoa       1.00      0.83      0.91        18
        coconut       0.00      0.00      0.00         2
    coconut-oil       0.00      0.00      0.00         3
         coffee       0.93      0.93      0.93        28
         copper       1.00      0.78      0.88        18
     copra-cake       0.00      0.00      0.00         1
           corn       0.91      0.86      0.88        56
         cotton       1.00      0.50      0.67        20
     cotton-oil       0.00      0.00      0.00         2
            cpi       0.68      0.46      0.55        28
            cpu       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
