# Model Creation - SVM
This notebook creates the model that classify the abstracts

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report,accuracy_score,confusion_matrix,precision_score,recall_score,f1_score

Import data

In [2]:
data = pd.read_pickle('preprocessed_abstracts.pkl')
data.iloc[0]['abstract']


'bojan pandžić  born 13 march 1982  swedish football referee  pandžić currently resides hisings backa  part gothenburg  he full international referee fifa since 2014  he became professional referee 2004 allsvenskan referee since 2009  pandzic refereed 42 matches allsvenskan  65 matches superettan 8 international matches 2014 '

Divide into training and test data. 80% train 20% test.

In [3]:
x = data['abstract']
y = data['label']
classes = y.unique()
classes.sort()
train_x, test_x, train_y, test_y = train_test_split(x,y,test_size=0.2, random_state=1)

In [5]:
nb = Pipeline([('vect', CountVectorizer()),
               ('tfidf', TfidfTransformer()),
               ('clf', SGDClassifier(loss='hinge', penalty='l2',alpha=1e-3, random_state=42, max_iter=5, tol=None),),
              ])

In [6]:
nb.fit(train_x, train_y)

Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...ty='l2', power_t=0.5, random_state=42, shuffle=True,
       tol=None, verbose=0, warm_start=False))])

In [7]:
predictions = nb.predict(test_x)

In [8]:
print('accuracy %s' % accuracy_score(predictions, test_y))
print(classification_report(test_y, predictions))


accuracy 0.991045991045991
             precision    recall  f1-score   support

     Animal       0.98      1.00      0.99      1976
       City       1.00      1.00      1.00      1834
    Country       1.00      0.34      0.51        56
     Person       1.00      1.00      1.00      1048

avg / total       0.99      0.99      0.99      4914



In [9]:
print(confusion_matrix(test_y, predictions))

[[1975    0    0    1]
 [   0 1832    0    2]
 [  28    7   19    2]
 [   4    0    0 1044]]


Show faulty predictions

In [10]:
for idx, row in enumerate(test_x):
    if test_y.iloc[idx] != predictions[idx]:
        print('Prediction:',predictions[idx],', True:',test_y.iloc[idx])
        print('Abstract:',row)

Prediction: Animal , True: Country
Abstract: the federation bosnia herzegovina  bosnian  croatian serbian  federacija bosne hercegovine  cyrillic script  федерација босне и херцеговине  pronounced  federǎːtsija bôsneː xěrtsegoʋineː   one two political entities compose bosnia herzegovina  republika srpska  the federation bosnia herzegovina consists 10 autonomous cantons governments  it inhabited primarily bosniaks bosnian croats  sometimes informally referred bosniakcroat federation  bosnian serbs third constituency entity   it sometimes known shorter name federation b  h  federacija bih   the federation created 1994 washington agreement  ended part conflict whereby bosnian croats fought bosniaks  it established constituent assembly continued work october 1996  the federation capital  government  president  parliament  customs police departments  two postal systems airline  bh airlines   it army  army federation bosnia herzegovina  merged army republika srpska form armed forces bosnia h

Precision score per class

In [11]:
precision = precision_score(test_y, predictions, average=None)
recall = recall_score(test_y, predictions, average=None)
f1 = f1_score(test_y, predictions, average=None)

In [12]:
metrics = pd.DataFrame(np.c_[precision,recall,f1], index=classes)
metrics.columns = ['precision', 'recall', 'f1 score']
metrics.plot.bar(rot=0)
plt.show()

SyntaxError: invalid syntax (<ipython-input-12-ec8dc30afd40>, line 4)