In [1]:
# These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
import numpy as np
import os
import time
import random
import tensorflow as tf

%env TF_FORCE_GPU_ALLOW_GROWTH=true
# Making sure we cache the models and are not downloaded all the time
%env TFHUB_CACHE_DIR=./tfhub_modules

env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TFHUB_CACHE_DIR=./tfhub_modules


## Using pre-trained ELMo Model

### Downloading the ELMo Model from TFHub

In [2]:
import tensorflow_hub as hub
import tensorflow.keras.backend as K

# Remove any ongoing sessions
K.clear_session()

# Download the ELMo model and save to disk
elmo_layer = hub.KerasLayer("https://tfhub.dev/google/elmo/3", signature="tokens",signature_outputs_as_dict=True)

### Formatting the input for ELMo

ELMo expects the inputs to be in a specific format. Here we write a function to get the input in that format.

In [4]:
def format_text_for_elmo(texts, lower=True, split=" ", max_len=None):
    
    """ Formats a given text for the ELMo model (takes in a list of strings) """
        
    token_inputs = [] # Maintains individual tokens
    token_lengths = [] # Maintains the length of each sequence
    
    max_len_inferred = 0 # We keep a variable to matain the max length of the input
    
    # Go through each text (string)
    for text in texts:    
        
        # Process the text and get a list of tokens
        tokens = tf.keras.preprocessing.text.text_to_word_sequence(text, lower=lower, split=split)
        
        # Add the tokens 
        token_inputs.append(tokens)                   
        
        # Compute the max length for the collection of sequences
        if len(tokens)>max_len_inferred:
            max_len_inferred = len(tokens)
    
    # It's important to make sure the maximum token length is only as large as the longest input in the sequence
    # You can't have arbitrarily large length as the maximum length. Otherwise, you'll get this error.
    #InvalidArgumentError:  Incompatible shapes: [2,6,1] vs. [2,10,1024]
    #    [[node mul (defined at .../python3.6/site-packages/tensorflow_hub/module_v2.py:106) ]] [Op:__inference_pruned_3391]
    
    # Here we make sure max_len is only as large as the longest input
    if max_len and max_len_inferred < max_len:
        max_len = max_len_inferred
    if not max_len:
        max_len = max_len_inferred
    
    # Go through each token sequence and modify sequences to have same length
    for i, token_seq in enumerate(token_inputs):
        
        token_lengths.append(min(len(token_seq), max_len))
        
        # If the maximum length is less than input length, truncate
        if max_len < len(token_seq):
            token_seq = token_seq[:max_len]            
        # If the maximum length is greater than or equal to input length, add padding as needed
        else:            
            token_seq = token_seq+[""]*(max_len-len(token_seq))
                
        assert len(token_seq)==max_len
        
        token_inputs[i] = token_seq
    
    # Return the final output
    return {
        "tokens": tf.constant(token_inputs), 
        "sequence_len": tf.constant(token_lengths)
    }


print(format_text_for_elmo(["the cat sat on the mat", "the mat sat"], max_len=10))

{'tokens': <tf.Tensor: shape=(2, 6), dtype=string, numpy=
array([[b'the', b'cat', b'sat', b'on', b'the', b'mat'],
       [b'the', b'mat', b'sat', b'', b'', b'']], dtype=object)>, 'sequence_len': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([6, 3])>}


In [5]:
# Titles of 001.txt - 005.txt in bbc/business
elmo_inputs = format_text_for_elmo([
    "Ad sales boost Time Warner profit",
    "Dollar gains on Greenspan speech",
    "Yukos unit buyer faces loan claim",
    "High fuel prices hit BA's profits",
    "Pernod takeover talk lifts Domecq"
])

# Get the result from ELMo
elmo_result = elmo_layer(elmo_inputs)

# Print the result
for k,v in elmo_result.items():    
    print(f"Tensor under key={k} is a {v.shape} shaped Tensor")

Tensor under key=word_emb is a (5, 6, 512) shaped Tensor
Tensor under key=lstm_outputs2 is a (5, 6, 1024) shaped Tensor
Tensor under key=sequence_len is a (5,) shaped Tensor
Tensor under key=elmo is a (5, 6, 1024) shaped Tensor
Tensor under key=lstm_outputs1 is a (5, 6, 1024) shaped Tensor
Tensor under key=default is a (5, 1024) shaped Tensor


## Generating Document Embeddings with ELMo

### Downloading the data

This code downloads a [BBC dataset](hhttp://mlg.ucd.ie/files/datasets/bbc-fulltext.zip) consisting of news articles published by BBC. 

In [6]:
url = 'http://mlg.ucd.ie/files/datasets/bbc-fulltext.zip'


def download_data(url, data_dir):
    """Download a file if not present, and make sure it's the right size."""
    
    # Create the data directory if not exist
    os.makedirs(data_dir, exist_ok=True)

    file_path = os.path.join(data_dir, 'bbc-fulltext.zip')
    
    # If file doesnt exist, download
    if not os.path.exists(file_path):
        print('Downloading file...')
        filename, _ = urlretrieve(url, file_path)
    else:
        print("File already exists")
  
    extract_path = os.path.join(data_dir, 'bbc')
    
    # If data has not been extracted already, extract data
    if not os.path.exists(extract_path):        
        with zipfile.ZipFile(os.path.join(data_dir, 'bbc-fulltext.zip'), 'r') as zipf:
            zipf.extractall(data_dir)
    else:
        print("bbc-fulltext.zip has already been extracted")
    
download_data(url, 'data')

File already exists
bbc-fulltext.zip has already been extracted


### Read Data without Preprocessing 

Here we read all the files and keep them as a list of strings, where each string is a single article

In [7]:
def read_data(data_dir):
    
    # This will contain the full list of stories
    news_stories = []    
    filenames = []
    print("Reading files")
    
    i = 0 # Just used for printing progress
    for root, dirs, files in os.walk(data_dir):
        
        for fi, f in enumerate(files):
            
            # We don't read the readme file
            if 'README' in f:
                continue
            
            # Printing progress
            i += 1
            print("."*i, f, end='\r')
            
            # Open the file
            with open(os.path.join(root, f), encoding='latin-1') as text_file:
                
                story = []
                # Read all the lines
                for row in text_file:
                                        
                    story.append(row.strip())
                    
                # Create a single string with all the rows in the doc
                story = ' '.join(story)                        
                # Add that to the list
                news_stories.append(story)  
                filenames.append(os.path.join(root, f))
                
        print('', end='\r')
        
    print(f"\nDetected {len(news_stories)} stories")
    return news_stories, filenames
                
  
news_stories, filenames = read_data(os.path.join('data', 'bbc'))

# Printing some stats and sample data
print(f"{sum([len(story.split(' ')) for story in news_stories])} words found in the total news set")
print('Example words (start): ',news_stories[0][:50])
print('Example words (end): ',news_stories[-1][-50:])

Reading files


. readme.txt.. 001.txt... 002.txt.... 003.txt..... 004.txt...... 005.txt....... 006.txt........ 007.txt......... 008.txt.......... 009.txt........... 010.txt............ 011.txt............. 012.txt.............. 013.txt............... 014.txt................ 015.txt................. 016.txt.................. 017.txt................... 018.txt.................... 019.txt..................... 020.txt...................... 021.txt....................... 022.txt........................ 023.txt......................... 024.txt.......................... 025.txt........................... 026.txt............................ 027.txt............................. 028.txt.............................. 029.txt............................... 030.txt................................ 031.txt................................. 032.txt.................................. 033.txt................................... 034.txt.................................... 035.txt......

........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

### Check the length statistics 

Here we look at the 95-percientile in order to decide a good sequence length for inputs.

In [8]:
import pandas as pd

pd.Series([len(x.split(' ')) for x in news_stories]).describe(percentiles=[0.05, 0.95])

count    2226.000000
mean      388.701707
std       241.514748
min        87.000000
5%        164.000000
50%       336.000000
95%       736.750000
max      4489.000000
dtype: float64

### Compute the document embeddings

ELMo provides several outputs as the output (in the form of a dictionary). The most important output is in a key called `default` which is the averaged vector resulting from vectors produced for all the tokens in the input. We will use this as the document embedding.

In [9]:
batch_size = 4

news_elmo_embeddings = []

# Go through batches
for i in range(0, len(news_stories), batch_size):
    
    # Print progress
    print('.', end='')
    # Format ELMo inputs
    elmo_inputs = format_text_for_elmo(news_stories[i: min(i+batch_size, len(news_stories))], max_len=768)    
    # Get the result stored in default
    elmo_result = elmo_layer(elmo_inputs)["default"]
    # Add that to a list
    news_elmo_embeddings.append(elmo_result)

# Create an array
news_elmo_embeddings = np.concatenate(news_elmo_embeddings, axis=0)    

.............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

### Save the embeddings to disk

In [10]:
# Save the data to disk
os.makedirs('elmo_embeddings', exist_ok=True)

pd.DataFrame(
    news_elmo_embeddings, index=filenames
).to_pickle(
    os.path.join('elmo_embeddings', 'elmo_embeddings.pkl')
)

In [11]:
pd.read_pickle(os.path.join('elmo_embeddings', 'elmo_embeddings.pkl'))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
data\bbc\readme.txt,0.317657,-0.037904,0.096557,-0.120920,-0.031368,0.100202,-0.026247,0.396833,-0.147786,-0.080523,...,-0.287333,0.232494,0.078371,0.287477,0.104193,0.087736,0.313276,-0.105956,0.213654,0.086695
data\bbc\business\001.txt,0.052209,-0.108837,0.103078,0.060255,0.277382,0.101622,0.147006,0.390242,0.053942,-0.051409,...,-0.179657,0.053069,0.091880,0.257571,0.044403,-0.093912,0.032205,-0.116896,0.420686,0.049002
data\bbc\business\002.txt,-0.277251,-0.498861,-0.000488,0.090846,0.320750,-0.054692,-0.041421,0.329660,-0.438261,-0.232513,...,-0.140432,-0.022603,0.369779,0.214081,-0.019910,-0.004382,-0.073545,0.050382,0.697808,-0.038186
data\bbc\business\003.txt,0.001640,-0.078992,0.178254,-0.076050,-0.188926,0.156753,-0.178013,0.316145,0.339274,0.103716,...,0.077891,0.137827,0.199895,-0.000208,0.204819,-0.148592,0.030815,-0.008611,0.582661,-0.055719
data\bbc\business\004.txt,-0.176339,-0.215292,0.215862,0.094050,0.498871,0.291168,-0.037569,0.238697,0.275630,0.069215,...,-0.472698,0.062436,0.207713,0.127696,0.203626,-0.116648,0.153931,-0.271529,0.392614,0.046904
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
data\bbc\tech\397.txt,0.211539,0.073864,-0.228550,-0.021582,0.234441,0.221942,-0.166678,0.108783,-0.160235,-0.005722,...,-0.432023,-0.045294,0.049274,0.030276,0.052349,0.131190,0.367616,-0.109777,0.069831,-0.080369
data\bbc\tech\398.txt,0.165850,0.008826,-0.154758,0.104231,0.289332,0.079604,-0.154873,-0.315355,-0.238488,0.168557,...,-0.419998,0.051438,-0.045387,0.279227,0.046775,0.104169,0.212102,-0.126483,0.122837,-0.329561
data\bbc\tech\399.txt,-0.066523,-0.190664,0.104273,-0.145981,0.181983,-0.011872,-0.086985,0.176678,-0.357082,-0.030099,...,-0.234072,0.192155,0.049397,0.103218,-0.009068,0.135200,0.389549,0.052150,0.347691,-0.027860
data\bbc\tech\400.txt,-0.054691,-0.251641,-0.107450,-0.015446,0.133443,0.000446,-0.091617,-0.058762,-0.347266,0.109252,...,-0.390207,0.009347,0.102706,0.475566,0.006273,-0.224070,0.159726,-0.073267,0.323672,-0.324118
