In [1]:
import pickle

import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from skmultilearn.problem_transform import LabelPowerset

# Load training and test sets

In [2]:
X_train = pd.read_pickle("../../pickled_files/X_train.pkl")
X_test =pd.read_pickle("../../pickled_files/X_test.pkl")

In [3]:
mlb = pickle.load(open("../../pickled_files/mlb.pkl", 'rb'))
y_train = pickle.load(open("../../pickled_files/y_train.pkl", 'rb'))
y_test = pickle.load(open("../../pickled_files/y_test.pkl", 'rb'))

# Train the model

In [4]:
text_clf = Pipeline([('vect', CountVectorizer()),
                     ('tfidf', TfidfTransformer()),
                     ('clf-svm', LabelPowerset(SGDClassifier(loss='hinge',
                                                             penalty='l2',
                                                             alpha=1e-4,
                                                             max_iter=12)))])


# Fit the model
text_clf = text_clf.fit(X_train, y_train)

# Make prediction
predicted= text_clf.predict(X_test)



# Evaluate the model

In [5]:
print(classification_report(y_test, predicted))

              precision    recall  f1-score   support

           0       0.73      0.49      0.58       121
           1       0.68      0.84      0.75       289
           2       0.86      0.71      0.78       301
           3       0.82      0.59      0.69        80
           4       0.82      0.65      0.72       145

   micro avg       0.76      0.70      0.73       936
   macro avg       0.78      0.66      0.70       936
weighted avg       0.78      0.70      0.73       936
 samples avg       0.77      0.73      0.74       936



In [6]:
print(accuracy_score(y_test, predicted))

0.6281208935611038


In [7]:
np.mean(predicted == y_test)

0.8733245729303548