In [62]:
import os
import numpy as np
from infersent import InferSent
import torch

In [3]:
W2V_PATH = os.path.join(os.getenv('HOME'), 'fastText', 'crawl-300d-2M.vec')     # check your model location, plz
MODEL_PATH = 'encoder/infersent2.pkl'                                           # check your model location, plz
BATCH_SIZE = 64
EMB_DIM = 300
LSTM_DIM = 2048
POOL_TYPE = 'max'       # max / mean
DPOUT_MODEL = 0.0
MODEL_VER = 2

params_model = {'bsize': BATCH_SIZE, 
                'word_emb_dim': EMB_DIM, 
                'enc_lstm_dim': LSTM_DIM,
                'pool_type': POOL_TYPE, 
                'dpout_model': DPOUT_MODEL, 
                'version': MODEL_VER}

In [4]:
# utility functions #
def log(content, verbose=True):
    if verbose:
        print(f'[ LOG ] {content}')

In [5]:
def load_model(model_path, config, w2v_path=None):
    # load model
    model = InferSent(config)
    model.load_state_dict(torch.load(model_path))
    # model = model.cuda()   # GPU
    # model.set_w2v_path(w2v_path)
    # model.build_vocab(sentences)
    return model

In [6]:
def get_word_emb(path):
    if not os.path.exists(path):
        return None
    return np.load(path)

In [92]:
# @param bundles: list of bundle
# @ret: all bundle embedding's np array
def get_embeddings(bundles, emb_dir):
    # load beginning-of-sent and end-of-sent embedding
    emb_bos = np.load(os.path.join(emb_dir, 'bos.npy'))
    emb_eos = np.load(os.path.join(emb_dir, 'eos.npy'))
    
    embeddings = []
    lengths = []
    bundle_list = []
    failures = []
    max_len = 0
    for bundle in bundles:
        words = bundle.split(' ')
        emb = []
        emb.append(emb_bos)
        for w in words:
            emb_path = os.path.join(emb_dir, w + '.npy')
            emb.append(get_word_emb(emb_path))
        emb.append(emb_eos)
        if any(e is None for e in emb):
            failures.append(bundle)
            continue
        embeddings.append(emb)
        lengths.append(len(emb))
        bundle_list.append(bundle)
        max_len = len(emb) if len(emb) > max_len else max_len
    
    batches = np.zeros((max_len, len(embeddings), embeddings[0][0].shape[0]))
    for i in range(len(embeddings)):
        for j in range(len(embeddings[i])):
            batches[j][i][:] = embeddings[i][j]
    return batches, np.array(lengths), bundle_list

## main

In [None]:
def main(bundle_path, model_path, emb_dir, batch_size=64):
    log('getting bundles...')
    with open(bundle_path, 'r') as f:
        bundles = f.read().split('\n')
    embeddings = get_embeddings(bundles, emb_dir)
    log(f'{len(bundles)} bundles loaded')

    log(f'loading model...')
    model = load_model(model_path, config)

    for idx in range(0, len(embeddings), batch_size):
        batch = torch.FloatTensor(embeddings[])

In [10]:
bundle_path = 'data/bundles_all.txt'
model_path = 'encoder/infersent2.pkl'
emb_dir = 'word_emb'

In [64]:
batch_size = 64

In [11]:
with open(bundle_path, 'r') as f:
    bundles = f.read().split('\n')

In [14]:
bundles = bundles[:-1]   # exclude spaces

In [15]:
bundles[-5:]

['weekends and holidays',
 'wide range of',
 'worth a visit',
 'worth checking out',
 'years ago in']

In [85]:
embeddings, lengths = get_embeddings(bundles, emb_dir)

In [86]:
lengths[:10]

array([5, 5, 6, 5, 5, 5, 5, 5, 5, 5])

In [75]:
embeddings.shape

(6, 132, 300)

In [60]:
config = params_model

In [63]:
model = load_model(model_path, config)

In [66]:
len(embeddings)

6

In [88]:
p_embs = []
for idx in range(0, embeddings.shape[1], batch_size):
    batch = torch.FloatTensor(embeddings[:,idx:idx+batch_size, :])
    length = lengths[idx:idx+batch_size]
    with torch.no_grad():
        emb = model.forward((batch, length)).data.cpu().numpy()
    p_embs.append(emb)

In [90]:
ret = np.vstack(p_embs)

In [91]:
ret.shape

(132, 4096)

## get_embeddings

In [46]:
emb_bos = np.load(os.path.join(emb_dir, 'bos.npy'))
emb_eos = np.load(os.path.join(emb_dir, 'eos.npy'))

In [47]:
embeddings = []
failures = []
max_len = 0
for bundle in bundles:
    words = bundle.split(' ')
    emb = []
    emb.append(emb_bos)
    for w in words:
        emb_path = os.path.join(emb_dir, w + '.npy')
        emb.append(get_word_emb(emb_path))
    emb.append(emb_eos)
    if any(e is None for e in emb):
        failures.append(bundle)
        continue
    embeddings.append(emb)
    max_len = len(emb) if len(emb) > max_len else max_len

In [27]:
embeddings[0][0].shape

(300,)

In [28]:
max_len

6

In [None]:
    batch = np.zeros((max_len, len(embeddings), embeddings[0][0].shape[0]))
    for i in range(len(embeddings)):
        for j in range(len(embeddings[i])):
            batch[j][i][:] = embeddings[i][j]
    return batch

In [29]:
len(embeddings)

132

In [32]:
embeddings[0][0].shape[0]

300

In [48]:
batch = np.zeros((max_len, len(embeddings), embeddings[0][0].shape[0]))

In [44]:
batch.shape

(6, 132, 300)

In [49]:
for i in range(len(embeddings)):
    for j in range(len(embeddings[i])):
        batch[j][i][:] = embeddings[i][j]

In [50]:
emb_bos

array([-0.3398,  0.301 ,  0.1689, -0.059 ,  0.31  , -0.1625,  0.3938,
       -0.4378,  0.038 ,  0.0879,  0.26  ,  1.5207, -0.0033,  0.0499,
       -0.1864,  0.3068,  0.1511,  0.5598,  0.131 , -0.3409,  0.1883,
       -0.2301,  0.1512, -0.0274, -0.0101,  0.0052, -0.1908,  0.062 ,
        0.7689, -0.0926,  0.0714, -0.332 , -0.1068,  0.2163, -0.456 ,
        0.318 , -0.0541, -0.1677,  0.1038,  0.008 ,  0.4726, -0.2714,
        0.2154,  0.0421, -0.0844, -0.0081, -0.2349,  0.2663, -0.3735,
       -0.194 ,  0.1594,  0.3434,  0.8196,  0.2394,  0.0417,  0.4827,
        0.1411,  0.1159, -0.0286,  0.0492, -0.2025,  0.4332,  0.1325,
        0.064 ,  0.8302,  0.3763, -0.201 , -0.1348,  0.0174, -0.1784,
       -0.3994,  0.2344,  0.1994, -0.1032,  0.248 ,  0.2627, -0.2558,
        0.1891,  0.2943,  0.049 ,  0.033 ,  0.0905, -0.2289, -0.5167,
        0.0826, -0.0607, -0.2633, -0.1619,  0.7178, -0.0209,  0.0718,
        0.2381, -0.1027,  0.1029, -0.6198, -0.223 ,  0.1356,  0.2053,
        0.2352,  0.1

In [51]:
emb_eos

array([-2.0100e-01,  3.2120e-01, -2.7000e-02,  2.3670e-01,  9.2500e-02,
       -1.6900e-01,  2.6030e-01,  8.6500e-02,  1.3430e-01,  2.6290e-01,
       -4.5400e-02,  1.0735e+00,  5.9400e-02,  1.4350e-01, -2.1810e-01,
        5.9800e-01,  3.1390e-01,  2.8100e-01, -2.1340e-01, -9.7800e-02,
        4.5400e-01, -3.5600e-02,  4.3600e-01,  1.9600e-02, -2.5890e-01,
        4.7900e-02, -9.9000e-02, -3.0880e-01,  4.2010e-01,  8.0000e-03,
        1.5730e-01, -4.9630e-01,  6.4100e-02, -2.2880e-01, -7.9750e-01,
       -2.2830e-01, -2.9000e-02, -5.9100e-01, -1.1740e-01,  4.2200e-02,
        8.3000e-03,  2.4720e-01,  4.6340e-01, -1.7100e-02,  1.6050e-01,
       -1.9480e-01, -2.7890e-01,  2.4640e-01, -5.0350e-01, -1.2890e-01,
        4.1980e-01,  5.0280e-01,  7.7190e-01,  5.1000e-03,  1.4780e-01,
        7.3530e-01,  1.6720e-01,  2.2220e-01,  1.2980e-01,  9.7100e-02,
       -3.1400e-02,  6.8010e-01,  1.7010e-01, -4.9820e-01,  1.2317e+00,
        5.3420e-01, -3.5830e-01,  3.6420e-01, -7.2000e-02, -3.55

In [55]:
batch[:,1,:]

array([[-0.3398    ,  0.301     ,  0.1689    , ...,  0.0639    ,
         0.012     , -0.1761    ],
       [ 0.25195312, -0.19140625, -0.12158203, ..., -0.1171875 ,
        -0.15625   , -0.19628906],
       [ 0.17675781,  0.02929688,  0.01031494, ..., -0.22265625,
        -0.01416016, -0.08935547],
       [-0.01177979, -0.04736328,  0.04467773, ...,  0.07128906,
        -0.03491211,  0.02416992],
       [-0.20100001,  0.32120001, -0.027     , ...,  0.16670001,
        -0.0982    , -0.0186    ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]])