In [1]:
import fastai

In [3]:
from fastai.text import *

In [4]:
EOS = '<eos>'
PATH=Path('./data/wikitext-2')

In [5]:
def read_file(filename):
    tokens = []
    with open(PATH/filename, encoding='utf8') as f:
        for line in f:
            tokens.append(line.split() + [EOS])
    return np.array(tokens)

In [6]:
trn_tok = read_file('wiki.train.tokens')
val_tok = read_file('wiki.valid.tokens')
tst_tok = read_file('wiki.test.tokens')

In [7]:
len(trn_tok)

36718

In [8]:
cnt = Counter(word for sent in trn_tok for word in sent)
itos = [o for o,c in cnt.most_common()]
itos.insert(0,'_pad_')

In [9]:
vocab_size = len(itos); vocab_size

33279

In [10]:
stoi = collections.defaultdict(lambda : 5, {w:i for i,w in enumerate(itos)})

In [32]:
trn_ids = np.array([([stoi[w] for w in s]) for s in trn_tok])
val_ids = np.array([([stoi[w] for w in s]) for s in val_tok])
tst_ids = np.array([([stoi[w] for w in s]) for s in tst_tok])

In [34]:
trn_ids = trn_ids[:5000]
val_ids = val_ids[:2000]

In [35]:
em_sz,nh,nl = 8,16,2
drops = np.array([0.6,0.4,0.5,0.05,0.2])

In [36]:
bptt, bs = 5, 128

In [37]:
trn_dl = LanguageModelLoader(np.concatenate(trn_ids), bs, bptt)
val_dl = LanguageModelLoader(np.concatenate(val_ids), bs, bptt)
md = LanguageModelData(PATH, 0, vocab_size, trn_dl, val_dl, bs=bs, bptt=bptt)

In [38]:
opt_fn = partial(optim.SGD, momentum=0.9)
learner= md.get_model(opt_fn, em_sz, nh, nl,
    dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])

In [39]:
learner.fit(lrs = 0.1, n_cycle = 1)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      8.04073    7.52068   


[7.520679618491501]

In [40]:
class TextReader():
    def __init__(self, nums, bptt, backwards=False):
        self.bptt,self.backwards = bptt,backwards
        self.data = self.batchify(nums)
        self.i,self.iter = 0,0
        self.n = len(self.data)

    def __iter__(self):
        self.i,self.iter = 0,0
        while self.i < self.n-1 and self.iter<len(self):
            res = self.get_batch(self.i, self.bptt)
            self.i += self.bptt
            self.iter += 1
            yield res

    def __len__(self): return self.n // self.bptt 

    def batchify(self, data):
        data = np.array(data)[:,None]
        if self.backwards: data=data[::-1]
        return T(data)

    def get_batch(self, i, seq_len):
        source = self.data
        seq_len = min(seq_len, len(source) - 1 - i)
        return source[i:i+seq_len], source[i+1:i+1+seq_len].view(-1)

In [41]:
def my_validate(model, source, bptt=2000):
    data_source = TextReader(source, bptt)
    model.eval()
    model.reset()
    total_loss = 0.
    for inputs, targets in tqdm(data_source):
        #The language model throws up a bucnh of things, we'll focus on that later. For now we just want the ouputs.
        outputs, raws, outs = model(V(inputs))
        #The output doesn't go through softmax so we can use the CrossEntropy loss directly 
        total_loss += F.cross_entropy(outputs, V(targets), size_average=False).data[0]
    #Total size is length of our iterator times bptt
    mean = total_loss / (bptt * len(data_source))
    #Returns loss and perplexity.
    return mean, np.exp(mean)

In [46]:
import warnings
warnings.filterwarnings('ignore')

In [47]:
my_validate(learner.model, np.concatenate(val_ids))

100%|██████████| 55/55 [00:41<00:00,  1.39it/s]


(tensor(7.5241), tensor(1852.1338))

In [44]:
def one_hot(vec, size=vocab_size, cuda=True):
    a = torch.zeros(len(vec), size)
    for i,v in enumerate(vec):
        a[i,v] = 1.
    return V(a)

In [52]:
def my_cache_pointer(model, source, theta = 0.662, lambd = 0.1279, window=200, bptt=200):
    data_source = TextReader(source, bptt)
    pdb.set_trace()
    #Set the model into eval mode.
    model.eval()
    #Just to create a hidden state.
    model.reset()
    total_loss = 0.
    #Containers for the previous targets/hidden states.
    targ_history = None
    hid_history = None
    for inputs, targets in tqdm(data_source):
        outputs, raws, outs = model(V(inputs))
        #The outputs aren't softmaxed, sowe have to do it to get the p_vocab vectors.
        p_vocab = F.softmax(outputs,1)
        #We take the last hidden states (raws contains one Tensor for the results of each layer) and remove the batch dimension.
        hiddens = raws[-1].squeeze() 
        #Start index inside our history.
        start = 0 if targ_history is None else targ_history.size(0)
        #Add the targets and hidden states to our history.
        targ_history = one_hot(targets) if targ_history is None else torch.cat([targ_history, one_hot(targets)])
        hid_history = hiddens if hid_history is None else torch.cat([hid_history, hiddens])
        for i, pv in enumerate(p_vocab):
            #Get the cached values
            p = pv
            if start + i > 0:
                targ_cache = targ_history[:start+i] if start + i <= window else targ_history[start+i-window:start+i]
                hid_cache = hid_history[:start+i] if start + i <= window else hid_history[start+i-window:start+i]
                #This is explained in the blog post.
                all_dot_prods = torch.mv(theta * hid_cache, hiddens[i])
                softmaxed = F.softmax(all_dot_prods).unsqueeze(1)
                p_cache = (softmaxed.expand_as(targ_cache) * targ_cache).sum(0).squeeze()
                p = (1-lambd) * pv + lambd * p_cache
            total_loss -= torch.log(p[targets[i]]).data[0]
        targ_history = targ_history[-window:]
        hid_history = hid_history[-window:]
    #Total size is length of our iterator times bptt
    mean = total_loss / (bptt * len(data_source))
    #Returns loss and perplexity
    return mean, np.exp(mean)

In [53]:
my_cache_pointer(learner.model, np.concatenate(val_ids))

> <ipython-input-52-a06c9eee4817>(5)my_cache_pointer()
-> model.eval()
(Pdb) ll
  1  	def my_cache_pointer(model, source, theta = 0.662, lambd = 0.1279, window=200, bptt=200):
  2  	    data_source = TextReader(source, bptt)
  3  	    pdb.set_trace()
  4  	    #Set the model into eval mode.
  5  ->	    model.eval()
  6  	    #Just to create a hidden state.
  7  	    model.reset()
  8  	    total_loss = 0.
  9  	    #Containers for the previous targets/hidden states.
 10  	    targ_history = None
 11  	    hid_history = None
 12  	    for inputs, targets in tqdm(data_source):
 13  	        outputs, raws, outs = model(V(inputs))
 14  	        #The outputs aren't softmaxed, sowe have to do it to get the p_vocab vectors.
 15  	        p_vocab = F.softmax(outputs,1)
 16  	        #We take the last hidden states (raws contains one Tensor for the results of each layer) and remove the batch dimension.
 17  	        hiddens = raws[-1].squeeze()
 18  	        #Start index inside our history.
 19

(Pdb) tagets
*** NameError: name 'tagets' is not defined
(Pdb) targets.shape
torch.Size([200])
(Pdb) targets
tensor([   11, 10854, 33171,    11,     9,     9, 10854, 33171,     2,   123,
           17,     1,   606,  9440,    47,   399,  9440,     2,    25,    10,
          216,     4,     5,  9440,    26,     1,   571,   660,  1728,     2,
         3384,  1476,     6,   765,     4,     1,   736,  1476,     3,    62,
           25,  2365,   941,     8,     1,   145,  9440,     2,  2137, 30470,
            3,    62,   148,  2914,     8,    10,   831,     4,   988,   982,
           24,   495,     7,    23,     6,    10,  1079,     4,   147, 12276,
           24,   361,  3221,    23,     2,     6,  4844,    10, 11690,  1469,
            4, 14076,     3,    35,   181,     2,     1, 21831,    37,  1493,
            2,    76,   962,    12,  9440,   900,    12,    19,  6810,     3,
        30363,  2636,     7,     1,  1006,     2,  2679,  3086,    34,    37,
         1091,    22,     1,  193

(Pdb) n
> <ipython-input-52-a06c9eee4817>(32)my_cache_pointer()
-> p_cache = (softmaxed.expand_as(targ_cache) * targ_cache).sum(0).squeeze()
(Pdb) softmaxed
tensor([[1.]])
(Pdb) softmaxed.expand_as(targ_cache).shape
torch.Size([1, 33279])
(Pdb) n
> <ipython-input-52-a06c9eee4817>(33)my_cache_pointer()
-> p = (1-lambd) * pv + lambd * p_cache
(Pdb) q


BdbQuit: 