# TDA classification

In [None]:
import numpy as np
from TopologicalFeaturesConstructor import TopologicalFeaturesConstructor
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import cross_validate
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, roc_auc_score
from tqdm import tqdm
import pickle

In [None]:
for embeddings_type, dict_size in zip(('SVD', 'CBoW'), ('1000', '8')):
    tda_constructor = TopologicalFeaturesConstructor(
        dict_path=f'data/Russian/Russian_dict_{embeddings_type}_{dict_size}.npy',
        voids_embeddings_path=f'data/voids/voids_{embeddings_type}.npy',
        embeddings_len=8
    )

    for corpus_name in ('literature', 'mGPT', 'GPT2', 'balaboba'):
        corpus = open(
            f'classification/corpora/Russian_{corpus_name}_corpus.txt'
        ).read().split('\n')

        topological_features = np.vstack([
            tda_constructor.construct_topological_features(text.split()) for text in tqdm(corpus)
        ])

        np.save(
            f'classification/topological_features/{embeddings_type}/{corpus_name}.npy',
            topological_features
        )

100%|██████████| 3905/3905 [33:14<00:00,  1.96it/s]
100%|██████████| 433/433 [03:25<00:00,  2.10it/s]
100%|██████████| 2629/2629 [20:23<00:00,  2.15it/s]
100%|██████████| 843/843 [06:52<00:00,  2.04it/s]
100%|██████████| 3905/3905 [38:57<00:00,  1.67it/s]
100%|██████████| 433/433 [03:54<00:00,  1.85it/s]
100%|██████████| 2629/2629 [20:59<00:00,  2.09it/s]
100%|██████████| 843/843 [07:00<00:00,  2.00it/s]


In [None]:
scoring = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']

scores = dict()
for embeddings_type in ('SVD', 'CBoW'):
    topological_features = np.vstack([
        np.load(
            f'classification/topological_features/{embeddings_type}/{corpus_name}.npy',
            allow_pickle=True
        ) for corpus_name in ('literature', 'mGPT', 'GPT2', 'balaboba')
    ])

    targets = np.hstack((
        np.zeros(topological_features.shape[0] // 2),
        np.ones(topological_features.shape[0] // 2)
    ))

    models = (LogisticRegression(max_iter=10000), SVC(kernel='linear', random_state=42))

    for model, model_type in zip(models, ('logit', 'svm')):
        curr_scores = cross_validate(model, topological_features, targets, cv=5, scoring=scoring)
        curr_scores.pop('fit_time')
        curr_scores.pop('score_time')
        scores[f'{embeddings_type} {model_type}'] = curr_scores

scores

{'SVD logit': {'test_accuracy': array([0.81690141, 0.95454545, 0.98783611, 0.98271447, 0.7893726 ]),
  'test_precision': array([0.79151943, 0.92362768, 0.99097938, 0.99867725, 1.        ]),
  'test_recall': array([0.86043534, 0.99103713, 0.98463508, 0.96670935, 0.5787452 ]),
  'test_f1': array([0.82453988, 0.95614577, 0.98779705, 0.98243331, 0.73317113]),
  'test_roc_auc': array([0.89357844, 0.99466359, 0.99817037, 0.99829497, 0.99856876])},
 'SVD svm': {'test_accuracy': array([0.82586428, 0.95902689, 0.98975672, 0.9865557 , 0.81370038]),
  'test_precision': array([0.84161074, 0.9272944 , 0.98602287, 0.9973822 , 1.        ]),
  'test_recall': array([0.8028169 , 0.99615877, 0.99359795, 0.97567222, 0.62740077]),
  'test_f1': array([0.82175623, 0.96049383, 0.98979592, 0.98640777, 0.77104642]),
  'test_roc_auc': array([0.88841582, 0.99614893, 0.99841137, 0.99913273, 0.99910978])},
 'CBoW logit': {'test_accuracy': array([0.9865557 , 0.99295775, 0.9993598 , 0.99743918, 0.99231754]),
  'test_

In [None]:
for key in scores.keys():
    print(key)
    for metric in scores[key].keys():
        print(f'    {metric[5:]}: {round(scores[key][metric].mean(), 3)}')
    print()

SVD logit
    accuracy: 0.906
    precision: 0.941
    recall: 0.876
    f1: 0.897
    roc_auc: 0.977

SVD svm
    accuracy: 0.915
    precision: 0.95
    recall: 0.879
    f1: 0.906
    roc_auc: 0.976

CBoW logit
    accuracy: 0.994
    precision: 0.995
    recall: 0.993
    f1: 0.994
    roc_auc: 0.998

CBoW svm
    accuracy: 0.993
    precision: 0.995
    recall: 0.991
    f1: 0.993
    roc_auc: 0.998

