## Setup

In [1]:
import io
import re
import string
import tqdm

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers

In [2]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [3]:
SEED = 42
AUTOTUNE = tf.data.AUTOTUNE

### Vectorize an example sentence

## Compile all steps into one function


### Skip-gram sampling table

A large dataset means larger vocabulary with higher number of more frequent words such as stopwords. Training examples obtained from sampling commonly occurring words (such as `the`, `is`, `on`) don't add much useful information  for the model to learn from. [Mikolov et al.](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf) suggest subsampling of frequent words as a helpful practice to improve embedding quality.

The `tf.keras.preprocessing.sequence.skipgrams` function accepts a sampling table argument to encode probabilities of sampling any token. You can use the `tf.keras.preprocessing.sequence.make_sampling_table` to  generate a word-frequency rank based probabilistic sampling table and pass it to the `skipgrams` function. Inspect the sampling probabilities for a `vocab_size` of 10.

In [None]:
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=10)
print(sampling_table)

[0.00315225 0.00315225 0.00547597 0.00741556 0.00912817 0.01068435
 0.01212381 0.01347162 0.01474487 0.0159558 ]


`sampling_table[i]` denotes the probability of sampling the i-th most common word in a dataset. The function assumes a [Zipf's distribution](https://en.wikipedia.org/wiki/Zipf%27s_law) of the word frequencies for sampling.

Key point: The `tf.random.log_uniform_candidate_sampler` already assumes that the vocabulary frequency follows a log-uniform (Zipf's) distribution. Using these distribution weighted sampling also helps approximate the Noise Contrastive Estimation (NCE) loss with simpler loss functions for training a negative sampling objective.

### Generate training data

Compile all the steps described above into a function that can be called on a list of vectorized sentences obtained from any text dataset. Notice that the sampling table is built before sampling skip-gram word pairs. You will use this function in the later sections.

In [None]:
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm.tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

## Prepare training data for word2vec

With an understanding of how to work with one sentence for a skip-gram negative sampling based word2vec model, you can proceed to generate training examples from a larger list of sentences!

### Download text corpus


You will use a text file of Shakespeare's writing for this tutorial. Change the following line to run this code on your own data.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import zipfile

# Unzip the archive
local_zip = '/content/drive/MyDrive/Colab Notebooks/Text.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall()

zip_ref.close()

In [None]:
import os

base_dir = 'Text'

print("Contents of base directory:")
print(os.listdir(base_dir))

print("\nContents of AI directory:")
print(os.listdir(f'{base_dir}/AI'))

print("\nContents of NLP directory:")
print(os.listdir(f'{base_dir}/NLP'))

Contents of base directory:
['AI', 'Stats', 'NLP', 'CV']

Contents of AI directory:
['Artificial intelligence and education in China.pdf', '10.2478_rem-2020-0003.pdf', 's00146-020-01033-8.pdf', 'sustainability-13-07941.pdf', 'article_222831.pdf', 's10639-022-11316-w.pdf', 's43681-021-00096-7.pdf', 'Computer Assisted Learning - 2022 - Pham - The development of artificial intelligence in education  A review in context.pdf', 'sustainability-14-01101.pdf', '978-981-19-0351-9_6-2.pdf', '0346202012112271.pdf', '25_1_03.pdf', 'Artificial intelligence in education challenges and opportunities for sustainable development.pdf', 'fpsyg-11-580820.pdf', 's40593-016-0095-y.pdf', 'Artificial_Intelligence_in_Education_A_Review.pdf', 'Becker-AI-in-Education-with-cover-sheet.pdf', 'Artificial intelligence for education  Knowledge and its assessment in AI-enabled learning ecologies.pdf', 'document.pdf', 's40593-016-0110-3.pdf', 'd37ca3f650e9f72613189003a8c49eddb75b.pdf', '1984-5745-1-PB.pdf', 'informatio

In [None]:
main_dir_files = os.listdir(base_dir)
main_dir_files

['AI', 'Stats', 'NLP', 'CV']

In [None]:
!pip install PyPDF2



In [None]:
import PyPDF2

for i in main_dir_files:
  sub_dir_files = os.listdir(base_dir+'/'+i)
  k=1
  for j in sub_dir_files:
    #create file object variable
    #opening method will be rb
    pdffileobj=open(base_dir+'/'+i+'/'+j,'rb')

    #create reader variable that will read the pdffileobj
    reader = PyPDF2.PdfReader(pdffileobj)

    #This will store the number of pages of this pdf file
    x = len(reader.pages)
    text = ''
    for pages in range(x):
      page = reader.pages[pages]
      text += page.extract_text()

    #filename = 'file'+str(i)+str(k)
    filename = open(base_dir+"/"+str(i)+"/"+str(k)+".txt", "a")
    filename.writelines(text)
    filename.close()
    k+=1



In [None]:
train_ds = tf.keras.utils.text_dataset_from_directory(
  base_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  batch_size=10)

Found 117 files belonging to 4 classes.
Using 94 files for training.


In [None]:
val_ds = tf.keras.utils.text_dataset_from_directory(
  base_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  batch_size=10)

Found 117 files belonging to 4 classes.
Using 23 files for validation.


In [None]:
text_ds = tf.keras.utils.text_dataset_from_directory(
  base_dir,
  seed=123)

Found 117 files belonging to 4 classes.


In [None]:
doc_len = len(list(text_ds.as_numpy_iterator())[0][0])

### Vectorize sentences from the corpus

You can use the `TextVectorization` layer to vectorize sentences from the corpus. Learn more about using this layer in this [Text classification](https://www.tensorflow.org/tutorials/keras/text_classification) tutorial. Notice from the first few sentences above that the text needs to be in one case and punctuation needs to be removed. To do this, define a `custom_standardization function` that can be used in the TextVectorization layer.

In [None]:
# Now, create a custom standardization function to lowercase the text and
# remove punctuation.
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  return tf.strings.regex_replace(lowercase,
                                  '[%s]' % re.escape(string.punctuation), '')


# Define the vocabulary size and the number of words in a sequence.
vocab_size = 4096
sequence_length = 10

# Use the `TextVectorization` layer to normalize, split, and map strings to
# integers. Set the `output_sequence_length` length to pad all samples to the
# same length.
vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length)

Call `TextVectorization.adapt` on the text dataset to create vocabulary.


In [None]:
class_len = len(list(text_ds.as_numpy_iterator())[0])

In [None]:
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm.tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [None]:
for i in main_dir_files:
  sub_dir_files = os.listdir(base_dir+'/'+i)
  for j in sub_dir_files:
    text_ds = tf.data.TextLineDataset(base_dir+'/'+i+'/'+j).filter(lambda x: tf.cast(tf.strings.length(x), bool))
    vectorize_layer.adapt(text_ds.batch(1024))

    #inverse_vocab = vectorize_layer.get_vocabulary()
    #print(inverse_vocab[:20])

    # Vectorize the data in text_ds.
    text_vector_ds = text_ds.batch(batch_size=1024).prefetch(buffer_size=AUTOTUNE).map(vectorize_layer).unbatch()

    sequences = list(text_vector_ds.as_numpy_iterator())
    print(len(sequences))

    targets, contexts, labels = generate_training_data(
    sequences=sequences,
    window_size=2,
    num_ns=4,
    vocab_size=vocab_size,
    seed=SEED)

    targets = np.array(targets)
    contexts = np.array(contexts)
    labels = np.array(labels)

    print('\n')
    print(f"targets.shape: {targets.shape}")
    print(f"contexts.shape: {contexts.shape}")
    print(f"labels.shape: {labels.shape}")


529


100%|██████████| 529/529 [00:00<00:00, 1089.02it/s]




targets.shape: (1690,)
contexts.shape: (1690, 5)
labels.shape: (1690, 5)
2900


100%|██████████| 2900/2900 [00:04<00:00, 702.59it/s]




targets.shape: (9015,)
contexts.shape: (9015, 5)
labels.shape: (9015, 5)
22320


100%|██████████| 22320/22320 [00:02<00:00, 8823.20it/s] 




targets.shape: (6790,)
contexts.shape: (6790, 5)
labels.shape: (6790, 5)
626


100%|██████████| 626/626 [00:00<00:00, 837.41it/s]




targets.shape: (2501,)
contexts.shape: (2501, 5)
labels.shape: (2501, 5)
647


100%|██████████| 647/647 [00:00<00:00, 1058.52it/s]




targets.shape: (2161,)
contexts.shape: (2161, 5)
labels.shape: (2161, 5)
3273


100%|██████████| 3273/3273 [00:01<00:00, 1895.01it/s]




targets.shape: (4136,)
contexts.shape: (4136, 5)
labels.shape: (4136, 5)
4158


100%|██████████| 4158/4158 [00:01<00:00, 4025.57it/s]




targets.shape: (3856,)
contexts.shape: (3856, 5)
labels.shape: (3856, 5)
11360


100%|██████████| 11360/11360 [00:04<00:00, 2278.11it/s]




targets.shape: (18385,)
contexts.shape: (18385, 5)
labels.shape: (18385, 5)
467


100%|██████████| 467/467 [00:00<00:00, 982.34it/s] 




targets.shape: (1476,)
contexts.shape: (1476, 5)
labels.shape: (1476, 5)
18756


100%|██████████| 18756/18756 [00:01<00:00, 11410.32it/s]




targets.shape: (4927,)
contexts.shape: (4927, 5)
labels.shape: (4927, 5)
6088


100%|██████████| 6088/6088 [00:00<00:00, 6263.36it/s]




targets.shape: (3336,)
contexts.shape: (3336, 5)
labels.shape: (3336, 5)
615


100%|██████████| 615/615 [00:00<00:00, 1690.16it/s]




targets.shape: (1399,)
contexts.shape: (1399, 5)
labels.shape: (1399, 5)
247


100%|██████████| 247/247 [00:00<00:00, 1467.75it/s]




targets.shape: (607,)
contexts.shape: (607, 5)
labels.shape: (607, 5)
4712


100%|██████████| 4712/4712 [00:01<00:00, 3798.40it/s]




targets.shape: (3907,)
contexts.shape: (3907, 5)
labels.shape: (3907, 5)
1018


100%|██████████| 1018/1018 [00:01<00:00, 761.92it/s]




targets.shape: (3440,)
contexts.shape: (3440, 5)
labels.shape: (3440, 5)
531


100%|██████████| 531/531 [00:00<00:00, 994.00it/s]




targets.shape: (1896,)
contexts.shape: (1896, 5)
labels.shape: (1896, 5)
15604


100%|██████████| 15604/15604 [00:01<00:00, 10167.97it/s]




targets.shape: (4841,)
contexts.shape: (4841, 5)
labels.shape: (4841, 5)
331


100%|██████████| 331/331 [00:00<00:00, 1149.70it/s]




targets.shape: (1001,)
contexts.shape: (1001, 5)
labels.shape: (1001, 5)
12468


100%|██████████| 12468/12468 [00:02<00:00, 5557.34it/s]




targets.shape: (6676,)
contexts.shape: (6676, 5)
labels.shape: (6676, 5)
1297


100%|██████████| 1297/1297 [00:00<00:00, 1660.67it/s]




targets.shape: (2002,)
contexts.shape: (2002, 5)
labels.shape: (2002, 5)
11042


100%|██████████| 11042/11042 [00:01<00:00, 7328.94it/s]




targets.shape: (5438,)
contexts.shape: (5438, 5)
labels.shape: (5438, 5)
12360


100%|██████████| 12360/12360 [00:01<00:00, 8108.40it/s]




targets.shape: (4997,)
contexts.shape: (4997, 5)
labels.shape: (4997, 5)
649


100%|██████████| 649/649 [00:00<00:00, 1061.32it/s]




targets.shape: (2331,)
contexts.shape: (2331, 5)
labels.shape: (2331, 5)
11640


100%|██████████| 11640/11640 [00:09<00:00, 1166.90it/s]




targets.shape: (33251,)
contexts.shape: (33251, 5)
labels.shape: (33251, 5)
6724


100%|██████████| 6724/6724 [00:01<00:00, 3593.92it/s]




targets.shape: (4796,)
contexts.shape: (4796, 5)
labels.shape: (4796, 5)
1581


100%|██████████| 1581/1581 [00:00<00:00, 1631.18it/s]




targets.shape: (3793,)
contexts.shape: (3793, 5)
labels.shape: (3793, 5)
633


100%|██████████| 633/633 [00:00<00:00, 979.35it/s]




targets.shape: (2320,)
contexts.shape: (2320, 5)
labels.shape: (2320, 5)
263


100%|██████████| 263/263 [00:00<00:00, 1065.84it/s]




targets.shape: (883,)
contexts.shape: (883, 5)
labels.shape: (883, 5)
1896


100%|██████████| 1896/1896 [00:00<00:00, 2094.48it/s]




targets.shape: (3315,)
contexts.shape: (3315, 5)
labels.shape: (3315, 5)
612


100%|██████████| 612/612 [00:00<00:00, 1234.69it/s]




targets.shape: (1840,)
contexts.shape: (1840, 5)
labels.shape: (1840, 5)
14170


100%|██████████| 14170/14170 [00:02<00:00, 6199.33it/s]




targets.shape: (6412,)
contexts.shape: (6412, 5)
labels.shape: (6412, 5)
616


100%|██████████| 616/616 [00:00<00:00, 1029.96it/s]




targets.shape: (1721,)
contexts.shape: (1721, 5)
labels.shape: (1721, 5)
1819


100%|██████████| 1819/1819 [00:02<00:00, 653.80it/s]




targets.shape: (7854,)
contexts.shape: (7854, 5)
labels.shape: (7854, 5)
8214


100%|██████████| 8214/8214 [00:01<00:00, 4145.11it/s]




targets.shape: (6838,)
contexts.shape: (6838, 5)
labels.shape: (6838, 5)
21989


100%|██████████| 21989/21989 [00:02<00:00, 9602.30it/s] 




targets.shape: (6729,)
contexts.shape: (6729, 5)
labels.shape: (6729, 5)
877


100%|██████████| 877/877 [00:00<00:00, 1371.21it/s]




targets.shape: (2423,)
contexts.shape: (2423, 5)
labels.shape: (2423, 5)
1078


100%|██████████| 1078/1078 [00:00<00:00, 1172.54it/s]




targets.shape: (3381,)
contexts.shape: (3381, 5)
labels.shape: (3381, 5)
1852


100%|██████████| 1852/1852 [00:01<00:00, 926.67it/s]




targets.shape: (7887,)
contexts.shape: (7887, 5)
labels.shape: (7887, 5)
893


100%|██████████| 893/893 [00:00<00:00, 1078.80it/s]




targets.shape: (3200,)
contexts.shape: (3200, 5)
labels.shape: (3200, 5)
878


100%|██████████| 878/878 [00:01<00:00, 854.52it/s]




targets.shape: (2908,)
contexts.shape: (2908, 5)
labels.shape: (2908, 5)
1463


100%|██████████| 1463/1463 [00:01<00:00, 994.59it/s] 




targets.shape: (4177,)
contexts.shape: (4177, 5)
labels.shape: (4177, 5)
725


100%|██████████| 725/725 [00:00<00:00, 1464.43it/s]




targets.shape: (1866,)
contexts.shape: (1866, 5)
labels.shape: (1866, 5)
25533


100%|██████████| 25533/25533 [00:01<00:00, 14089.88it/s]




targets.shape: (5327,)
contexts.shape: (5327, 5)
labels.shape: (5327, 5)
2386


100%|██████████| 2386/2386 [00:01<00:00, 2183.87it/s]




targets.shape: (3533,)
contexts.shape: (3533, 5)
labels.shape: (3533, 5)
1381


100%|██████████| 1381/1381 [00:01<00:00, 969.25it/s]




targets.shape: (3825,)
contexts.shape: (3825, 5)
labels.shape: (3825, 5)
877


100%|██████████| 877/877 [00:00<00:00, 1186.00it/s]




targets.shape: (2774,)
contexts.shape: (2774, 5)
labels.shape: (2774, 5)
1753


100%|██████████| 1753/1753 [00:01<00:00, 1656.47it/s]




targets.shape: (4108,)
contexts.shape: (4108, 5)
labels.shape: (4108, 5)
71789


100%|██████████| 71789/71789 [00:02<00:00, 26542.53it/s]




targets.shape: (5420,)
contexts.shape: (5420, 5)
labels.shape: (5420, 5)
3227


100%|██████████| 3227/3227 [00:01<00:00, 2833.14it/s]




targets.shape: (4007,)
contexts.shape: (4007, 5)
labels.shape: (4007, 5)
1111


100%|██████████| 1111/1111 [00:01<00:00, 949.60it/s] 




targets.shape: (4283,)
contexts.shape: (4283, 5)
labels.shape: (4283, 5)
1310


100%|██████████| 1310/1310 [00:01<00:00, 1073.47it/s]




targets.shape: (3484,)
contexts.shape: (3484, 5)
labels.shape: (3484, 5)
1782


100%|██████████| 1782/1782 [00:01<00:00, 1775.98it/s]




targets.shape: (2459,)
contexts.shape: (2459, 5)
labels.shape: (2459, 5)
3305


100%|██████████| 3305/3305 [00:01<00:00, 2917.31it/s]




targets.shape: (4000,)
contexts.shape: (4000, 5)
labels.shape: (4000, 5)
3668


100%|██████████| 3668/3668 [00:01<00:00, 3074.90it/s]




targets.shape: (4254,)
contexts.shape: (4254, 5)
labels.shape: (4254, 5)
134


100%|██████████| 134/134 [00:00<00:00, 1978.71it/s]




targets.shape: (233,)
contexts.shape: (233, 5)
labels.shape: (233, 5)
798


100%|██████████| 798/798 [00:00<00:00, 1027.85it/s]




targets.shape: (2896,)
contexts.shape: (2896, 5)
labels.shape: (2896, 5)
1821


100%|██████████| 1821/1821 [00:01<00:00, 1679.96it/s]




targets.shape: (3979,)
contexts.shape: (3979, 5)
labels.shape: (3979, 5)
78


100%|██████████| 78/78 [00:00<00:00, 42197.31it/s]




targets.shape: (0,)
contexts.shape: (0,)
labels.shape: (0,)
650


100%|██████████| 650/650 [00:00<00:00, 4094.25it/s]



targets.shape: (549,)
contexts.shape: (549, 5)
labels.shape: (549, 5)





7354


100%|██████████| 7354/7354 [00:01<00:00, 4201.44it/s]




targets.shape: (4418,)
contexts.shape: (4418, 5)
labels.shape: (4418, 5)
309


100%|██████████| 309/309 [00:00<00:00, 886.00it/s]




targets.shape: (737,)
contexts.shape: (737, 5)
labels.shape: (737, 5)
3475


100%|██████████| 3475/3475 [00:01<00:00, 3216.13it/s]




targets.shape: (4011,)
contexts.shape: (4011, 5)
labels.shape: (4011, 5)
2185


100%|██████████| 2185/2185 [00:00<00:00, 2286.76it/s]




targets.shape: (3779,)
contexts.shape: (3779, 5)
labels.shape: (3779, 5)
878


100%|██████████| 878/878 [00:00<00:00, 1070.68it/s]




targets.shape: (3174,)
contexts.shape: (3174, 5)
labels.shape: (3174, 5)
20816


100%|██████████| 20816/20816 [00:09<00:00, 2139.55it/s]




targets.shape: (32548,)
contexts.shape: (32548, 5)
labels.shape: (32548, 5)
2879


100%|██████████| 2879/2879 [00:01<00:00, 2744.78it/s]




targets.shape: (3905,)
contexts.shape: (3905, 5)
labels.shape: (3905, 5)
592


100%|██████████| 592/592 [00:00<00:00, 1140.91it/s]




targets.shape: (2010,)
contexts.shape: (2010, 5)
labels.shape: (2010, 5)
1298


100%|██████████| 1298/1298 [00:01<00:00, 1269.53it/s]




targets.shape: (3386,)
contexts.shape: (3386, 5)
labels.shape: (3386, 5)
897


100%|██████████| 897/897 [00:00<00:00, 1035.18it/s]




targets.shape: (2388,)
contexts.shape: (2388, 5)
labels.shape: (2388, 5)
2682


100%|██████████| 2682/2682 [00:01<00:00, 2665.65it/s]




targets.shape: (3510,)
contexts.shape: (3510, 5)
labels.shape: (3510, 5)
279


100%|██████████| 279/279 [00:00<00:00, 1175.36it/s]




targets.shape: (831,)
contexts.shape: (831, 5)
labels.shape: (831, 5)
1797


100%|██████████| 1797/1797 [00:01<00:00, 1727.27it/s]




targets.shape: (4069,)
contexts.shape: (4069, 5)
labels.shape: (4069, 5)
1528


100%|██████████| 1528/1528 [00:00<00:00, 3740.65it/s]




targets.shape: (1407,)
contexts.shape: (1407, 5)
labels.shape: (1407, 5)
6706


100%|██████████| 6706/6706 [00:01<00:00, 5558.39it/s]




targets.shape: (4096,)
contexts.shape: (4096, 5)
labels.shape: (4096, 5)
3552


100%|██████████| 3552/3552 [00:01<00:00, 3020.30it/s]




targets.shape: (4185,)
contexts.shape: (4185, 5)
labels.shape: (4185, 5)
9


100%|██████████| 9/9 [00:00<00:00, 922.50it/s]




targets.shape: (28,)
contexts.shape: (28, 5)
labels.shape: (28, 5)
3401


100%|██████████| 3401/3401 [00:01<00:00, 2644.25it/s]




targets.shape: (4075,)
contexts.shape: (4075, 5)
labels.shape: (4075, 5)
263


100%|██████████| 263/263 [00:00<00:00, 1110.49it/s]




targets.shape: (773,)
contexts.shape: (773, 5)
labels.shape: (773, 5)
243


100%|██████████| 243/243 [00:00<00:00, 1259.99it/s]




targets.shape: (439,)
contexts.shape: (439, 5)
labels.shape: (439, 5)
954


100%|██████████| 954/954 [00:01<00:00, 911.59it/s]




targets.shape: (3138,)
contexts.shape: (3138, 5)
labels.shape: (3138, 5)
11702


100%|██████████| 11702/11702 [00:01<00:00, 7526.46it/s]




targets.shape: (4893,)
contexts.shape: (4893, 5)
labels.shape: (4893, 5)
14


100%|██████████| 14/14 [00:00<00:00, 5348.42it/s]




targets.shape: (1,)
contexts.shape: (1, 5)
labels.shape: (1, 5)
345


100%|██████████| 345/345 [00:00<00:00, 1425.76it/s]




targets.shape: (846,)
contexts.shape: (846, 5)
labels.shape: (846, 5)
5160


100%|██████████| 5160/5160 [00:01<00:00, 4403.65it/s]




targets.shape: (4313,)
contexts.shape: (4313, 5)
labels.shape: (4313, 5)
19125


100%|██████████| 19125/19125 [00:01<00:00, 10181.01it/s]




targets.shape: (5940,)
contexts.shape: (5940, 5)
labels.shape: (5940, 5)
5559


100%|██████████| 5559/5559 [00:01<00:00, 3060.15it/s]




targets.shape: (4337,)
contexts.shape: (4337, 5)
labels.shape: (4337, 5)
189


100%|██████████| 189/189 [00:00<00:00, 2025.91it/s]



targets.shape: (348,)
contexts.shape: (348, 5)
labels.shape: (348, 5)





51636


100%|██████████| 51636/51636 [00:07<00:00, 6798.42it/s]




targets.shape: (23233,)
contexts.shape: (23233, 5)
labels.shape: (23233, 5)
496


100%|██████████| 496/496 [00:00<00:00, 662.55it/s]




targets.shape: (1693,)
contexts.shape: (1693, 5)
labels.shape: (1693, 5)
3682


100%|██████████| 3682/3682 [00:01<00:00, 2875.65it/s]




targets.shape: (4332,)
contexts.shape: (4332, 5)
labels.shape: (4332, 5)
3490


100%|██████████| 3490/3490 [00:01<00:00, 3011.58it/s]




targets.shape: (3667,)
contexts.shape: (3667, 5)
labels.shape: (3667, 5)
990


100%|██████████| 990/990 [00:00<00:00, 1193.61it/s]




targets.shape: (2842,)
contexts.shape: (2842, 5)
labels.shape: (2842, 5)
623


100%|██████████| 623/623 [00:00<00:00, 1289.83it/s]




targets.shape: (1413,)
contexts.shape: (1413, 5)
labels.shape: (1413, 5)
44


100%|██████████| 44/44 [00:00<00:00, 7606.52it/s]



targets.shape: (8,)
contexts.shape: (8, 5)
labels.shape: (8, 5)





10


100%|██████████| 10/10 [00:00<00:00, 725.52it/s]




targets.shape: (32,)
contexts.shape: (32, 5)
labels.shape: (32, 5)
630


100%|██████████| 630/630 [00:00<00:00, 1053.12it/s]




targets.shape: (2175,)
contexts.shape: (2175, 5)
labels.shape: (2175, 5)
2667


100%|██████████| 2667/2667 [00:01<00:00, 2231.74it/s]




targets.shape: (4213,)
contexts.shape: (4213, 5)
labels.shape: (4213, 5)
229


100%|██████████| 229/229 [00:00<00:00, 2476.21it/s]



targets.shape: (274,)
contexts.shape: (274, 5)
labels.shape: (274, 5)





779


100%|██████████| 779/779 [00:00<00:00, 1566.48it/s]




targets.shape: (1388,)
contexts.shape: (1388, 5)
labels.shape: (1388, 5)
21097


100%|██████████| 21097/21097 [00:02<00:00, 9534.68it/s] 




targets.shape: (6880,)
contexts.shape: (6880, 5)
labels.shape: (6880, 5)
5663


100%|██████████| 5663/5663 [00:01<00:00, 3605.03it/s]




targets.shape: (4311,)
contexts.shape: (4311, 5)
labels.shape: (4311, 5)
6280


100%|██████████| 6280/6280 [00:01<00:00, 4822.33it/s]




targets.shape: (4186,)
contexts.shape: (4186, 5)
labels.shape: (4186, 5)
796


100%|██████████| 796/796 [00:00<00:00, 2813.80it/s]




targets.shape: (882,)
contexts.shape: (882, 5)
labels.shape: (882, 5)
4149


100%|██████████| 4149/4149 [00:01<00:00, 3618.69it/s]




targets.shape: (4296,)
contexts.shape: (4296, 5)
labels.shape: (4296, 5)
1


100%|██████████| 1/1 [00:00<00:00, 288.96it/s]




targets.shape: (4,)
contexts.shape: (4, 5)
labels.shape: (4, 5)
3994


100%|██████████| 3994/3994 [00:01<00:00, 3715.51it/s]




targets.shape: (3973,)
contexts.shape: (3973, 5)
labels.shape: (3973, 5)
787


100%|██████████| 787/787 [00:00<00:00, 1474.33it/s]




targets.shape: (1887,)
contexts.shape: (1887, 5)
labels.shape: (1887, 5)
1666


100%|██████████| 1666/1666 [00:01<00:00, 1317.71it/s]




targets.shape: (3837,)
contexts.shape: (3837, 5)
labels.shape: (3837, 5)
389


100%|██████████| 389/389 [00:00<00:00, 860.85it/s]




targets.shape: (1373,)
contexts.shape: (1373, 5)
labels.shape: (1373, 5)
1395


100%|██████████| 1395/1395 [00:01<00:00, 1364.50it/s]




targets.shape: (3071,)
contexts.shape: (3071, 5)
labels.shape: (3071, 5)
889


100%|██████████| 889/889 [00:00<00:00, 1278.13it/s]




targets.shape: (2595,)
contexts.shape: (2595, 5)
labels.shape: (2595, 5)
1358


100%|██████████| 1358/1358 [00:01<00:00, 1088.30it/s]




targets.shape: (4689,)
contexts.shape: (4689, 5)
labels.shape: (4689, 5)
5194


100%|██████████| 5194/5194 [00:01<00:00, 3925.28it/s]




targets.shape: (4317,)
contexts.shape: (4317, 5)
labels.shape: (4317, 5)
4078


100%|██████████| 4078/4078 [00:01<00:00, 3422.00it/s]




targets.shape: (3930,)
contexts.shape: (3930, 5)
labels.shape: (3930, 5)
6975


100%|██████████| 6975/6975 [00:02<00:00, 3027.43it/s]




targets.shape: (5769,)
contexts.shape: (5769, 5)
labels.shape: (5769, 5)
1359


100%|██████████| 1359/1359 [00:01<00:00, 834.77it/s]




targets.shape: (6209,)
contexts.shape: (6209, 5)
labels.shape: (6209, 5)
342


100%|██████████| 342/342 [00:00<00:00, 1154.38it/s]




targets.shape: (1013,)
contexts.shape: (1013, 5)
labels.shape: (1013, 5)
1378


100%|██████████| 1378/1378 [00:01<00:00, 1317.54it/s]




targets.shape: (3718,)
contexts.shape: (3718, 5)
labels.shape: (3718, 5)
4371


100%|██████████| 4371/4371 [00:01<00:00, 3839.04it/s]




targets.shape: (3912,)
contexts.shape: (3912, 5)
labels.shape: (3912, 5)
15967


100%|██████████| 15967/15967 [00:01<00:00, 10092.58it/s]




targets.shape: (4827,)
contexts.shape: (4827, 5)
labels.shape: (4827, 5)
31851


100%|██████████| 31851/31851 [00:02<00:00, 11563.21it/s]




targets.shape: (7074,)
contexts.shape: (7074, 5)
labels.shape: (7074, 5)
471


100%|██████████| 471/471 [00:00<00:00, 680.67it/s]




targets.shape: (1534,)
contexts.shape: (1534, 5)
labels.shape: (1534, 5)
886


100%|██████████| 886/886 [00:00<00:00, 2403.58it/s]




targets.shape: (1337,)
contexts.shape: (1337, 5)
labels.shape: (1337, 5)
3631


100%|██████████| 3631/3631 [00:01<00:00, 3058.55it/s]




targets.shape: (4167,)
contexts.shape: (4167, 5)
labels.shape: (4167, 5)
467


100%|██████████| 467/467 [00:00<00:00, 1309.84it/s]




targets.shape: (1244,)
contexts.shape: (1244, 5)
labels.shape: (1244, 5)
1886


100%|██████████| 1886/1886 [00:01<00:00, 1318.65it/s]




targets.shape: (5371,)
contexts.shape: (5371, 5)
labels.shape: (5371, 5)
654


100%|██████████| 654/654 [00:00<00:00, 1664.80it/s]




targets.shape: (1406,)
contexts.shape: (1406, 5)
labels.shape: (1406, 5)
1224


100%|██████████| 1224/1224 [00:01<00:00, 1057.21it/s]




targets.shape: (4354,)
contexts.shape: (4354, 5)
labels.shape: (4354, 5)
11105


100%|██████████| 11105/11105 [00:01<00:00, 5586.04it/s]




targets.shape: (5484,)
contexts.shape: (5484, 5)
labels.shape: (5484, 5)
9027


100%|██████████| 9027/9027 [00:01<00:00, 6293.60it/s]




targets.shape: (4650,)
contexts.shape: (4650, 5)
labels.shape: (4650, 5)
3183


100%|██████████| 3183/3183 [00:01<00:00, 2969.83it/s]




targets.shape: (3853,)
contexts.shape: (3853, 5)
labels.shape: (3853, 5)
428


100%|██████████| 428/428 [00:00<00:00, 1830.17it/s]




targets.shape: (774,)
contexts.shape: (774, 5)
labels.shape: (774, 5)
6333


100%|██████████| 6333/6333 [00:01<00:00, 4735.68it/s]




targets.shape: (4412,)
contexts.shape: (4412, 5)
labels.shape: (4412, 5)
16097


100%|██████████| 16097/16097 [00:01<00:00, 9782.63it/s]




targets.shape: (4873,)
contexts.shape: (4873, 5)
labels.shape: (4873, 5)
3181


100%|██████████| 3181/3181 [00:01<00:00, 3176.84it/s]




targets.shape: (3549,)
contexts.shape: (3549, 5)
labels.shape: (3549, 5)
319


100%|██████████| 319/319 [00:00<00:00, 2395.63it/s]




targets.shape: (491,)
contexts.shape: (491, 5)
labels.shape: (491, 5)
2625


100%|██████████| 2625/2625 [00:00<00:00, 2647.97it/s]




targets.shape: (3772,)
contexts.shape: (3772, 5)
labels.shape: (3772, 5)
15405


100%|██████████| 15405/15405 [00:02<00:00, 7146.35it/s]




targets.shape: (5066,)
contexts.shape: (5066, 5)
labels.shape: (5066, 5)
1481


100%|██████████| 1481/1481 [00:00<00:00, 1582.64it/s]




targets.shape: (3684,)
contexts.shape: (3684, 5)
labels.shape: (3684, 5)
683


100%|██████████| 683/683 [00:00<00:00, 1207.63it/s]




targets.shape: (2074,)
contexts.shape: (2074, 5)
labels.shape: (2074, 5)
705


100%|██████████| 705/705 [00:00<00:00, 1098.76it/s]




targets.shape: (2203,)
contexts.shape: (2203, 5)
labels.shape: (2203, 5)
19579


100%|██████████| 19579/19579 [00:14<00:00, 1318.81it/s]




targets.shape: (48542,)
contexts.shape: (48542, 5)
labels.shape: (48542, 5)
952


100%|██████████| 952/952 [00:01<00:00, 666.21it/s]




targets.shape: (4637,)
contexts.shape: (4637, 5)
labels.shape: (4637, 5)
1956


100%|██████████| 1956/1956 [00:01<00:00, 1028.45it/s]




targets.shape: (7577,)
contexts.shape: (7577, 5)
labels.shape: (7577, 5)
2553


100%|██████████| 2553/2553 [00:01<00:00, 2443.34it/s]




targets.shape: (3861,)
contexts.shape: (3861, 5)
labels.shape: (3861, 5)
374


100%|██████████| 374/374 [00:00<00:00, 1384.66it/s]




targets.shape: (926,)
contexts.shape: (926, 5)
labels.shape: (926, 5)
4399


100%|██████████| 4399/4399 [00:01<00:00, 3676.32it/s]




targets.shape: (4131,)
contexts.shape: (4131, 5)
labels.shape: (4131, 5)
4666


100%|██████████| 4666/4666 [00:01<00:00, 3229.89it/s]




targets.shape: (4263,)
contexts.shape: (4263, 5)
labels.shape: (4263, 5)
762


100%|██████████| 762/762 [00:01<00:00, 685.85it/s]




targets.shape: (2807,)
contexts.shape: (2807, 5)
labels.shape: (2807, 5)
15168


100%|██████████| 15168/15168 [00:13<00:00, 1121.14it/s]




targets.shape: (41797,)
contexts.shape: (41797, 5)
labels.shape: (41797, 5)
5311


100%|██████████| 5311/5311 [00:01<00:00, 4276.69it/s]




targets.shape: (4211,)
contexts.shape: (4211, 5)
labels.shape: (4211, 5)
689


100%|██████████| 689/689 [00:00<00:00, 958.02it/s]




targets.shape: (2274,)
contexts.shape: (2274, 5)
labels.shape: (2274, 5)
865


100%|██████████| 865/865 [00:00<00:00, 1156.26it/s]




targets.shape: (2203,)
contexts.shape: (2203, 5)
labels.shape: (2203, 5)
280


100%|██████████| 280/280 [00:00<00:00, 820.02it/s]




targets.shape: (678,)
contexts.shape: (678, 5)
labels.shape: (678, 5)
435


100%|██████████| 435/435 [00:00<00:00, 1153.31it/s]




targets.shape: (1328,)
contexts.shape: (1328, 5)
labels.shape: (1328, 5)
842


100%|██████████| 842/842 [00:00<00:00, 928.70it/s]




targets.shape: (3363,)
contexts.shape: (3363, 5)
labels.shape: (3363, 5)
811


100%|██████████| 811/811 [00:00<00:00, 1426.00it/s]




targets.shape: (2170,)
contexts.shape: (2170, 5)
labels.shape: (2170, 5)
604


100%|██████████| 604/604 [00:00<00:00, 1274.45it/s]




targets.shape: (1778,)
contexts.shape: (1778, 5)
labels.shape: (1778, 5)
4754


100%|██████████| 4754/4754 [00:01<00:00, 3961.01it/s]




targets.shape: (4190,)
contexts.shape: (4190, 5)
labels.shape: (4190, 5)
3107


100%|██████████| 3107/3107 [00:02<00:00, 1152.83it/s]




targets.shape: (10325,)
contexts.shape: (10325, 5)
labels.shape: (10325, 5)
1327


100%|██████████| 1327/1327 [00:01<00:00, 727.03it/s]




targets.shape: (4995,)
contexts.shape: (4995, 5)
labels.shape: (4995, 5)
735


100%|██████████| 735/735 [00:00<00:00, 980.21it/s] 




targets.shape: (2467,)
contexts.shape: (2467, 5)
labels.shape: (2467, 5)
51639


100%|██████████| 51639/51639 [00:03<00:00, 15802.34it/s]




targets.shape: (9130,)
contexts.shape: (9130, 5)
labels.shape: (9130, 5)
2689


100%|██████████| 2689/2689 [00:01<00:00, 2430.04it/s]




targets.shape: (3857,)
contexts.shape: (3857, 5)
labels.shape: (3857, 5)
15347


100%|██████████| 15347/15347 [00:01<00:00, 9299.12it/s]




targets.shape: (4813,)
contexts.shape: (4813, 5)
labels.shape: (4813, 5)
1138


100%|██████████| 1138/1138 [00:00<00:00, 1154.67it/s]




targets.shape: (3630,)
contexts.shape: (3630, 5)
labels.shape: (3630, 5)
16802


100%|██████████| 16802/16802 [00:03<00:00, 5133.67it/s]




targets.shape: (7999,)
contexts.shape: (7999, 5)
labels.shape: (7999, 5)
4338


100%|██████████| 4338/4338 [00:01<00:00, 3710.04it/s]




targets.shape: (4067,)
contexts.shape: (4067, 5)
labels.shape: (4067, 5)
780


100%|██████████| 780/780 [00:00<00:00, 1185.65it/s]




targets.shape: (2547,)
contexts.shape: (2547, 5)
labels.shape: (2547, 5)
16913


100%|██████████| 16913/16913 [00:01<00:00, 10879.70it/s]




targets.shape: (4805,)
contexts.shape: (4805, 5)
labels.shape: (4805, 5)
3501


100%|██████████| 3501/3501 [00:01<00:00, 2301.39it/s]




targets.shape: (4132,)
contexts.shape: (4132, 5)
labels.shape: (4132, 5)
1510


100%|██████████| 1510/1510 [00:01<00:00, 1132.39it/s]




targets.shape: (5041,)
contexts.shape: (5041, 5)
labels.shape: (5041, 5)
6468


100%|██████████| 6468/6468 [00:01<00:00, 4685.69it/s]




targets.shape: (4726,)
contexts.shape: (4726, 5)
labels.shape: (4726, 5)
7739


100%|██████████| 7739/7739 [00:02<00:00, 2981.60it/s]




targets.shape: (4554,)
contexts.shape: (4554, 5)
labels.shape: (4554, 5)
404


100%|██████████| 404/404 [00:00<00:00, 424.47it/s]




targets.shape: (1111,)
contexts.shape: (1111, 5)
labels.shape: (1111, 5)
11904


100%|██████████| 11904/11904 [00:03<00:00, 3859.39it/s]




targets.shape: (10284,)
contexts.shape: (10284, 5)
labels.shape: (10284, 5)
124878


100%|██████████| 124878/124878 [00:06<00:00, 18426.54it/s]




targets.shape: (14879,)
contexts.shape: (14879, 5)
labels.shape: (14879, 5)
470


100%|██████████| 470/470 [00:00<00:00, 1283.25it/s]




targets.shape: (1070,)
contexts.shape: (1070, 5)
labels.shape: (1070, 5)
560


100%|██████████| 560/560 [00:00<00:00, 1420.37it/s]




targets.shape: (1469,)
contexts.shape: (1469, 5)
labels.shape: (1469, 5)
4


100%|██████████| 4/4 [00:00<00:00, 358.45it/s]




targets.shape: (24,)
contexts.shape: (24, 5)
labels.shape: (24, 5)
23302


100%|██████████| 23302/23302 [00:06<00:00, 3689.04it/s]




targets.shape: (21665,)
contexts.shape: (21665, 5)
labels.shape: (21665, 5)
628


100%|██████████| 628/628 [00:00<00:00, 1014.43it/s]




targets.shape: (1705,)
contexts.shape: (1705, 5)
labels.shape: (1705, 5)
122853


100%|██████████| 122853/122853 [00:05<00:00, 20536.49it/s]




targets.shape: (12496,)
contexts.shape: (12496, 5)
labels.shape: (12496, 5)
1996


100%|██████████| 1996/1996 [00:00<00:00, 2238.03it/s]




targets.shape: (3043,)
contexts.shape: (3043, 5)
labels.shape: (3043, 5)
1620


100%|██████████| 1620/1620 [00:01<00:00, 1413.37it/s]




targets.shape: (3436,)
contexts.shape: (3436, 5)
labels.shape: (3436, 5)
2186


100%|██████████| 2186/2186 [00:00<00:00, 3743.77it/s]




targets.shape: (2075,)
contexts.shape: (2075, 5)
labels.shape: (2075, 5)
1946


100%|██████████| 1946/1946 [00:01<00:00, 1378.05it/s]




targets.shape: (4022,)
contexts.shape: (4022, 5)
labels.shape: (4022, 5)
3964


100%|██████████| 3964/3964 [00:01<00:00, 3369.56it/s]




targets.shape: (4287,)
contexts.shape: (4287, 5)
labels.shape: (4287, 5)
8106


100%|██████████| 8106/8106 [00:01<00:00, 5829.52it/s]




targets.shape: (4758,)
contexts.shape: (4758, 5)
labels.shape: (4758, 5)
1370


100%|██████████| 1370/1370 [00:00<00:00, 1675.54it/s]




targets.shape: (3092,)
contexts.shape: (3092, 5)
labels.shape: (3092, 5)
107213


100%|██████████| 107213/107213 [00:06<00:00, 16442.14it/s]




targets.shape: (14938,)
contexts.shape: (14938, 5)
labels.shape: (14938, 5)
1050


100%|██████████| 1050/1050 [00:01<00:00, 1038.41it/s]




targets.shape: (3622,)
contexts.shape: (3622, 5)
labels.shape: (3622, 5)
576


100%|██████████| 576/576 [00:00<00:00, 1728.45it/s]




targets.shape: (1292,)
contexts.shape: (1292, 5)
labels.shape: (1292, 5)
12156


100%|██████████| 12156/12156 [00:01<00:00, 7204.18it/s]




targets.shape: (5592,)
contexts.shape: (5592, 5)
labels.shape: (5592, 5)
4646


100%|██████████| 4646/4646 [00:01<00:00, 4004.93it/s]




targets.shape: (4002,)
contexts.shape: (4002, 5)
labels.shape: (4002, 5)
958


100%|██████████| 958/958 [00:00<00:00, 1139.54it/s]




targets.shape: (2438,)
contexts.shape: (2438, 5)
labels.shape: (2438, 5)
369


100%|██████████| 369/369 [00:00<00:00, 1129.02it/s]




targets.shape: (755,)
contexts.shape: (755, 5)
labels.shape: (755, 5)
10071


100%|██████████| 10071/10071 [00:02<00:00, 4340.44it/s]




targets.shape: (8269,)
contexts.shape: (8269, 5)
labels.shape: (8269, 5)
30265


100%|██████████| 30265/30265 [00:27<00:00, 1090.73it/s]




targets.shape: (92993,)
contexts.shape: (92993, 5)
labels.shape: (92993, 5)
3915


100%|██████████| 3915/3915 [00:01<00:00, 3603.69it/s]




targets.shape: (3847,)
contexts.shape: (3847, 5)
labels.shape: (3847, 5)
90465


100%|██████████| 90465/90465 [00:04<00:00, 18130.89it/s]




targets.shape: (13007,)
contexts.shape: (13007, 5)
labels.shape: (13007, 5)
803


100%|██████████| 803/803 [00:01<00:00, 756.53it/s]




targets.shape: (2764,)
contexts.shape: (2764, 5)
labels.shape: (2764, 5)
906


100%|██████████| 906/906 [00:00<00:00, 1118.59it/s]




targets.shape: (3004,)
contexts.shape: (3004, 5)
labels.shape: (3004, 5)
894


100%|██████████| 894/894 [00:00<00:00, 1318.69it/s]




targets.shape: (2621,)
contexts.shape: (2621, 5)
labels.shape: (2621, 5)
991


100%|██████████| 991/991 [00:00<00:00, 1223.31it/s]




targets.shape: (3087,)
contexts.shape: (3087, 5)
labels.shape: (3087, 5)
820


100%|██████████| 820/820 [00:00<00:00, 1157.20it/s]




targets.shape: (2696,)
contexts.shape: (2696, 5)
labels.shape: (2696, 5)
490


100%|██████████| 490/490 [00:00<00:00, 1448.10it/s]




targets.shape: (1292,)
contexts.shape: (1292, 5)
labels.shape: (1292, 5)
313


100%|██████████| 313/313 [00:00<00:00, 1507.97it/s]




targets.shape: (785,)
contexts.shape: (785, 5)
labels.shape: (785, 5)
1350


100%|██████████| 1350/1350 [00:00<00:00, 1548.11it/s]




targets.shape: (3408,)
contexts.shape: (3408, 5)
labels.shape: (3408, 5)
9120


100%|██████████| 9120/9120 [00:01<00:00, 6212.16it/s]




targets.shape: (4486,)
contexts.shape: (4486, 5)
labels.shape: (4486, 5)
794


100%|██████████| 794/794 [00:00<00:00, 1113.33it/s]




targets.shape: (2141,)
contexts.shape: (2141, 5)
labels.shape: (2141, 5)
600


100%|██████████| 600/600 [00:00<00:00, 1811.94it/s]




targets.shape: (762,)
contexts.shape: (762, 5)
labels.shape: (762, 5)
1227


100%|██████████| 1227/1227 [00:01<00:00, 1084.64it/s]




targets.shape: (4334,)
contexts.shape: (4334, 5)
labels.shape: (4334, 5)
17877


100%|██████████| 17877/17877 [00:04<00:00, 3595.23it/s]




targets.shape: (16642,)
contexts.shape: (16642, 5)
labels.shape: (16642, 5)
5196


100%|██████████| 5196/5196 [00:01<00:00, 3747.32it/s]




targets.shape: (4120,)
contexts.shape: (4120, 5)
labels.shape: (4120, 5)
1057


100%|██████████| 1057/1057 [00:00<00:00, 2134.12it/s]




targets.shape: (1772,)
contexts.shape: (1772, 5)
labels.shape: (1772, 5)
915


100%|██████████| 915/915 [00:00<00:00, 1726.84it/s]




targets.shape: (2036,)
contexts.shape: (2036, 5)
labels.shape: (2036, 5)
30602


100%|██████████| 30602/30602 [00:18<00:00, 1633.75it/s]




targets.shape: (64992,)
contexts.shape: (64992, 5)
labels.shape: (64992, 5)
555


100%|██████████| 555/555 [00:00<00:00, 2296.03it/s]




targets.shape: (769,)
contexts.shape: (769, 5)
labels.shape: (769, 5)
622


100%|██████████| 622/622 [00:00<00:00, 1556.96it/s]




targets.shape: (1429,)
contexts.shape: (1429, 5)
labels.shape: (1429, 5)
34541


100%|██████████| 34541/34541 [00:02<00:00, 15179.39it/s]




targets.shape: (6151,)
contexts.shape: (6151, 5)
labels.shape: (6151, 5)
663


100%|██████████| 663/663 [00:00<00:00, 1169.54it/s]




targets.shape: (2051,)
contexts.shape: (2051, 5)
labels.shape: (2051, 5)
2086


100%|██████████| 2086/2086 [00:00<00:00, 2119.60it/s]




targets.shape: (3729,)
contexts.shape: (3729, 5)
labels.shape: (3729, 5)
1361


100%|██████████| 1361/1361 [00:01<00:00, 918.67it/s] 




targets.shape: (4461,)
contexts.shape: (4461, 5)
labels.shape: (4461, 5)
93472


100%|██████████| 93472/93472 [00:04<00:00, 19250.17it/s]




targets.shape: (11943,)
contexts.shape: (11943, 5)
labels.shape: (11943, 5)
4995


100%|██████████| 4995/4995 [00:01<00:00, 3240.96it/s]




targets.shape: (4292,)
contexts.shape: (4292, 5)
labels.shape: (4292, 5)
6614


100%|██████████| 6614/6614 [00:01<00:00, 4928.94it/s]




targets.shape: (4382,)
contexts.shape: (4382, 5)
labels.shape: (4382, 5)
1087


100%|██████████| 1087/1087 [00:01<00:00, 960.63it/s]




targets.shape: (4220,)
contexts.shape: (4220, 5)
labels.shape: (4220, 5)
4240


100%|██████████| 4240/4240 [00:01<00:00, 3791.26it/s]




targets.shape: (3773,)
contexts.shape: (3773, 5)
labels.shape: (3773, 5)
11538


100%|██████████| 11538/11538 [00:02<00:00, 5000.08it/s]




targets.shape: (5383,)
contexts.shape: (5383, 5)
labels.shape: (5383, 5)
508


100%|██████████| 508/508 [00:00<00:00, 1100.85it/s]




targets.shape: (1692,)
contexts.shape: (1692, 5)
labels.shape: (1692, 5)
67613


100%|██████████| 67613/67613 [00:04<00:00, 14928.67it/s]




targets.shape: (12947,)
contexts.shape: (12947, 5)
labels.shape: (12947, 5)
9101


100%|██████████| 9101/9101 [00:01<00:00, 5835.02it/s]




targets.shape: (4770,)
contexts.shape: (4770, 5)
labels.shape: (4770, 5)
17624


100%|██████████| 17624/17624 [00:01<00:00, 10376.76it/s]



targets.shape: (4960,)
contexts.shape: (4960, 5)
labels.shape: (4960, 5)





### Configure the dataset for performance

To perform efficient batching for the potentially large number of training examples, use the `tf.data.Dataset` API. After this step, you would have a `tf.data.Dataset` object of `(target_word, context_word), (label)` elements to train your word2vec model!

In [None]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)

<_BatchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None))>


Apply `Dataset.cache` and `Dataset.prefetch` to improve performance:

In [None]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)

<_PrefetchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None))>


## Model and training

The word2vec model can be implemented as a classifier to distinguish between true context words from skip-grams and false context words obtained through negative sampling. You can perform a dot product multiplication between the embeddings of target and context words to obtain predictions for labels and compute the loss function against true labels in the dataset.

### Subclassed word2vec model

Use the [Keras Subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) to define your word2vec model with the following layers:

* `target_embedding`: A `tf.keras.layers.Embedding` layer, which looks up the embedding of a word when it appears as a target word. The number of parameters in this layer are `(vocab_size * embedding_dim)`.
* `context_embedding`: Another `tf.keras.layers.Embedding` layer, which looks up the embedding of a word when it appears as a context word. The number of parameters in this layer are the same as those in `target_embedding`, i.e. `(vocab_size * embedding_dim)`.
* `dots`: A `tf.keras.layers.Dot` layer that computes the dot product of target and context embeddings from a training pair.
* `flatten`: A `tf.keras.layers.Flatten` layer to flatten the results of `dots` layer into logits.

With the subclassed model, you can define the `call()` function that accepts `(target, context)` pairs which can then be passed into their corresponding embedding layer. Reshape the `context_embedding` to perform a dot product with `target_embedding` and return the flattened result.

Key point: The `target_embedding` and `context_embedding` layers can be shared as well. You could also use a concatenation of both embeddings as the final word2vec embedding.

In [None]:
class Word2Vec(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2Vec, self).__init__()
    self.target_embedding = layers.Embedding(vocab_size,
                                      embedding_dim,
                                      name="w2v_embedding")
    self.context_embedding = layers.Embedding(vocab_size,
                                       embedding_dim)

  def call(self, pair):
    target, context = pair
    # target: (batch, dummy?)  # The dummy axis doesn't exist in TF2.7+
    # context: (batch, context)
    if len(target.shape) == 2:
      target = tf.squeeze(target, axis=1)
    # target: (batch,)
    word_emb = self.target_embedding(target)
    # word_emb: (batch, embed)
    context_emb = self.context_embedding(context)
    # context_emb: (batch, context, embed)
    dots = tf.einsum('be,bce->bc', word_emb, context_emb)
    # dots: (batch, context)
    return dots

### Define loss function and compile model


For simplicity, you can use `tf.keras.losses.CategoricalCrossEntropy` as an alternative to the negative sampling loss. If you would like to write your own custom loss function, you can also do so as follows:

``` python
def custom_loss(x_logit, y_true):
      return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=y_true)
```

It's time to build your model! Instantiate your word2vec class with an embedding dimension of 128 (you could experiment with different values). Compile the model with the `tf.keras.optimizers.Adam` optimizer.

In [None]:
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
                 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])

Also define a callback to log training statistics for TensorBoard:

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

Train the model on the `dataset` for some number of epochs:

In [None]:
word2vec.fit(dataset, epochs=20, callbacks=[tensorboard_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x7a83b2106350>

TensorBoard now shows the word2vec model's accuracy and loss:

In [None]:
#docs_infra: no_execute
%tensorboard --logdir logs

<!-- <img class="tfo-display-only-on-site" src="images/word2vec_tensorboard.png"/> -->

## Embedding lookup and analysis

Obtain the weights from the model using `Model.get_layer` and `Layer.get_weights`. The `TextVectorization.get_vocabulary` function provides the vocabulary to build a metadata file with one token per line.

Create and save the vectors and metadata files:

Download the `vectors.tsv` and `metadata.tsv` to analyze the obtained embeddings in the [Embedding Projector](https://projector.tensorflow.org/):

In [None]:
try:
  from google.colab import files
  files.download('vectors.tsv')
  files.download('metadata.tsv')
except Exception:
  pass

## Next steps


This tutorial has shown you how to implement a skip-gram word2vec model with negative sampling from scratch and visualize the obtained word embeddings.

* To learn more about word vectors and their mathematical representations, refer to these [notes](https://web.stanford.edu/class/cs224n/readings/cs224n-2019-notes01-wordvecs1.pdf).

* To learn more about advanced text processing, read the [Transformer model for language understanding](https://www.tensorflow.org/tutorials/text/transformer) tutorial.

* If you’re interested in pre-trained embedding models, you may also be interested in [Exploring the TF-Hub CORD-19 Swivel Embeddings](https://www.tensorflow.org/hub/tutorials/cord_19_embeddings_keras), or the [Multilingual Universal Sentence Encoder](https://www.tensorflow.org/hub/tutorials/cross_lingual_similarity_with_tf_hub_multilingual_universal_encoder).

* You may also like to train the model on a new dataset (there are many available in [TensorFlow Datasets](https://www.tensorflow.org/datasets)).
