<a href="https://colab.research.google.com/github/anihab/tokenization/blob/main/statistics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import csv
import json
import os
import matplotlib.pyplot as plt
from matplotlib_venn import venn2
from collections import Counter

##**Tokenization Statistics**

A collection of functions and figures to analyze tokenized output.

###**Calculate Subword Fertility**

This is normally defined as the average number of subwords produced per tokenized word. Since we don't have tokenized words, we will instead calculate the average number of subwords per input sequence (which is currently 500 nt).

In [None]:
# given a single tokenized csv file, calculate the subword fertility
# all of the tokens for a single input sequence are on a single line of the csv, separated by whitespace. the tokens for the next input sequence are on the next line.
def subword_fertility(filepath):
  total_subwords = 0
  total_words = 0

  with open(filepath, 'r') as file:
      for row in csv.reader(file):
          if row:
              subwords = row[0].split()
              total_subwords += len(subwords)
              total_words += 1

  if total_words > 0:
      return total_subwords / total_words # avg number of tokens per sequence
  else:
      return 0

# given a directory of tokenized files, calculate the average subword fertility for all files
def subword_fertility_dir(directory):
  total_avg = 0
  file_count = 0

  for filename in os.listdir(directory):
      if filename.endswith('.csv'):
          filepath = os.path.join(directory, filename)
          avg = subword_fertility(filepath)
          total_avg += avg
          file_count += 1

  if file_count > 0:
      return total_avg / file_count
  else:
      return 0


###**Calculate the Max, Min and Average token length from tokenized files**

additionally produce a *histogram* of token lengths.

In [None]:
# given a single tokenized csv file, calculate and return a tuple of the max, min, and average token lengths.
def token_stats(filepath):
  lengths = [] # will hold the lengths of each token from a single input sequence

  with open(filepath, 'r') as file:
      for row in csv.reader(file):
          if row:
              tokens = row[0].split()
              lengths.extend(len(token) for token in tokens)

  if lengths:
      max_len = max(lengths)
      min_len = min(lengths)
      avg_len = sum(lengths) / len(lengths)
      return max_len, min_len, avg_len
  else:
      return 0, 0, 0

# given a directory of tokenized files, calculate the max, min, and average token lengths for all files.
def token_stats_dir(directory):
    max_lengths = []
    min_lengths = []
    avg_lengths = []

    for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            filepath = os.path.join(directory, filename)
            max_len, min_len, avg_len = token_stats(filepath)
            max_lengths.append(max_len)
            min_lengths.append(min_len)
            avg_lengths.append(avg_len)

    if max_lengths:
        max_len = max(max_lengths)
        min_len = min(min_lengths)
        avg_len = sum(avg_lengths) / len(avg_lengths)
        return max_len, min_len, avg_len
    else:
        return 0, 0, 0

# given a directory of tokenized files, plot a histogram of the token lengths
def length_histogram(directory):
  lengths = []

  for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            filepath = os.path.join(directory, filename)
            with open(filepath, 'r') as file:
                for row in csv.reader(file):
                    if row:
                        tokens = row[0].split()
                        lengths.extend(len(token) for token in tokens)

  if lengths:
        plt.hist(lengths, range=(min(lengths), max(lengths)), color='skyblue', edgecolor='black')
        plt.xlabel('Token Length')
        plt.ylabel('Frequency')
        plt.title('Histogram of Token Lengths')
        plt.show()
  else:
        print("No CSV files found in the directory.")

###**Calculate the Max, Min and Average token length from corpus**

additionally produce a *histogram* of token lengths.

In [None]:
# given a json file, extract the vocabulary
def get_vocab(filepath):
  with open(filepath, 'r') as file:
    data = json.load(file)
    corpus = data["model"]["vocab"]
    return corpus.keys()

# given a json vocabulary file, calculate the max, min, and average token length
# additionally, produce a histogram of the token lengths
def corpus_stats(filepath):
  lengths = []
  vocab = get_vocab(filepath)

  if vocab:
    lengths.extend(len(token) for token in vocab)

  if lengths:
    max_len = max(lengths)
    min_len = min(lengths)
    avg_len = sum(lengths) / len(vocab)

    plt.hist(lengths, range=(min_len, max_len), color='skyblue', edgecolor='black')
    plt.xlabel('Token Length')
    plt.ylabel('Frequency')
    plt.title('Histogram of Token Lengths')
    plt.show()

    return max_len, min_len, avg_len
  else:
    return 0, 0, 0

###**Calculate Coverage Metrics**

Returns the number of unused words in a corpus and additionally produces a *histogram* of used word frequencies.

In [None]:
# given a json vocabulary file and a directory of tokenized csv files, calculate the number of unused words
# additionally produce a histogram of used word frequencies
def coverage_stats(jsonfile, directory):
  unused = []
  used = []
  vocab = get_vocab(jsonfile)

  if not vocab:
    return "json file error"

  for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            filepath = os.path.join(directory, filename)
            with open(filepath, 'r') as file:
                for row in csv.reader(file):
                    if row:
                        used.extend(token for token in row[0].split())

  if used:
    for token in vocab:
        if token not in used:
            unused.append(token)

    tokens, frequencies = zip(*Counter(used).items())
    plt.bar(tokens, frequencies)
    plt.xlabel('Words')
    plt.ylabel('Frequency')
    plt.title('Histogram of Word Frequencies')
    plt.show()

    plt.pie(frequencies, labels=tokens, autopct='%1.1f%%')
    plt.axis('equal')
    plt.title('Pie Chart of Word Frequencies')
    plt.show()
    return "Unused words: " + str(len(unused))
  else:
    return 0

###**Vocabulary Comparison**

produce a venn diagram to compare different vocabulary outputs

In [None]:
def vocab_comparison(filepath1, filepath2, name1, name2):
  vocab1 = get_vocab(filepath1)
  vocab2 = get_vocab(filepath2)
  if vocab1 and vocab2:
    venn2([set(vocab1), set(vocab2)], (name1, name2))
    plt.title('Vocabulary Venn Diagram')
    plt.show()
  else:
    return "json file error"

###**Main**

In [None]:
print("subword fertility: " + str(subword_fertility_dir("/content")))
print("max, min, and avg token lengths: " + str(token_stats_dir("/content")))
length_histogram("/content")
corpus_stats("/content/tokenizer.json")
coverage_stats("/content/dna_tokenizer.json", "/content")
vocab_comparison("/content/dna_tokenizer.json", "/content/tokenizer.json", "Bacteria Vocab", "DNABERT2 Vocab")