In [42]:
import pandas as pd
from sentence_transformers import SentenceTransformer

from sklearn.metrics.pairwise import cosine_similarity # Expects 2D arrays as input
from scipy.spatial.distance import cosine # Works with 1D vectors

from sklearn.metrics import classification_report

In [6]:
smodel = SentenceTransformer('distilbert-base-nli-mean-tokens')

In [2]:
df1 = pd.read_csv('bbc_news_train.csv')

In [8]:
df1.head()

Unnamed: 0,ArticleId,Text,Category
0,1833,worldcom ex-boss launches defence lawyers defe...,business
1,154,german business confidence slides german busin...,business
2,1101,bbc poll indicates economic gloom citizens in ...,business
3,1976,lifestyle governs mobile choice faster bett...,tech
4,917,enron bosses in $168m payout eighteen former e...,business


In [7]:
def get_sentence_vector(query):
    query_vec = smodel.encode([query])[0]
    return query_vec

In [10]:
%%time
df1['textVec'] = df1['Text'].apply(lambda x: get_sentence_vector(x))

In [23]:
df1.head()

Unnamed: 0,ArticleId,Text,Category,textVec
0,1833,worldcom ex-boss launches defence lawyers defe...,business,"[0.0763296, -0.21301335, 0.1430163, -0.6761660..."
1,154,german business confidence slides german busin...,business,"[-0.6852783, -0.17200232, 0.65315604, -1.12945..."
2,1101,bbc poll indicates economic gloom citizens in ...,business,"[-0.3528403, -1.0351946, 0.756114, -0.704674, ..."
3,1976,lifestyle governs mobile choice faster bett...,technology,"[-0.40892917, 0.105635695, 0.47965625, -0.8354..."
4,917,enron bosses in $168m payout eighteen former e...,business,"[0.37398607, -0.81744885, 0.3845015, -0.735170..."


In [26]:
def std_category(x):
    if(x == 'tech'):
        return 'technology'
    elif (x == 'sport'):
        return 'sports'
    else:
        return x

In [27]:
df1['Category'] = df1['Category'].apply(std_category)

In [28]:
def get_cosine_sim(x, Y):
    y = smodel.encode([Y])[0]
    
    return cosine(x, y)

In [24]:
df1['cdist_business'] = df1['textVec'].apply(lambda x: get_cosine_sim(x, 'business'))

In [29]:
df1['Category'].unique()

array(['business', 'technology', 'politics', 'sports', 'entertainment'],
      dtype=object)

In [30]:
df1['cdist_technology'] = df1['textVec'].apply(lambda x: get_cosine_sim(x, 'technology'))

In [31]:
df1['cdist_politics'] = df1['textVec'].apply(lambda x: get_cosine_sim(x, 'politics'))

In [32]:
df1['cdist_sports'] = df1['textVec'].apply(lambda x: get_cosine_sim(x, 'sports'))

In [33]:
df1['cdist_entertainment'] = df1['textVec'].apply(lambda x: get_cosine_sim(x, 'entertainment'))

In [38]:
def get_prediction(in_row):
    min_dist = 99999999
    label = ""
    for i in ['cdist_business', 'cdist_technology', 'cdist_politics', 'cdist_sports', 'cdist_entertainment']:
        d = in_row[i]
        if d < min_dist:
            min_dist = d
            label = i.split('_')[1]
    return label

In [39]:
df1['prediction'] = df1.apply(lambda in_row: get_prediction(in_row), axis = 1)

In [41]:
df1.head()

Unnamed: 0,ArticleId,Text,Category,textVec,cdist_business,cdist_technology,cdist_politics,cdist_sports,cdist_entertainment,prediction
0,1833,worldcom ex-boss launches defence lawyers defe...,business,"[0.0763296, -0.21301335, 0.1430163, -0.6761660...",1.054001,1.186766,0.968711,1.15906,1.213088,politics
1,154,german business confidence slides german busin...,business,"[-0.6852783, -0.17200232, 0.65315604, -1.12945...",1.038498,1.18686,1.000978,1.192701,1.182215,politics
2,1101,bbc poll indicates economic gloom citizens in ...,business,"[-0.3528403, -1.0351946, 0.756114, -0.704674, ...",1.109749,1.203442,0.975895,1.153227,1.170546,politics
3,1976,lifestyle governs mobile choice faster bett...,technology,"[-0.40892917, 0.105635695, 0.47965625, -0.8354...",0.954146,1.072351,0.903048,1.059741,1.00419,politics
4,917,enron bosses in $168m payout eighteen former e...,business,"[0.37398607, -0.81744885, 0.3845015, -0.735170...",1.064061,1.26145,1.027481,1.2025,1.216041,politics


In [47]:
target_names = ['business', 'entertainment', 'politics', 'sports', 'technology']
print(classification_report(df1['Category'], df1['prediction'], target_names=target_names))

               precision    recall  f1-score   support

     business       0.60      0.20      0.30       336
entertainment       0.32      0.08      0.12       273
     politics       0.22      1.00      0.36       274
       sports       0.79      0.12      0.21       346
   technology       0.85      0.04      0.08       261

     accuracy                           0.28      1490
    macro avg       0.55      0.29      0.22      1490
 weighted avg       0.57      0.28      0.22      1490



In [45]:
from collections import Counter

In [46]:
Counter(df1['Category'])

Counter({'business': 336,
         'technology': 261,
         'politics': 274,
         'sports': 346,
         'entertainment': 273})