# Using ELMo Embeddings

In [1]:
import numpy as np
import os
import time
import random
import tensorflow as tf

# Not allocating full GPU memory upfront
%env TF_FORCE_GPU_ALLOW_GROWTH=true
# Making sure we cache the models and are not downloaded all the time
%env TFHUB_CACHE_DIR=./tfhub_modules

env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TFHUB_CACHE_DIR=./tfhub_modules


## Using pre-trained ELMo Model

### Downloading the ELMo Model from TFHub

In [2]:
import tensorflow_hub as hub
import tensorflow as tf
import tensorflow.keras.backend as K
K.clear_session()

elmo_layer = hub.KerasLayer("https://tfhub.dev/google/elmo/3",
                            signature="tokens",
                            signature_outputs_as_dict=True)

### Formatting the input for ELMo

ELMo expects the inputs to be in a specific format. Here we write a function to get the input in that format.

In [3]:
def format_text_for_elmo(texts, lower=True, split=" ", max_len=None):
  """Formats a given text for the ELMo model (takes in a list of strings)"""
  token_inputs = [] # Maintains individual tokens
  token_lengths = [] # Maintains the Length of each sequence

  max_len_inferred = 0
  # We keep a variable to maintain the max length of the input
  # Go through each text (string)
  for text in texts:
    # Process the text and get a list of tokens
    tokens = tf.keras.preprocessing.text.text_to_word_sequence(text, lower=lower, split=split)

    # Add the tokens
    token_inputs.append(tokens)

    # Compute the max length for the collection of sequences
    if len(tokens) > max_len_inferred:
      max_len_inferred = len(tokens)

  # It's important to make sure the maximum token length is only as large as the longest input in the sequence
  # You can't have arbitrarily large length as the maximum length. Otherwise, you'll get this error.
  #InvalidArgumentError:  Incompatible shapes: [2,6,1] vs. [2,10,1024]
  #    [[node mul (defined at .../python3.6/site-packages/tensorflow_hub/module_v2.py:106) ]] [Op:__inference_pruned_3391]
  # Here we make sure max_len is only as large as the longest input
  if max_len and max_len_inferred < max_len:
    max_len = max_len_inferred
  if not max_len:
    max_len = max_len_inferred

  # Go through each token sequence and modify sequences to have same length
  for i, token_seq in enumerate(token_inputs):
    token_lengths.append(min(len(token_seq), max_len))

    # If the maximum length is less than input length, truncate
    if max_len < len(token_seq):
      token_seq = token_seq[:max_len]

    # If the maximum length is greater than or equal to input length,
    # add padding as needed
    else:
      token_seq = token_seq + [""] * (max_len - len(token_seq))

    assert len(token_seq) == max_len

    token_inputs[i] = token_seq

  # Return the final output
  return {
      "tokens": tf.constant(token_inputs),
      "sequence_len": tf.constant(token_lengths)
  }

print(format_text_for_elmo(["the cat sat on the mat", "the mat sat"], max_len=10))

{'tokens': <tf.Tensor: shape=(2, 6), dtype=string, numpy=
array([[b'the', b'cat', b'sat', b'on', b'the', b'mat'],
       [b'the', b'mat', b'sat', b'', b'', b'']], dtype=object)>, 'sequence_len': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([6, 3], dtype=int32)>}


## Generating Document Embeddings with ELMo

### Downloading the data

This code downloads a [BBC dataset](hhttp://mlg.ucd.ie/files/datasets/bbc-fulltext.zip) consisting of news articles published by BBC.

In [4]:
from six.moves.urllib.request import urlretrieve
import zipfile
import numpy as np
import pandas as pd
import os
import time
import random
import tensorflow as tf
from matplotlib import pylab
from scipy.sparse import lil_matrix

url = 'http://mlg.ucd.ie/files/datasets/bbc-fulltext.zip'

def download_data(url, data_dir):
    """Download a file if not present, and make sure it's the right size."""

    # Create the data directory if not exist
    os.makedirs(data_dir, exist_ok=True)

    file_path = os.path.join(data_dir, 'bbc-fulltext.zip')

    # If file doesnt exist, download
    if not os.path.exists(file_path):
        print('Downloading file...')
        filename, _ = urlretrieve(url, file_path)
    else:
        print("File already exists")

    extract_path = os.path.join(data_dir, 'bbc')

    # If data has not been extracted already, extract data
    if not os.path.exists(extract_path):
        with zipfile.ZipFile(os.path.join(data_dir, 'bbc-fulltext.zip'), 'r') as zipf:
            zipf.extractall(data_dir)
    else:
        print("bbc-fulltext.zip has already been extracted")

download_data(url, 'data')

File already exists
bbc-fulltext.zip has already been extracted


### Read Data without Preprocessing

Here we read all the files and keep them as a list of strings, where each string is a single article

In [5]:
def read_data(data_dir):
    # This will contain the full list of stories
    news_stories = []
    filenames = []
    print("Reading files")

    i = 0 # Just used for printing progress
    for root, dirs, files in os.walk(data_dir):
        for fi, f in enumerate(files):
            # We don't read the readme file
            if 'README' in f:
                continue

            # Printing progress
            i += 1
            print("."*i, f, end='\r')

            # Open the file
            with open(os.path.join(root, f), encoding='latin-1') as text_file:
                story = []
                # Read all the lines
                for row in text_file:
                    story.append(row.strip())

                # Create a single string with all the rows in the doc
                story = ' '.join(story)
                # Add that to the list
                news_stories.append(story)
                filenames.append(os.path.join(root, f))

        print('', end='\r')

    print(f"\nDetected {len(news_stories)} stories")
    return news_stories, filenames


news_stories, filenames = read_data(os.path.join('data', 'bbc'))

# Printing some stats and sample data
print(f"{sum([len(story.split(' ')) for story in news_stories])} words found in the total news set")
print('Example words (start): ',news_stories[0][:50])
print('Example words (end): ',news_stories[-1][-50:])

Reading files
..........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

### Check the length statistics

Here we look at the 95-percientile in order to decide a good sequence length for inputs.

In [6]:
import pandas as pd

pd.Series([len(x.split(' ')) for x in news_stories]).describe(percentiles=[0.05, 0.95])

count    2225.000000
mean      388.837303
std       241.484273
min        91.000000
5%        164.200000
50%       336.000000
95%       736.800000
max      4489.000000
dtype: float64

### Compute the document embeddings

ELMo provides several outputs as the output (in the form of a dictionary). The most important output is in a key called `default` which is the averaged vector resulting from vectors produced for all the tokens in the input. We will use this as the document embedding.

In [7]:
batch_size = 4

news_elmo_embeddings = []

# Go through batches
for i in range(0, len(news_stories), batch_size):
    # Print progress
    print('.', end='')
    # Format ELMo inputs
    elmo_inputs = format_text_for_elmo(news_stories[i: min(i+batch_size, len(news_stories))], max_len=768)
    # Get the result stored in default
    elmo_result = elmo_layer(elmo_inputs)["default"]
    # Add that to a list
    news_elmo_embeddings.append(elmo_result)

# Create an array
news_elmo_embeddings = np.concatenate(news_elmo_embeddings, axis=0)

.............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

### Save the embeddings to disk

In [8]:
# Save the data to disk
os.makedirs('elmo_embeddings', exist_ok=True)

pd.DataFrame(
    news_elmo_embeddings, index=filenames
).to_pickle(
    os.path.join('elmo_embeddings', 'elmo_embeddings.pkl')
)

In [9]:
pd.read_pickle(os.path.join('elmo_embeddings', 'elmo_embeddings.pkl'))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
data/bbc/entertainment/181.txt,-0.054728,-0.227087,0.073870,-0.149917,0.202417,0.271801,-0.076706,0.124789,-0.167723,0.089672,...,-0.154803,0.237685,-0.058971,0.215510,0.129800,-0.084038,0.265275,-0.026663,0.576896,-0.011951
data/bbc/entertainment/162.txt,-0.257251,-0.286144,0.063412,0.063555,0.104869,0.496734,-0.113501,0.186762,0.103060,-0.040906,...,-0.333807,-0.065240,0.103234,0.178511,0.067251,-0.115650,-0.244925,-0.036662,0.453873,0.024653
data/bbc/entertainment/015.txt,-0.188401,-0.300901,0.023362,-0.128732,0.306804,0.241635,0.068264,0.128954,-0.128111,-0.200275,...,-0.253069,0.376615,-0.215489,0.362154,-0.009835,-0.058118,0.197030,0.056924,0.493857,-0.102432
data/bbc/entertainment/102.txt,-0.158987,-0.031365,0.076649,-0.037862,0.203332,0.194937,-0.061919,0.553116,0.235470,-0.074756,...,-0.115353,0.251829,-0.071439,0.164734,0.070648,0.282260,0.116603,0.053733,0.404045,-0.062439
data/bbc/entertainment/305.txt,0.086791,-0.278696,0.052923,-0.044659,0.348028,0.259595,0.103649,-0.065379,0.075237,-0.097908,...,-0.219957,0.270185,-0.175073,0.358120,0.088501,-0.005185,0.214776,0.000259,0.570177,-0.131550
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
data/bbc/tech/031.txt,0.078434,0.149747,-0.139460,0.167651,0.079742,-0.069641,-0.108829,-0.009174,-0.392660,0.045349,...,-0.526442,0.025496,-0.035541,0.098737,0.065343,0.070786,0.487475,-0.062433,0.374206,-0.247855
data/bbc/tech/394.txt,0.218646,0.029468,0.109840,-0.111590,0.289563,0.369362,0.142864,0.397531,-0.280676,0.065544,...,-0.239846,0.270525,0.012586,0.197014,0.145879,0.039718,0.171691,0.011018,0.252107,0.038584
data/bbc/tech/275.txt,0.033928,-0.024220,-0.081310,-0.139741,0.224108,0.348680,-0.081645,0.267504,-0.216363,0.018529,...,-0.225656,0.249019,0.104026,0.244571,0.137817,0.185246,0.563392,0.027944,0.352908,-0.067505
data/bbc/tech/079.txt,0.039464,0.112927,-0.202038,0.005816,0.145734,-0.210776,-0.180639,-0.084109,-0.262022,0.111711,...,-0.463283,0.042359,-0.063435,0.195232,-0.050059,0.004161,0.433020,-0.016625,0.245562,-0.294044


---