In [1]:
from src.parsers import NewsgroupsParser, ReutersParser
from src.engines.doc2vec import Doc2VecModel

NEWSGROUPS = NewsgroupsParser()
REUTERS = ReutersParser()

In [12]:
import numpy as np
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

def train(model: Doc2VecModel):
    # get the unique labels
    labels = []
    for entry in model.dataset.entries:
        for label in entry.labels:
            if label not in labels:
                labels.append(label)
    print(*enumerate(labels))

    # get train and test sets
    X = np.array([vector for vector in map(
        lambda e: model.model.dv[e.id],
        model.dataset.entries)])
    y = np.array([
        y_labs for y_labs in map(
        lambda e: np.array(np.array(
            [int(label in e.labels) for label in labels])),
        model.dataset.entries)]
    )
    # print(X)
    print(y, y.shape)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

    # train classifier
    clf = OneVsRestClassifier(SVC(kernel="poly", decision_function_shape="ovo"))
    clf.fit(X_train, y_train)

    # print report of training
    y_pred = clf.predict(X_test)
    report = classification_report(y_test, y_pred, zero_division=0)

    return y_pred, y_test, report


In [6]:
news_model = Doc2VecModel(NEWSGROUPS, use_predictor=False)
train(news_model)


[[1 0 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 1]] (18828, 20)
              precision    recall  f1-score   support

           0       0.87      0.33      0.48       208
           1       0.97      0.76      0.85       205
           2       0.94      0.60      0.74       207
           3       0.96      0.73      0.83       223
           4       0.86      0.47      0.61       163
           5       0.97      0.83      0.90       188
           6       0.90      0.56      0.69       141
           7       0.91      0.57      0.70       190
           8       0.93      0.55      0.69       211
           9       0.83      0.38      0.52       193
          10       0.94      0.67      0.78       176
          11       0.95      0.70      0.80       210
          12       0.97      0.76      0.85       191
          13       0.82      0.48      0.60       205
          14       0.90      0.64      0.75       188
   

In [13]:
reuters_model = Doc2VecModel(REUTERS, use_predictor=False)
y_pred, report = train(reuters_model)
print(report)


(0, 'trade') (1, 'grain') (2, 'crude') (3, 'nat-gas') (4, 'corn') (5, 'rice') (6, 'rubber') (7, 'sugar') (8, 'tin') (9, 'palm-oil') (10, 'veg-oil') (11, 'ship') (12, 'coffee') (13, 'lumber') (14, 'wheat') (15, 'gold') (16, 'acq') (17, 'interest') (18, 'money-fx') (19, 'copper') (20, 'ipi') (21, 'carcass') (22, 'livestock') (23, 'oilseed') (24, 'soybean') (25, 'earn') (26, 'bop') (27, 'gas') (28, 'lead') (29, 'jobs') (30, 'zinc') (31, 'cpi') (32, 'gnp') (33, 'soy-oil') (34, 'dlr') (35, 'yen') (36, 'nickel') (37, 'groundnut') (38, 'heat') (39, 'sorghum') (40, 'sunseed') (41, 'pet-chem') (42, 'cocoa') (43, 'rapeseed') (44, 'cotton') (45, 'money-supply') (46, 'iron-steel') (47, 'l-cattle') (48, 'alum') (49, 'palladium') (50, 'platinum') (51, 'strategic-metal') (52, 'reserves') (53, 'groundnut-oil') (54, 'lin-oil') (55, 'meal-feed') (56, 'rape-oil') (57, 'sun-meal') (58, 'sun-oil') (59, 'hog') (60, 'barley') (61, 'potato') (62, 'orange') (63, 'retail') (64, 'soy-meal') (65, 'cotton-oil') (6

In [None]:
from sklearn.metrics import confusion_matrix
confusion_matrix(y_pred)