# Fruit-Fly-Net demo 🍄

In this notebook we'll show how to use our implementation of the fruit fly network (Fruit-Fly-Net) introduced in [Can a Fruit Fly Learn Word Embeddings](https://arxiv.org/abs/2101.06887) to create and compare word embeddings from Pokedex entries. 🐸🌱🐢🌊🦎🔥

Useful links 👀:
* [paper](https://arxiv.org/pdf/2101.06887.pdf)
* [our repo](https://github.com/Ramos-Ramos/fruit-fly-net)
* [authors' official repo](https://github.com/bhoov/flyvec)

**Disclaimers 🚨:**
* we're not the original authors, we just took a crack at implementing it
* we haven't tried reproducing their results

**Note: This notebook requires a GPU runtime for training** (`Runtime > Change runtime type > Hardware accelerator > GPU`)

## How does it work? 🤔


Fruit-Fly-Net creates word embeddings by trying to learn the correlations between words and their contexts.

Given a vocabulary of $N_{voc}$ tokens, Fruit-Fly-Net takes in a token and its context in the form of a binary input vector $v^A$ of length $2 \times N_{voc}$, where the first $N_{voc}$ dimensions form a bag-of-words representation of the context words and the remaining $N_{voc}$ dimensions form a one-hot encoding of the target word. These input vectors are created from n-grams (which the authors refer to as w-grams) taken from the training corpus. The center element of each w-gram becomes the target while the surrounding elements comprise the context.

<table>
  <tr><td colspan=6><center>"Charizard breathes flames"</center></td></tr>
  <tr>
    <td>breathes</td><td>charizard</td><td>flames</td>
    <td>breathes</td><td>charizard</td><td>flames</td>
  </tr>
  <tr>
    <td>0</td><td>1</td><td>1</td>
    <td>1</td><td>0</td><td>0</td>
  </tr>
<table>

Fruit-Fly-Net projects this input vector to $K$ dimensions, of which the top $k$ activations are set to 1 while the rest are suppressed to 0. To update the projection weights, Fruit-Fly-Net requires a $2 \times N_{voc}$-dimensional vector $p$, which is a concatenation of two duplicate vectors of probabilities of each token appearing in the trainset.

## Installations and imports 🔧
**Note**: You'll have to restart the runtime after installing the packages

In [None]:
pip install -U einops gradio numpy spacy git+https://github.com/Ramos-Ramos/fruit-fly-net

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
from einops import rearrange
import cupy as cp
import cupy as xp
import numpy as np
import gradio as gr
import pandas as pd
from cupyx.scipy.sparse import csr_matrix, vstack
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.notebook import tqdm

from collections import Counter, OrderedDict
import pickle

from fruit_fly_net import FruitFlyNet, bio_hash_loss

## Tokenizing the dataset 🧩

The corpus from which we'll create word embeddings consists of several Pokedex entries. To create word emebddings, we need to get a list of words to begin with. We can do that by tokenizing our corpus, or converting the corpus into a vocabulary of words, or "tokens". For Fruit-Fly-Net to work, we also need a list of probabilities for each token.

We can download our corpus here. It's in the form of a csv, which we can open in Pandas.

In [None]:
!wget https://raw.githubusercontent.com/Ramos-Ramos/fruit-fly-net/demo/dex.csv

In [None]:
df = pd.read_csv('dex.csv')
print('shape:', df.shape)
df.head()

To create longer pieces of text, we can combine Pokedex entries coming from the same Pokemon. This list of concatenated entries will be our corpus.

In [None]:
corpus = df.groupby('name').description.apply(' '.join)
print('shape:', corpus.shape)
corpus.head()

Now we proceed to the actual tokenization. We use SpaCy for this.

In [None]:
nlp = spacy.load('en_core_web_sm')

Our function for tokenization splits a piece of text into tokens using SpaCy, and ignores tokens that are punctuations, numbers, or stop words. Note that this is different from the tokenization process of the original authors.

In [None]:
def create_tokens_from_text(text):
  """Tokenizes text by:
  - splitting with SpaCy
  - ignoring punctuations, numbers, and stop words
  """
  return [w.lemma_.lower() for w in nlp(text) 
          if not w.is_punct and 
          not w.like_num and 
          not w.lemma_.lower() in nlp.Defaults.stop_words]

To finally create the vocabulary, we iterate over the corpus and tokenize the texts using our tokenization function. We can then calculate the probabilities of each token, or the percentage of the corpus they composed. We start with an intial $N_{voc}$ of $20,000$ following the authors but our final vocabulary ends up being much smaller (~$6,500$).

In [None]:
tokens = []
init_vocabulary_size = 20000
batch_size = 100

# create tokens
for batch_start in tqdm(range(0, len(corpus), batch_size)):
  tokens += create_tokens_from_text(
      ' '.join(corpus.iloc[batch_start:batch_start+batch_size])
  )

# clip vocabulary if necessary and calculate probabilities
tokens_to_counts = dict(Counter(tokens).most_common(init_vocabulary_size))
total_count = sum(tokens_to_counts.values())
tokens_to_probabilities = {token : count / total_count for token, count in tokens_to_counts.items()}

# finalize vocabulary size, vocabulary and probabilities
vocabulary = list(tokens_to_probabilities.keys())
probabilities = xp.tile(xp.array(list(tokens_to_probabilities.values())), 2)
vocabulary_size = len(tokens_to_counts)

print(f'vocabulary size: {vocabulary_size}')

## Preparing trainset 🔨

We have the vocabulary and probabilities, but we still need to create a trainset in the format accepted by Fruit-Fly-Net as described in the "How does it work 🤔" section. We can create two helper functions for creating token ids (unique numbers for each token in our vocabulary) and the actual input training embeddings.

In [None]:
def create_token_ids_from_text(text, vocabulary):
  """Creates tokens from text then gets corresponding indices for tokens in the
  vocabulary
  """
  tokens = create_tokens_from_text(text)
  token_ids = [vocabulary.index(token) for token in tokens if token in vocabulary]
  return token_ids

In [None]:
def create_training_embeddings_from_token_ids(token_ids, w_gram_size, vocabulary_size):
  """Creates several w-grams, then creates input training emebddings by having
  the middle token be the target the rest be the context
  """

  # create w-grams
  w_gram_size = min(w_gram_size, len(token_ids))
  middle_idx = w_gram_size//2
  w_grams = xp.array(np.lib.stride_tricks.sliding_window_view(token_ids, w_gram_size))
  w_grams[:, middle_idx] += vocabulary_size

  # create training embeddings
  training_embeddings = xp.zeros((w_grams.shape[0], vocabulary_size*2))
  training_embeddings[xp.indices(w_grams.shape)[0], w_grams] = 1
  training_embeddings = training_embeddings.astype(xp.bool_)

  return training_embeddings

All we have to do now is iterate over our corpus and create the training embeddings. We use a w-gram size of $15$.

In [None]:
w_gram_size = 15

training_embeddings = []
for text in tqdm(corpus):
  token_ids = create_token_ids_from_text(text, vocabulary)
  
  training_embeddings.append(
      csr_matrix(create_training_embeddings_from_token_ids(
          token_ids, w_gram_size, len(vocabulary)
      ))
  )
training_embeddings = vstack(training_embeddings)

## Training Fruit-Fly-Net 💪

Let's instantiate our Fruit-Fly-Net now. We use $K=400$, $k=51$, and a learning rate of $1e-6$.

In [None]:
model = FruitFlyNet(
  input_dim=vocabulary_size*2,  # input dimension size (vocab_size * 2)
  output_dim=400,               # output dimension size
  k=51,                         # top k cells to be left active in output layer
  lr=1e-6                       # learning rate (learning is performed internally)
)
model.to('gpu')

For each epoch in our train loop, we shuffle our trainset and iterate over each batch. For each batch, we feed the inputs to the model. The weight update is performed interally. We also print out the loss for every 1000 batches and at the end of each epoch. We use a batch size of $32$.

In [None]:
batch_size = 32

loss = 0
epochs = 10
for epoch in range(epochs):
  
  # shuffle trainset
  shuffled_idxs = xp.random.permutation(training_embeddings.shape[0])
  training_embeddings = training_embeddings[shuffled_idxs]
    
  for batch_start in tqdm(range(0, training_embeddings.shape[0], batch_size)):
    
    # train step
    input = training_embeddings[batch_start:batch_start+batch_size].toarray()
    model(input, probabilities)
    
    # get loss
    loss += bio_hash_loss(model.weights, input, probabilities)
    
    # print metrics every 1000 batches
    if batch_start//batch_size % 1000 == 999:
      print(f'epoch {epoch:2d} batch {batch_start//batch_size:4d}:\t{loss/(batch_size*1000):.3f}')
      loss = 0
        
  # print metrics after each epoch
  print(f'epoch {epoch:2d} batch {batch_start//batch_size:4d}:\t{loss/(batch_size*((training_embeddings.shape[0]//batch_size)%1000)):.3f}')
  loss = 0

## Optional: Switching to a CPU runtime ⚙️

Before proceeding to the interactive demo in the next section, you may switch to a CPU runtime if you'd like. To do so, follow the succeeding steps.

1. Save the vocabulary, probabilities, and model weights. Make sure to download the files after savng.

In [None]:
# save vocabulary
with open('vocab.pkl', 'wb') as vocab_file:
  pickle.dump(vocabulary, vocab_file)

# save probabilities
with open('prob.npy', 'wb') as prob_file:
  xp.save(prob_file, cp.asnumpy(probabilities))

# save model weights
with open('weights.pkl', 'wb') as file:
  pickle.dump(model.state_dict(), file)

2. Shut down this runtime by going to `Runtime > Factory Reset Runtime`. Then switch to a CPU runtime by going to `Runtime > Change Runtime Type > Hardware Accelerator > None`. After starting a new runtime, upload the `vocab.pkl`, `prob.npy`, and `weights.pkl` files to `/content/`.

3. Redo some installations, imports, and downloads.

In [None]:
pip install -U einops gradio spacy git+https://github.com/Ramos-Ramos/fruit-fly-net

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
from einops import rearrange
import cupy as cp
import numpy as xp
import numpy as np
import gradio as gr
import pandas as pd
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.notebook import tqdm

from collections import OrderedDict
import pickle

from fruit_fly_net import FruitFlyNet, bio_hash_loss

4. Redefine functions, reinstantiate classes, and load vocabulary, probabilities, and model weights.

In [None]:
# load vocabulary
with open('vocab.pkl', 'rb') as vocab_file:
  vocabulary = pickle.load(vocab_file)
  vocabulary_size = len(vocabulary)

# load probabilities
with open('prob.npy', 'rb') as prob_file:
  probabilities = xp.load(prob_file)

# tokenization functions and classes
nlp = spacy.load('en_core_web_sm')

def create_tokens_from_text(text):
  """Tokenizes text by:
  - splitting with SpaCy
  - ignoring punctuations, numbers, and stop words
  """
  return [w.lemma_.lower() for w in nlp(text) 
          if not w.is_punct and 
          not w.like_num and 
          not w.lemma_.lower() in nlp.Defaults.stop_words]

# reinstantiate model and load weights
model = FruitFlyNet(
  input_dim=vocabulary_size*2,  # input dimension size (vocab_size * 2)
  output_dim=400,               # output dimension size
  k=51,                         # top k cells to be left active in output layer
  lr=1e-6                       # learning rate (learning is performed internally)
)

with open('weights.pkl', 'rb') as file:
  model.load_state_dict(pickle.load(file))

## Interactive demo ⌨️

Here we use Gradio to let you perform similarity search with static word embeddings.

The inputs for static embeddings differ from the input embeddings for training by ignoring context and only having a one-hot encoded target word in the remaining $N_{voc}$ dimensions of the vector. Let's start with a helper function that can create this type of embedding from a token and a vocabulary.

In [None]:
def create_static_input_embedding_from_token(token, vocabulary):
  token = (create_tokens_from_text(token)+[None])[0]
  id = None if token not in vocabulary else vocabulary.index(token) + len(vocabulary)
  input_embedding = xp.zeros(len(vocabulary)*2)
  if id is not None:
    input_embedding[id] = 1
  return input_embedding

Now let's create static input embeddings for each token in our vocabulary.

In [None]:
static_input_embeddings = []
for token in tqdm(vocabulary):
  static_input_embeddings.append(
      create_static_input_embedding_from_token(token, vocabulary)
  )
static_input_embeddings = xp.stack(static_input_embeddings)

We then feed each input embedding into our model to create a static embedding for each token.

In [None]:
batch_size = 32
model.eval()
static_embeddings = []
for batch_start in tqdm(range(0, static_input_embeddings.shape[0], batch_size)):
  input = static_input_embeddings[batch_start:batch_start+batch_size]
  static_embeddings.append(model(input, probabilities))
static_embeddings = xp.concatenate(static_embeddings)

Now we can find the $n$ most similar words for a given input word (ex. "fire", "wing", "night").

Have fun!

In [None]:
def get_top_similar_tokens_with_scores(token, top_similar):
  
  token = (create_tokens_from_text(token)+[None])[0]
  id = None if token not in vocabulary else vocabulary.index(token)
  if id is None:
    return {'out of vocabulary': 1.0}
  
  input_embedding = create_static_input_embedding_from_token(token, vocabulary)
  input_embedding = rearrange(input_embedding, 'd -> () d')
  
  model.eval()
  embedding = model(input_embedding, probabilities)
  
  similarities = cosine_similarity(
      cp.asnumpy(embedding), cp.asnumpy(static_embeddings)
  )
  similarities = rearrange(similarities, '() i -> i')
  
  current_vocabulary = vocabulary
  if id is not None:
    similarities = np.concatenate((similarities[:id], similarities[id+1:]))
    current_vocabulary = current_vocabulary[:id]+current_vocabulary[id+1:]

  top_similar_ids = similarities.argsort(kind='stable')[-top_similar:].tolist()
  top_similar_scores = similarities[top_similar_ids]
  top_similar_tokens = [current_vocabulary[id] for id in top_similar_ids]
  return OrderedDict(zip(top_similar_tokens, top_similar_scores))

r = gr.inputs.Slider(1, 20, step=1, default=10)
gr.Interface(fn=get_top_similar_tokens_with_scores, inputs=['text', r], outputs='label').launch()