<a href="https://colab.research.google.com/github/anihab/dnaTokenization/blob/main/statistics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [129]:
import csv
import json
import os
import scipy
import json

import pandas as pd

import plotly.express as px
import plotly.graph_objects as go

import matplotlib.pyplot as plt
from matplotlib_venn import venn2, venn3

from collections import Counter



In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [102]:
# define inputs
dnabert2="/content/drive/MyDrive/tokenization/data/DNABERT2.json"
bigbird="/content/drive/MyDrive/tokenization/data/bigbird.json"

vocab4096="/content/drive/MyDrive/tokenization/data/vocab4096.json"
vocab8192="/content/drive/MyDrive/tokenization/data/vocab8192.json"
vocab16384="/content/drive/MyDrive/tokenization/data/vocab16384.json"
vocab32768="/content/drive/MyDrive/tokenization/data/vocab32768.json"

##**Tokenization Statistics**

A collection of functions to analyze tokenized output and produce figures.

###**Calculate Subword Fertility**

This is normally defined as the average number of subwords produced per tokenized word. Since we don't have tokenized words, we will instead calculate the average number of subwords per input sequence (which is currently 500 nt).

In [120]:
# given a single tokenized csv file, calculate the subword fertility
# all of the tokens for a single input sequence are on a single line of the csv, separated by whitespace. the tokens for the next input sequence are on the next line.
def subword_fertility(filepath):
  total_subwords = 0
  total_words = 0

  with open(filepath, 'r') as file:
      for row in csv.reader(file):
          if row:
              subwords = row[0].split()
              total_subwords += len(subwords)
              total_words += 1

  if total_words > 0:
      return total_subwords / total_words # avg number of tokens per sequence
  else:
      return 0

# given a directory of tokenized files, calculate the average subword fertility for all files
def subword_fertility_dir(directory):
  total_avg = 0
  file_count = 0

  for filename in os.listdir(directory):
      if filename.endswith('.csv'):
          filepath = os.path.join(directory, filename)
          avg = subword_fertility(filepath)
          total_avg += avg
          file_count += 1

  if file_count > 0:
      return total_avg / file_count
  else:
      return 0


###**Calculate the Max, Min and Average token length from tokenized files**

additionally produce a *histogram* of token lengths.

In [121]:
# given a single tokenized csv file, calculate and return a tuple of the max, min, and average token lengths.
def token_stats(filepath):
  lengths = [] # will hold the lengths of each token from a single input sequence

  with open(filepath, 'r') as file:
      for row in csv.reader(file):
          if row:
              tokens = row[0].split()
              lengths.extend(len(token) for token in tokens)

  if lengths:
      max_len = max(lengths)
      min_len = min(lengths)
      avg_len = sum(lengths) / len(lengths)
      return max_len, min_len, avg_len
  else:
      return 0, 0, 0

# given a directory of tokenized files, calculate the max, min, and average token lengths for all files.
def token_stats_dir(directory):
    max_lengths = []
    min_lengths = []
    avg_lengths = []

    for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            filepath = os.path.join(directory, filename)
            max_len, min_len, avg_len = token_stats(filepath)
            max_lengths.append(max_len)
            min_lengths.append(min_len)
            avg_lengths.append(avg_len)

    if max_lengths:
        max_len = max(max_lengths)
        min_len = min(min_lengths)
        avg_len = sum(avg_lengths) / len(avg_lengths)
        return max_len, min_len, avg_len
    else:
        return 0, 0, 0

# given a directory of tokenized files, plot a histogram of the token lengths
def length_histogram(directory):
  lengths = []

  for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            filepath = os.path.join(directory, filename)
            with open(filepath, 'r') as file:
                for row in csv.reader(file):
                    if row:
                        tokens = row[0].split()
                        lengths.extend(len(token) for token in tokens)

  if lengths:
        fig = px.histogram(x=lengths, nbins=30, labels={'x': 'Token Length', 'y': 'Frequency'},
                           title='Histogram of Token Lengths')
        fig.update_layout(showlegend=False)
        fig.show()
  else:
        print("No CSV files found in the directory.")

###**Calculate the Max, Min and Average token length from corpus**

additionally produce a *histogram* of token lengths.

In [132]:
# given a json file, extract the vocabulary
def get_vocab(filepath):
  with open(filepath, 'r') as file:
    data = json.load(file)
    corpus = data["model"]["vocab"]
    return list(corpus.keys())

# given a json vocabulary file, calculate the max, min, and average word/token lengths
# additionally, produce a histogram of the token lengths
def corpus_stats(filepath):
  lengths = []
  vocab = get_vocab(filepath)

  if vocab:
    lengths.extend(len(token) for token in vocab)

  if lengths:
    max_len = max(lengths)
    min_len = min(lengths)
    avg_len = sum(lengths) / len(vocab)

    fig = px.histogram(x=lengths, labels={'x': 'Word Length', 'y': 'Frequency'},
                           title='Histogram of Word Lengths')
    fig.update_layout(showlegend=False)
    fig.show()

    name = filepath.replace('.json', '')
    fig.write_html(name + "_histogram.html")

    return max_len, min_len, avg_len
  else:
    return 0, 0, 0

###**Calculate Coverage Metrics**

Returns the number of unused words in a tokenizer's vocabulary and additionally produces a *histogram*, *pie chart*, and *scatter plot* of used word frequencies.

In [123]:
# given a json vocabulary file and a directory of tokenized csv files, calculate the number of unused words
# additionally produce figures to display used word frequencies
def coverage_stats(jsonfile, directory):
  unused = []
  used = []
  vocab = get_vocab(jsonfile)

  if not vocab:
    return "json file error"

  for filename in os.listdir(directory):
        if filename.endswith(".csv"):
            filepath = os.path.join(directory, filename)
            with open(filepath, 'r') as file:
                for row in csv.reader(file):
                    if row:
                        used.extend(token for token in row[0].split())

  if used:
    for token in vocab:
        if token not in used:
            unused.append(token)

    tokens, frequencies = zip(*Counter(used).items())

    # Histogram
    fig_hist = px.bar(x=tokens, y=frequencies, labels={'x': 'Words', 'y': 'Frequency'},
                      title='Histogram of Word Frequencies')
    fig_hist.show()

    # Pie Chart
    fig_pie = px.pie(names=tokens, values=frequencies, labels=tokens, title='Pie Chart of Word Frequencies',
                     hole=0.3)
    fig_pie.show()

    # Scatter Plot
    df = pd.DataFrame({'token': tokens,
                       'length': list(map(len(), tokens)),
                       'frequency': frequencies})

    fig_scatter = px.scatter(data_frame=df, x='length', y='frequency', labels='token',
                             title='Scatter Plot of Used Word Length Frequencies')
    fig_scatter.show()

    return "Unused words: " + str(len(unused))
  else:
    return 0

###**Vocabulary Comparison**

produce a venn diagram to compare different vocabulary outputs

In [131]:
def venn2_to_plotly(sets,labels,title):

    n_sets = len(sets)
    v = venn2(sets, labels)

    # supress output of venn diagramm
    plt.close()

    # create empty lists to hold shapes and annotations
    shapes = []
    annotation = []

    # define color list for sets
    color = ['FireBrick','DodgerBlue','DimGrey']

    # create empty list to make hold of min and max values of set shapes
    L_x_max = []
    L_y_max = []
    L_x_min = []
    L_y_min = []

    for i in range(0,n_sets):

        # create circle shape for current set

        shape = go.layout.Shape(
                type="circle",
                xref="x",
                yref="y",
                x0= v.centers[i][0] - v.radii[i],
                y0=v.centers[i][1] - v.radii[i],
                x1= v.centers[i][0] + v.radii[i],
                y1= v.centers[i][1] + v.radii[i],
                fillcolor=color[i],
                line_color=color[i],
                opacity = 0.75
            )

        shapes.append(shape)

        # create set label for current set

        anno_set_label = go.layout.Annotation(
                xref="x",
                yref="y",
                x = v.set_labels[i].get_position()[0],
                y = v.set_labels[i].get_position()[1],
                text = v.set_labels[i].get_text(),
                showarrow=False
        )

        annotation.append(anno_set_label)

        # get min and max values of current set shape
        L_x_max.append(v.centers[i][0] + v.radii[i])
        L_x_min.append(v.centers[i][0] - v.radii[i])
        L_y_max.append(v.centers[i][1] + v.radii[i])
        L_y_min.append(v.centers[i][1] - v.radii[i])

    # determine number of subsets
    n_subsets = sum([scipy.special.binom(n_sets,i+1) for i in range(0,n_sets)])

    for i in range(0,int(n_subsets)):

        # create subset label (number of common elements for current subset

        anno_subset_label = go.layout.Annotation(
                xref="x",
                yref="y",
                x = v.subset_labels[i].get_position()[0],
                y = v.subset_labels[i].get_position()[1],
                text = v.subset_labels[i].get_text(),
                showarrow=False
        )

        annotation.append(anno_subset_label)

    # define off_set for the figure range
    off_set = 0.2

    # get min and max for x and y dimension to set the figure range
    x_max = max(L_x_max) + off_set
    x_min = min(L_x_min) - off_set
    y_max = max(L_y_max) + off_set
    y_min = min(L_y_min) - off_set

    # create plotly figure
    p_fig = go.Figure()

    # set xaxes range and hide ticks and ticklabels
    p_fig.update_xaxes(
        range=[x_min, x_max],
        showticklabels=False,
        ticklen=0
    )

    # set yaxes range and hide ticks and ticklabels
    p_fig.update_yaxes(
        range=[y_min, y_max],
        scaleanchor="x",
        scaleratio=1,
        showticklabels=False,
        ticklen=0
    )

    # set figure properties and add shapes and annotations
    p_fig.update_layout(
        plot_bgcolor='white',
        margin = dict(b = 0, l = 10, pad = 0, r = 10, t = 40),
        width=800,
        height=400,
        shapes= shapes,
        annotations = annotation,
        title = dict(text = title, x=0.5, xanchor = 'center')
    )

    p_fig.show()
    p_fig.write_html(title + ".html")

In [125]:
def vocab_comparison(filepath1, filepath2, name1, name2):
  vocab1 = get_vocab(filepath1)
  vocab2 = get_vocab(filepath2)
  title = "Vocabulary Overlap Between " + name1 + " and " + name2 + " Trained Tokenizers"
  if vocab1 and vocab2:
    venn2_to_plotly([set(vocab1), set(vocab2)], [name1, name2], title)
  else:
    return "json file error"

##**Figures**


###**Vocabulary Size 4096**

In [138]:
corpus_stats(vocab4096)
vocab_comparison(vocab4096, dnabert2, "Bacteria 4096", "DNABERT2")
vocab_comparison(vocab4096, bigbird, "Bacteria 4096", "BigBird")

###**Vocabulary Size 8192**

In [139]:
corpus_stats(vocab8192)
vocab_comparison(vocab8192, dnabert2, "Bacteria 8192", "DNABERT2")
vocab_comparison(vocab8192, bigbird, "Bacteria 8192", "BigBird")

###**Vocabulary Size 16384**

In [140]:
corpus_stats(vocab16384)
vocab_comparison(vocab16384, dnabert2, "Bacteria 16384", "DNABERT2")
vocab_comparison(vocab16384, bigbird, "Bacteria 16384", "BigBird")

###**Vocabulary Size 32768**

In [141]:
corpus_stats(vocab32768)
vocab_comparison(vocab32768, dnabert2, "Bacteria 32768", "DNABERT2")
vocab_comparison(vocab32768, bigbird, "Bacteria 32768", "BigBird")

In [137]:
# Tokenized Statistics
# print("subword fertility: " + str(subword_fertility_dir(tokenized_directory)))
# print("max, min, and avg token lengths: " + str(token_stats_dir(tokenized_directory)))
# length_histogram(tokenized_directory)
# coverage_stats(my_vocabulary, tokenized_directory)