In [1]:
%run -i "../util/lang_utils.ipynb"

In [2]:
import pandas as pd
from spacy.cli.train import train
from spacy.cli.evaluate import evaluate
from spacy.cli.debug_data import debug_data
from spacy.tokens import DocBin
from sklearn.metrics import classification_report
# Config generated at https://spacy.io/usage/training

In [3]:
def preprocess_data_entry(input_text, label, label_list):
    doc = small_model(input_text)
    cats = [0] * len(label_list)
    cats[label] = 1
    final_cats = {}
    for i, label in enumerate(label_list):
        final_cats[label] = cats[i]
    doc.cats = final_cats
    return doc

In [4]:
# Load and prepare data
train_db = DocBin()
test_db = DocBin()
label_list = ["tech", "business", "sport", "entertainment", "politics"]
train_df = pd.read_json("../data/bbc_train.json")
test_df = pd.read_json("../data/bbc_test.json")
train_df.sample(frac=1)
for idx, row in train_df.iterrows():
    text = row["text"]
    label = row["label"]
    doc = preprocess_data_entry(text, label, label_list)
    train_db.add(doc)
for idx, row in test_df.iterrows():
    text = row["text"]
    label = row["label"]
    doc = preprocess_data_entry(text, label, label_list)
    test_db.add(doc)
train_db.to_disk('../data/bbc_train.spacy')
test_db.to_disk('../data/bbc_test.spacy')

In [39]:
train("../data/spacy_config.cfg", output_path="../models/spacy_textcat_bbc")

[38;5;4mℹ Saving to output directory: ../models/spacy_textcat_bbc[0m
[38;5;4mℹ Using CPU[0m
[1m
[38;5;2m✔ Initialized pipeline[0m
[1m
[38;5;4mℹ Pipeline: ['tok2vec', 'textcat'][0m
[38;5;4mℹ Initial learn rate: 0.001[0m
E    #       LOSS TOK2VEC  LOSS TEXTCAT  CATS_SCORE  SCORE 
---  ------  ------------  ------------  ----------  ------
  0       0          0.00          0.16        8.48    0.08
  0     200         20.77         37.26       35.58    0.36
  0     400         98.56         35.96       26.90    0.27
  0     600         49.83         37.31       36.60    0.37
  0     800         96.46         27.11       38.64    0.39
  0    1000        102.35         22.53       43.11    0.43
  1    1200        101.42         23.68       61.93    0.62
  1    1400         50.70         19.16       55.75    0.56
  1    1600        224.28         15.09       46.57    0.47
  1    1800        354.78         18.86       84.20    0.84
  1    2000        131.65         16.03       64.

In [15]:
# Use the trained model
nlp = spacy.load("../models/spacy_textcat_bbc/model-last")
input_text = test_df.iloc[1, test_df.columns.get_loc('text')]
print(input_text)
print(test_df["label_text"].iloc[[1]])
doc = nlp(input_text)
print("Predicted probabilities: ", doc.cats)

lib dems  new election pr chief the lib dems have appointed a senior figure from bt to be the party s new communications chief for their next general election effort.  sandy walkington will now work with senior figures such as matthew taylor on completing the party manifesto. party chief executive lord rennard said the appointment was a  significant strengthening of the lib dem team . mr walkington said he wanted the party to be ready for any  mischief  rivals or the media tried to throw at it.   my role will be to ensure this new public profile is effectively communicated at all levels   he said.  i also know the party will be put under scrutiny in the media and from the other parties as never before - and we will need to show ourselves ready and prepared to counter the mischief and misrepresentation that all too often comes from the party s opponents.  the party is already demonstrating on every issue that it is the effective opposition.  mr walkington s new job title is director of 

In [16]:
# Evaluate the model on test data
def get_prediction(input_text, nlp_model, target_names):
    doc = nlp_model(input_text)
    category = max(doc.cats, key = doc.cats.get)
    return target_names.index(category)
test_df["prediction"] = test_df["text"].apply(lambda x: get_prediction(x, nlp, label_list))

In [17]:
print(classification_report(test_df["label"], test_df["prediction"], target_names=label_list))

               precision    recall  f1-score   support

         tech       0.82      0.94      0.87        80
     business       0.94      0.83      0.89       102
        sport       0.89      0.89      0.89       102
entertainment       0.94      0.87      0.91        77
     politics       0.78      0.83      0.80        84

     accuracy                           0.87       445
    macro avg       0.87      0.87      0.87       445
 weighted avg       0.88      0.87      0.87       445



In [3]:
evaluate('../models/spacy_textcat_bbc/model-last', '../data/bbc_test.spacy')

{'token_acc': 1.0,
 'token_p': 1.0,
 'token_r': 1.0,
 'token_f': 1.0,
 'cats_score': 0.8719339318444819,
 'cats_score_desc': 'macro F',
 'cats_micro_p': 0.8719101123595505,
 'cats_micro_r': 0.8719101123595505,
 'cats_micro_f': 0.8719101123595505,
 'cats_macro_p': 0.8746516896205309,
 'cats_macro_r': 0.8732906799083269,
 'cats_macro_f': 0.8719339318444819,
 'cats_macro_auc': 0.9800144873453936,
 'cats_f_per_type': {'tech': {'p': 0.8152173913043478,
   'r': 0.9375,
   'f': 0.872093023255814},
  'business': {'p': 0.9444444444444444,
   'r': 0.8333333333333334,
   'f': 0.8854166666666667},
  'sport': {'p': 0.8921568627450981,
   'r': 0.8921568627450981,
   'f': 0.8921568627450981},
  'entertainment': {'p': 0.9436619718309859,
   'r': 0.8701298701298701,
   'f': 0.9054054054054054},
  'politics': {'p': 0.7777777777777778,
   'r': 0.8333333333333334,
   'f': 0.8045977011494253}},
 'cats_auc_per_type': {'tech': 0.9842808219178081,
  'business': 0.9824501229063054,
  'sport': 0.993354484651003