In [None]:
# Automagically reimport haikulib if it changes.
%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'svg'
%matplotlib inline

import collections
import itertools

import grakel
import matplotlib.pyplot as plt
import numpy as np
import nltk
import networkx as nx
import pandas as pd
import seaborn as sns

from haikulib import data, nlp, utils

data_dir = data.get_data_dir() / "experiments" / "similarity"
data_dir.mkdir(parents=True, exist_ok=True)
pd.set_option("display.latex.repr", True)
pd.set_option("display.latex.longtable", True)
plt.rcParams["figure.figsize"] = (16 * 0.6, 9 * 0.6)
sns.set()

In [None]:
def get_generated_df():
    return pd.read_csv(
        # TODO: Actually generate this CSV file.
        data.get_data_dir() / "experiments" / "generation" / "knesser-ney-ngram" / "generated.csv",
        index_col=0,
    )

In [None]:
corpus = data.get_df()
# Exceedingly slot
corpus["lemma"] = list(nlp.lemmatize(corpus["haiku"]))
generated = get_generated_df()
generated["lemma"] = list(nlp.lemmatize(generated["haiku"]))
generated.head()

In [None]:
def haiku2graph(haiku):
    """Generate the word-adjacency graph for the given haiku."""
    edges = collections.Counter()
    tokens = nltk.word_tokenize(haiku)
    edges.update(utils.pairwise(tokens))
    return grakel.Graph(edges, node_labels={k:k for k in tokens})

In [None]:
%%time
gen_graphs = list(map(haiku2graph, generated["lemma"]))
corpus_graphs = list(map(haiku2graph, corpus["lemma"]))

# gen_graphs = [grakel.Graph(g.edges.data(), node_labels=g.nodes) for g in nx_gen_graphs]
# corpus_graphs = [grakel.Graph(g.edges.data(), node_labels=g.nodes) for g in nx_corpus_graphs]

In [None]:
graph_kernel = grakel.kernels.WeisfeilerLehman(n_iter=2, normalize=True, base_kernel=(grakel.kernels.VertexHistogram, {"sparse": True}))

In [None]:
%%time
for query_graph, query in zip(gen_graphs, generated["haiku"]):
    graph_kernel.fit([query_graph])
    kernel = graph_kernel.transform(corpus_graphs)
    
    # number of similar haiku to find
    n = 3
    indices = np.argsort(kernel[:, 0])[-n:]
    similar = corpus.iloc[indices]
    print("query:", query)
    for sim in similar["haiku"]:
        print("\tsimilar:", sim)