In [None]:
pip install sentence-transformers


In [None]:
pip install pandas

In [10]:
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, multilabel_confusion_matrix


In [12]:
train = pd.read_csv(f'dataset/ptc_adjust/ptc_preproc_train.csv', sep=";")
test = pd.read_csv(f'dataset/ptc_adjust/ptc_preproc_test.csv', sep=";")


In [13]:
train.head()

Unnamed: 0.1,Unnamed: 0,text,label,category
0,0,Stop Islamization of America.\t,Slogans,Call
1,1,We condemn all those whose behaviours and view...,Black-and-White_Fallacy,
2,2,Defeat Jihad`,Slogans,Call
3,3,the nation that gave the world the Magna Carta...,Loaded_Language,Manipulative_wording
4,4,The UK should never become a stage for inflamm...,Flag-Waving,Justification


In [14]:
train.dropna(subset=["text", "label"], inplace=True)


In [15]:
train.drop_duplicates(subset=["text"], inplace=True)


In [16]:
test.drop_duplicates(subset=["text"], inplace=True)


In [17]:
train["label"].value_counts()


label
Loaded_Language                                               1595
Name_Calling_Labeling                                          824
Doubt                                                          408
Exaggeration_Minimisation                                      349
Repetition                                                     230
Causal_Oversimplification                                      162
Appeal_to_fear-prejudice                                       160
Flag-Waving                                                    144
Slogans                                                         95
Black-and-White_Fallacy                                         91
Appeal_to_Authority                                             86
Thought-terminating_Cliches                                     57
Whataboutism                                                    52
Reductio_ad_hitlerum                                            38
Red_Herring                                             

In [18]:
model_name = "sentence-transformers/stsb-xlm-r-multilingual"


In [19]:
import torch


In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [21]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name).to(device)


  from .autonotebook import tqdm as notebook_tqdm
.gitattributes: 100%|██████████| 574/574 [00:00<00:00, 288kB/s]
1_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 245kB/s]
README.md: 100%|██████████| 3.70k/3.70k [00:00<00:00, 7.50MB/s]
config.json: 100%|██████████| 709/709 [00:00<00:00, 1.80MB/s]
config_sentence_transformers.json: 100%|██████████| 122/122 [00:00<00:00, 578kB/s]
pytorch_model.bin: 100%|██████████| 1.11G/1.11G [04:00<00:00, 4.63MB/s]
sentence_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 23.0kB/s]
sentencepiece.bpe.model: 100%|██████████| 5.07M/5.07M [00:01<00:00, 4.88MB/s]
special_tokens_map.json: 100%|██████████| 150/150 [00:00<00:00, 47.8kB/s]
tokenizer.json: 100%|██████████| 9.10M/9.10M [00:01<00:00, 4.59MB/s]
tokenizer_config.json: 100%|██████████| 505/505 [00:00<00:00, 258kB/s]
modules.json: 100%|██████████| 229/229 [00:00<00:00, 106kB/s]


In [22]:
device

'cpu'

In [23]:
train.dropna()
test.dropna()

Unnamed: 0.1,Unnamed: 0,text,label,category
0,0,The next transmission could be more pronounced...,Appeal_to_Authority,Justification
1,1,when (the plague) comes again it starts from m...,Appeal_to_Authority,Justification
2,2,appeared,Doubt,Attack_on_reputation
3,3,"a very, very different",Repetition,Manipulative_wording
4,4,He also pointed to the presence of the pneumon...,Appeal_to_fear-prejudice,Justification
...,...,...,...,...
1375,1375,a great First Amendment victory,Exaggeration_Minimisation,Manipulative_wording
1376,1376,Trump-hating Republican,Name_Calling_Labeling,Attack_on_reputation
1377,1377,grave hardship,Loaded_Language,Manipulative_wording
1378,1378,unbelievably rude,Name_Calling_Labeling,Attack_on_reputation


In [24]:
def encode(data):
  return model.encode(data)

In [25]:
train['embeddings'] = train['text'].apply(encode)
test['embeddings'] = test['text'].apply(encode)

train.head()

Unnamed: 0.1,Unnamed: 0,text,label,category,embeddings
0,0,Stop Islamization of America.\t,Slogans,Call,"[-0.6336295, 0.2622935, -0.058359243, -0.45848..."
1,1,We condemn all those whose behaviours and view...,Black-and-White_Fallacy,,"[0.26669678, 0.43944684, 0.4905479, -0.0066375..."
2,2,Defeat Jihad`,Slogans,Call,"[-0.07159674, 0.33379942, 0.71964896, 0.092027..."
3,3,the nation that gave the world the Magna Carta...,Loaded_Language,Manipulative_wording,"[0.22905102, 0.75942534, 0.7189353, 0.4899605,..."
4,4,The UK should never become a stage for inflamm...,Flag-Waving,Justification,"[0.20815654, 0.65153056, 1.1221485, -0.8420495..."


In [26]:
train_features = train['embeddings'].to_list()


In [27]:
test_features = test['embeddings'].to_list()

In [28]:
import numpy as np

In [29]:
train_labels, test_labels = train["label"].str.split(",").to_numpy(), test["label"].str.split(",").to_numpy()
print(f'train labels: {len(train_labels)}')
print(f'test labels: {len(test_labels)}')

labels_with_duplicates = np.hstack(np.concatenate((train_labels, test_labels), axis=None))
labels = [list(set(labels_with_duplicates))]
print(f'qty of labels: {len(labels[0])}')


train labels: 4464
test labels: 1210
qty of labels: 18


In [30]:
mlb = MultiLabelBinarizer()
train_labels_binarized = mlb.fit(labels).transform(train_labels)
test_labels_binarized = mlb.transform(test_labels)
# train_labels_binarized = mlb.fit_transform(train_labels)
# test_labels_binarized = mlb.fit_transform(test_labels)
print(f'qty labels train: {len(train_labels_binarized[0])}')
print(f'qty labels test: {len(test_labels_binarized[0])}')


qty labels train: 18
qty labels test: 18


In [31]:
ff = MLPClassifier(
    random_state=1,
    max_iter=400,
    alpha=0.001,
    shuffle=True,
    early_stopping=True,
    verbose=True
).fit(train_features, train_labels_binarized)


Iteration 1, loss = 6.34985815
Validation score: 0.279642
Iteration 2, loss = 2.71354466
Validation score: 0.393736
Iteration 3, loss = 2.33581964
Validation score: 0.416107
Iteration 4, loss = 2.15033951
Validation score: 0.440716
Iteration 5, loss = 2.02521041
Validation score: 0.436242
Iteration 6, loss = 1.93985139
Validation score: 0.442953
Iteration 7, loss = 1.87019953
Validation score: 0.458613
Iteration 8, loss = 1.80722805
Validation score: 0.460850
Iteration 9, loss = 1.75211482
Validation score: 0.467562
Iteration 10, loss = 1.70055889
Validation score: 0.476510
Iteration 11, loss = 1.65745585
Validation score: 0.474273
Iteration 12, loss = 1.61534026
Validation score: 0.474273
Iteration 13, loss = 1.57936624
Validation score: 0.469799
Iteration 14, loss = 1.54437762
Validation score: 0.476510
Iteration 15, loss = 1.50903695
Validation score: 0.496644
Iteration 16, loss = 1.47102536
Validation score: 0.478747
Iteration 17, loss = 1.44290129
Validation score: 0.474273
Iterat

In [32]:
test_predicted_labels_binarized = ff.predict(test_features)
micro_f1 = f1_score(test_labels_binarized, test_predicted_labels_binarized, average="micro")
acc = accuracy_score(test_labels_binarized, test_predicted_labels_binarized)
prec = precision_score(test_labels_binarized, test_predicted_labels_binarized, average="micro")
rec = recall_score(test_labels_binarized, test_predicted_labels_binarized, average="micro")
print(f'micro-f1: {micro_f1}')
print(f'accuracy: {acc}')
print(f'micro-precision: {prec}')
print(f'micro-recall: {rec}')


micro-f1: 0.511904761904762
accuracy: 0.4041322314049587
micro-precision: 0.6043243243243244
micro-recall: 0.44400317712470216


In [33]:
cf_mtx = multilabel_confusion_matrix(test_labels_binarized, test_predicted_labels_binarized)
cf_mtx.shape


(18, 2, 2)

In [34]:
cf_mtx


array([[[1161,    0],
        [  49,    0]],

       [[1050,   32],
        [  97,   31]],

       [[1206,    0],
        [   4,    0]],

       [[1176,    9],
        [  23,    2]],

       [[1171,    8],
        [  27,    4]],

       [[1115,   24],
        [  46,   25]],

       [[1077,   48],
        [  46,   39]],

       [[1117,   15],
        [  52,   26]],

       [[ 685,  130],
        [  85,  310]],

       [[ 956,   75],
        [  79,  100]],

       [[1204,    0],
        [   6,    0]],

       [[1197,    1],
        [  12,    0]],

       [[1197,    1],
        [  11,    1]],

       [[1082,   15],
        [ 100,   13]],

       [[1170,    6],
        [  26,    8]],

       [[1208,    0],
        [   2,    0]],

       [[1192,    2],
        [  16,    0]],

       [[1191,    0],
        [  19,    0]]])