In [22]:
import os
from random import sample

import pandas as pd
import numpy as np
from itertools import combinations
from pprint import pprint as print

from ripser import Rips
from gudhi.hera import wasserstein_distance
import gudhi

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_score

In [23]:
import warnings
warnings.filterwarnings("ignore")

In [24]:
def create_rips(distance_file):
    distances = pd.read_csv("distances/" + distance_file, index_col=0)
    distances = distances.to_numpy()
    rips_complex = gudhi.RipsComplex(distance_matrix=distances)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)

    diag = simplex_tree.persistence()
    
    components = []
    loops = []
    
    for d in diag:
        if d[0] == 0:
            components.append(d[1])
        elif d[0] == 1:
            loops.append(d[1])
            
    components_matrix = []
    for comp in components:
        if comp[1] == float("inf"):
            components_matrix.append([comp[0], 1])
        else:
            components_matrix.append([comp[0], comp[1]])
            
    loops_matrix = []
    for loop in loops:
        if loop[1] == float("inf"):
            loops_matrix.append([loop[0], 1])
        else:
            loops_matrix.append([loop[0], loop[1]])

    return [components_matrix, loops_matrix]

In [25]:
directory = "distances"
book_t0 = []
book_tminus = []
book_tplus = []
authors_novels = {}
novel_names = []

for filename in os.listdir(directory):
    if filename.endswith("_tminus1.csv"):
        book_tminus.append(filename)
    elif filename.endswith("_t0.csv"):
        book_t0.append(filename)
        tmp = filename.split("_")
        if tmp[0] not in authors_novels:
            authors_novels[tmp[0]] = set()
            authors_novels[tmp[0]].add(tmp[1])
        else:
            authors_novels[tmp[0]].add(tmp[1])
        novel_names.append(tmp[1])
    elif filename.endswith("_tplus1.csv"):
        book_tplus.append(filename)

In [26]:
rips_tminus = []
rips_t0 = []
rips_tplus = []
rips = Rips()

# Calculate rips diagrams for every distance matrix
for i in range(0, len(book_t0)):
    rips_tminus.append(create_rips(book_tminus[i]))
    rips_t0.append(create_rips(book_t0[i])) 
    rips_tplus.append(create_rips(book_tplus[i]))

Rips(maxdim=1, thresh=inf, coeff=2, do_cocycles=False, n_perm = None, verbose=True)


In [27]:
# Make a distance matrix between all pairs of novels 
dist_novels = pd.DataFrame(index=novel_names, columns=novel_names)
for i in range(0, len(novel_names)):
    for j in range(i + 1, len(novel_names)):

        diag0_components = np.array(rips_tminus[i][0])
        diag1_components = np.array(rips_tminus[j][0])
        diag0_loops = np.array(rips_tminus[i][1])
        diag1_loops = np.array(rips_tminus[j][1])
        tminus = wasserstein_distance(diag0_components, diag1_components, order=1.0) + wasserstein_distance(diag0_loops, diag1_loops, order=1.0)

        diag0_components = np.array(rips_t0[i][0])
        diag1_components = np.array(rips_t0[j][0])
        diag0_loops = np.array(rips_t0[i][1])
        diag1_loops = np.array(rips_t0[j][1])
        t0 = wasserstein_distance(diag0_components, diag1_components, order=1.0) + wasserstein_distance(diag0_loops, diag1_loops, order=1.0)
        
        diag0_components = np.array(rips_tplus[i][0])
        diag1_components = np.array(rips_tplus[j][0])
        diag0_loops = np.array(rips_tplus[i][1])
        diag1_loops = np.array(rips_tplus[j][1])
        tplus = wasserstein_distance(diag0_components, diag1_components, order=1.0) + wasserstein_distance(diag0_loops, diag1_loops, order=1.0)

        dist_novels.at[novel_names[i], novel_names[j]] = t0 ** 2 + tminus ** 2 + tplus ** 2
        dist_novels.at[novel_names[j], novel_names[i]] = dist_novels.at[novel_names[i], novel_names[j]]

for i in range(0, len(novel_names)):
    dist_novels.at[novel_names[i], novel_names[i]] = 0

In [28]:
authors = list(authors_novels.keys())
authors_acc = pd.DataFrame(index=authors, columns=authors)

In [29]:
def binary_knn_cross_validation(distance_matrix, authors, n_neighbors, n_splits):
    precision_scores = []
    
    # Create the KNN classifier
    knn = KNeighborsClassifier(n_neighbors=n_neighbors)

    for i in range(0, len(authors) - 1):
            for j in range(i + 1, len(authors)):

                # Take equal number of novels of both authors by sampling novels of the author who has more novels
                a1 = list(authors_novels[authors[i]])
                a2 = list(authors_novels[authors[j]])

                for k in range(0, 5):
                    if len(a1) > len(a2):
                        a1 = sample(a1, len(a2))
                    elif len(a1) < len(a2):
                        a2 = sample(a2, len(a1))
                    novel_list = a1 + a2 

                    subset_distance_matrix = dist_novels[novel_list].loc[novel_list]
                    subset_authors = np.concatenate(([authors[i]] * len(a1), [authors[j]] * len(a2)))[:, np.newaxis]
 
                    # Create the KNN classifier
                    knn = KNeighborsClassifier(n_neighbors=n_neighbors)

                    # Perform N-fold cross-validation
                    scores = cross_val_score(knn, subset_distance_matrix, subset_authors, cv=n_splits)

                    # Calculate precision and store the score
                    precision = np.mean(scores)
                    precision_scores.append(precision)

                average_precision = np.mean(precision_scores)
                authors_acc.at[authors[i], authors[j]] = authors_acc.at[authors[j], authors[i]] = average_precision
                
    for i in range(0, len(authors)):
            authors_acc.at[authors[i], authors[i]] = 1            
    
    return authors_acc

In [30]:
for n_neighbours in range(3, 6):
    for n_folds in range(2, 6):
        acc = binary_knn_cross_validation(dist_novels, authors, n_neighbours, n_folds)
        acc.to_csv("accuracy/" + str(n_neighbours) + "_neighbours_" + str(n_folds) + "_folds.csv")