In [163]:
# imports
import conllu
from sklearn.naive_bayes import MultinomialNB
import numpy as np
from ordered_set import OrderedSet
from tqdm import tqdm
from sklearn.metrics import classification_report

In [164]:
# load the conllu dataset
with open("../data/dataset.conllu") as f:
    data = conllu.parse(f.read())

In [165]:
# extract headlines (since a headline can have more than one sentence)
headlines = []
for i, sentence in enumerate(data):
    if sentence.metadata["sent_id"] == "0":
        headlines.append(data[i:i+1])
    else:
        headlines[-1].append(sentence)
print(len(headlines))

28619


In [None]:
# gather all features (i.e. all tokens)
features = OrderedSet()
for headline in headlines:
    for sentence in headline:
        for token in sentence:
            features |= [token["lemma"]]
print(len(features))
print(features)

21280


In [167]:
# create the Bag of Words model
bow = MultinomialNB()
classes = [0, 1]

# fit the model to the data
for headline in tqdm(headlines):
    X = np.zeros(len(features))
    y = np.array([int(headline[0].metadata["class"])])
    for sentence in headline:
        for token in sentence:
            X[features.index(token["lemma"])] += 1
    bow.partial_fit(X.reshape(1, -1), y, classes)

100%|██████████| 28619/28619 [00:26<00:00, 1085.74it/s]


In [168]:
# test some headlines
y_true = []
y_pred = []
for headline in tqdm(headlines):
    X = np.zeros(len(features))
    y = np.array([int(headline[0].metadata["class"])])
    for sentence in headline:
        for token in sentence:
            X[features.index(token["lemma"])] += 1
    y_pred.append(bow.predict(X.reshape(1, -1))[0])
    y_true.append(y[0])
target_names = ['Non-sarcastic', 'Sarcastic']
print(classification_report(y_true, y_pred, target_names=target_names))

100%|██████████| 28619/28619 [00:03<00:00, 7917.50it/s]

               precision    recall  f1-score   support

Non-sarcastic       0.91      0.93      0.92     14985
    Sarcastic       0.92      0.90      0.91     13634

     accuracy                           0.91     28619
    macro avg       0.91      0.91      0.91     28619
 weighted avg       0.91      0.91      0.91     28619




