In [101]:
%reset -f

In [102]:
import gc
import numpy as np
from itertools import islice
from collections import Counter
gc.collect()

0

In [103]:
with open('./dataset/corpus', 'rt') as fp:
    corpus = fp.read()

corpus = corpus.strip()
corpus = corpus.split()
corpus[:5]

['anarchism', 'originated', 'as', 'a', 'term']

In [104]:
class SkipGramBatcher:
    def __init__(self, corpus, window_size=4, batch_size=32):
        '''corpus - list of words'''
        self.corpus = corpus
        self.window_size = window_size
        self.batch_size = batch_size
        self.make_vocab()
        return
    
    def make_vocab(self):
        self.vocab = sorted(set(corpus))
        self.word2index = {w: idx for idx, w in enumerate(self.vocab)}
        self.index2word = {idx: w for idx, w in enumerate(self.vocab)}
        return
    
    def batch_gen(self):
        '''c - corpus, v - vocab ; i - central, j - side'''
        x_batch = np.empty(self.batch_size, dtype=np.int)
        y_batch = np.empty(self.batch_size, dtype=np.int)
        curr_idx = 0
        for c_i, w in enumerate(self.corpus):
            v_i = self.word2index[w]
            window_left_border = c_i - self.window_size
            if window_left_border < 0:
                window_left_border = 0
            for side_w in self.corpus[window_left_border: c_i] \
                          + self.corpus[c_i + 1 : c_i + self.window_size + 1]:
                v_j = self.word2index[side_w]
                x_batch[curr_idx] = v_i
                y_batch[curr_idx] = v_j
                curr_idx += 1
                if curr_idx == self.batch_size:
                    curr_idx = 0
                    yield (x_batch, y_batch)
        # drop last
        #if curr_idx != 0:
            #yield (x_batch, y_batch)

In [105]:
def freq_filter(corpus, threshold):
    freq_map = Counter(corpus)
    filtrator = lambda w: 'UNK' if freq_map[w] <= threshold else w
    corpus = map(filtrator, corpus)
    return list(corpus)

In [106]:
corpus = freq_filter(corpus, 2)
corpus.count('UNK') / len(corpus)

0.011120887855114024

In [107]:
batcher = SkipGramBatcher(corpus, batch_size=32)

### checking batches' shape

In [108]:
max_iter = 5
for x_batch, y_batch in islice(batcher.batch_gen(), max_iter):
    print('x_batch.shape: {}, y_batch.shape: {}'.format(x_batch.shape, y_batch.shape))

x_batch.shape: (32,), y_batch.shape: (32,)
x_batch.shape: (32,), y_batch.shape: (32,)
x_batch.shape: (32,), y_batch.shape: (32,)
x_batch.shape: (32,), y_batch.shape: (32,)
x_batch.shape: (32,), y_batch.shape: (32,)


### print generated batches

In [109]:
print(corpus[:20])

['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english']


In [112]:
i2w = batcher.index2word
gen = batcher.batch_gen()

# printing 2 sequential batches
for _ in range(2):
    x_ids, y_ids = next(gen)
    print('----------------------')
    for x_w, y_w in [(i2w[i], i2w[j]) for i, j in zip(x_ids, y_ids)]:
        print(x_w, y_w)

----------------------
anarchism originated
anarchism as
anarchism a
anarchism term
originated anarchism
originated as
originated a
originated term
originated of
as anarchism
as originated
as a
as term
as of
as abuse
a anarchism
a originated
a as
a term
a of
a abuse
a first
term anarchism
term originated
term as
term a
term of
term abuse
term first
term used
of originated
of as
----------------------
of a
of term
of abuse
of first
of used
of against
abuse as
abuse a
abuse term
abuse of
abuse first
abuse used
abuse against
abuse early
first a
first term
first of
first abuse
first used
first against
first early
first working
used term
used of
used abuse
used first
used against
used early
used working
used class
against of
against abuse


### <span style="color:blue">*clean code*</span>         <------        ( in case you've lost it )