In [None]:
from collections import Counter
import embeddings
import faiss
import json
import glob
import numpy as np
import random

from kl_divergence import get_article_dicts, get_kl_divs, get_word_2_ind, B
from process_words import *
from process_sentences import *
from similarity_search import *
import utils

In [None]:
# Sample num_samples (non-unique) words from the word_2_count dict and return
def get_sample(word_2_count, num_samples):
  word_array = []
  for k, v in word_2_count.items():
    word_array += [k]*v
  random.shuffle(word_array)
  return word_array[:num_samples]

In [None]:
# For a given source_directory, for each file in the directory sample num_samples words and write results to output_directory
# output_directory is assumed to exist
def sample_from_text(num_samples, source_dir, out_dir):
  embedder = embeddings.FastTextEmbedding()
  stopwords = utils.make_stopwords_list() + ["'", "’", "”", "(", ")", "‘"]
  locs = ['middle', 'high', 'college']
  for loc in locs:
    files = glob.glob(f'{source_dir}/*')
    for path in files:
      # Get the name of the book from the filename
      book = path.split("/")[-1]
      print(book)
      # Read all words from the book
      _, _, word_2_count, _ = get_words(path, embedder, keep_misses=False, stopwords=stopwords)
      # If the book is too short, ignore it
      if sum(word_2_count.values()) < num_samples:
        print("MISS", book, sum(word_2_count.values()))
        continue
      # Otherwise, sample and write to output dir
      else:
        text = " ".join(get_sample(word_2_count, num_samples))
        with open(f'{out_dir}/{book}', 'w') as out_file:
          out_file.write(text)

In [None]:
# Reads data from dataset, calculates pairwise KL divergences, and writes the results according to the get_kl_divs_function
def calculate_pairwise_divergences(dataset):
    # dataset could be something like "books_sample_40k"
    ind_2_auth, ind_2_text, ind_2_counts, ind_2_embs, ind_2_p = get_article_dicts(dataset)
    ep = 10E-5
    a = [1-ep, 1+ep]
    _ = get_kl_divs(dataset, ind_2_p, ind_2_auth, ind_2_embs, ind_2_counts, alphas=a, gpu=False)

In [None]:
# Build from individual files to single json
# out_file should be a json
def combine_results(source_dir, out_file):
  kl_og = {}
  kl_3 = {}
  kl_5 = {}
  kl_10 = {}
  kl_25 = {}
  kl_50 = {}
  kl_100 = {}

  for f in glob.glob(f"{source_dir}/*.json"):
    with open(f, 'r') as out_file:
      data = json.load(out_file)
      left = f.split("/")[-1].split(".")[0]
      for right in data.keys():
        kl_og[(left, right_fixed)] = data[right]["original"]
        kl_3[(left, right_fixed)] = data[right]["3"]
        kl_5[(left, right_fixed)] = data[right]["5"]
        kl_10[(left, right_fixed)] = data[right]["10"]
        kl_25[(left, right_fixed)] = data[right]["25"]
        kl_50[(left, right_fixed)] = data[right]["50"]
        kl_100[(left, right_fixed)] = data[right]["100"]

  data = {}
  data["pairs"] = list(kl_og.keys())
  data["original"] = list(kl_og.values())
  data["3"] = list(kl_3.values())
  data["5"] = list(kl_5.values())
  data["10"] = list(kl_10.values())
  data["25"] = list(kl_25.values())
  data["50"] = list(kl_50.values())
  data["100"] = list(kl_100.values())

  with open(out_file, 'w') as out_file:
    json.dump(data, out_file, indent=2)


In [None]:
# For a given file, 
def get_predictions(source_file, out_file, works):
  # Get the "true" label for each book
  with open('authorship_categories.json', 'r') as f:
    categories = json.load(f)

  work_2_auth = {}
  for auth in categories["author"].keys():
    for work in categories["author"][auth]:
      work_2_auth[work] = auth

  work_2_level = {}
  for level in categories["reading_level"].keys():
    for work in categories["reading_level"][level]:
      work_2_level[work] = level

  work_2_genre = {}
  for genre in categories["genre"].keys():
    for work in categories["genre"][genre]:
      work_2_genre[work] = genre

  # Load results from this file and initialize data structure
  with open(source_file, 'r') as f:
    data = json.load(f)
  left = file_name.split("/")[-1].split(".")[0]
  new = {}
  new["author"] = {k: {"original": [], "3": [], "5": [], "10": [], "25": [], "50": [], "100": []} for k in categories["author"].keys()}
  new["level"] = {k: {"original": [], "3": [], "5": [], "10": [], "25": [], "50": [], "100": []} for k in categories["reading_level"].keys()}
  new["genre"] = {k: {"original": [], "3": [], "5": [], "10": [], "25": [], "50": [], "100": []} for k in categories["genre"].keys()}

  # Gather the results for every other work
  for right in data.keys():
    if right in ["author", "level", "genre"]:
      continue
    if right not in works:
      continue
    for val in ["original", "3", "5", "10", "25", "50", "100"]:
      new["author"][work_2_auth[right]][val].append(data[right][val])
      new["level"][work_2_level[right]][val].append(data[right][val])
      new["genre"][work_2_genre[right]][val].append(data[right][val])

  # Take the average
  for comp in ["author", "level", "genre"]:
    for cat in new[comp].keys():
      for val in ["original", "3", "5", "10", "25", "50", "100"]:
        new[comp][cat][f"{val}_avg"] = np.mean(new[comp][cat][val])
  
  # Get the number of times each label appears
  auth = []
  level = []
  genre = []
  for w in works:
    auth.append(work_2_auth[w])
    level.append(work_2_level[w])
    genre.append(work_2_genre[w])

  counts = {"author": Counter(auth), "level": Counter(level), "genre": Counter(genre)}

  final = {}
  final["predictions"] = {"author": {"original": "", "3": "", "5": "", "10": "", "25": "", "50": "", "100": "", "nearest": {"original": [], "3": [], "5": [], "10": [], "25": [], "50": [], "100": []}},
                         "level": {"original": "", "3": "", "5": "", "10": "", "25": "", "50": "", "100": "", "nearest": {"original": [], "3": [], "5": [], "10": [], "25": [], "50": [], "100": []}},
                         "genre": {"original": "", "3": "", "5": "", "10": "", "25": "", "50": "", "100": "", "nearest": {"original": [], "3": [], "5": [], "10": [], "25": [], "50": [], "100": []}}}
  for comp in ["author", "level", "genre"]:
    for val in ["original", "3", "5", "10", "25", "50", "100"]:
      min_val = 1000
      min_category = ""
      for cat in new[comp].keys():
        # This is the new label if 1) the divergence is less than the previously recorded divergence
        # and 2) the label applies to at least two works in the dataset
        if new[comp][cat][f"{val}_avg"] < min_val and counts[comp][cat] > 1:
          min_val = new[comp][cat][f"{val}_avg"]
          min_category = cat
      final["predictions"][comp][val] = min_category
      # Order the rest of the results by how far off from optimal they are. This helps with debugging
      nearest_list = []
      for cat in new[comp].keys():
        if cat != min_category and not np.isnan(new[comp][cat][f"{val}_avg"]):
          if counts[comp][cat] > 1:
            nearest_list.append((cat, new[comp][cat][f"{val}_avg"] - min_val))
          else:
            nearest_list.append((cat, -1))

      nearest_list.sort(key=lambda x: x[1])
      final["predictions"][comp]["nearest"][val] = nearest_list

  # Dump results to file
  with open(out_file, "w") as f:
    json.dump(final, f, indent=2)

In [None]:
# Get accuracy stats
def show_accuracy(source_dir):
    # Load "true" results
    with open('authorship_categories.json', 'r') as f:
    categories = json.load(f)

    work_2_auth = {}
    for auth in categories["author"].keys():
    for work in categories["author"][auth]:
        work_2_auth[work] = auth

    work_2_level = {}
    for level in categories["reading_level"].keys():
    for work in categories["reading_level"][level]:
        work_2_level[work] = level

    work_2_genre = {}
    for genre in categories["genre"].keys():
    for work in categories["genre"][genre]:
        work_2_genre[work] = genre


    scores = {}
    for comp in ["author", "level", "genre"]:
    scores[comp] = {}
    for val in ["original", "3", "5", "10", "25", "50", "100"]:
        scores[comp][val] = 0

    # Collect list of books
    files = glob.glob(f"{source_dir}/*.json")
    works =[x.split("/")[-1].split(".")[0] for x in files]


    number_files = len(files)
    for f in files:
        with open(f, 'r') as in_file:
            data = json.load(in_file)["predictions"]
        book = f.split("/")[-1].split(".")[0]
        for val in ["original", "3", "5", "10", "25", "50", "100"]:
            if data["author"][val] == work_2_auth[book]:
                scores["author"][val] += 1
            if data["level"][val] == work_2_level[book]:
            scores["level"][val] += 1
            if data["genre"][val] == work_2_genre[book]:
            scores["genre"][val] += 1

    print(scores)
    print(number_files)

In [None]:
# An example usage
sample_from_text(32000, './data/books', './data/books_sample_32k')
calculate_pairwise_divergences('books_sample_32k')
combine_results("./kl_results/books_sample_32k", "./kl_combined_32k.json")

files = glob.glob("./kl_results/books_sample_32k/*.json")
works = [x.split("/")[-1].split(".")[0] for x in files]
for f in files:
  book = f.split("/")[-1].split(".")[0]
  get_avgs(f, f"./kl_predictions_32k/{book}.json", sample)

show_accuracy("32k")