In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Import external dependencies

In [2]:
import argparse
import time
import math
import os
import itertools

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
# import devastator.blur as blur

## Import Blur modules

In [3]:
from blur import init_config_run, Config, Trainer, ScheduledOptimizer
from blur import Blur, DecoderXL, AdaptiveInput, AdaptiveLogSoftmaxWithLoss, DecoderFT
from blur.processing import get_lm_corpus

## Define model and training configuration

In [4]:
###############################################################################
# Define config paths
###############################################################################

config_dir = os.path.join('configs', 'xl')
config_run_path = os.path.join(config_dir, 'config_run.json')
config_encoder_path = os.path.join(config_dir, 'config_encoder.json')
config_decoder_path = os.path.join(config_dir, 'config_decoder.json')

###############################################################################
# Load config files
###############################################################################


config_model = Config(**{"tgt_len": 150, "mem_len": 150, "ext_len": 0})
config_run = init_config_run(config_run=Config.from_json(config_run_path), config_model=config_model)
config_encoder = Config.from_json(config_encoder_path)
config_decoder = Config.from_json(config_decoder_path)

## Set parameters for sandbox testing

In [5]:
device = 'cpu'
config_run.multi_gpu = False
config_run.max_step = 100

## Load data and set vocabulary length

## Build model

In [6]:
config_decoder.nft = 0
config_decoder.nxl = 16
config_decoder.ft_first = False

###############################################################################
# Build the model
###############################################################################

corpus = get_lm_corpus(config_run.data, config_run.dataset)
config_encoder.n_classes = len(corpus.vocab)

model = Blur(
    tgt_len = 150, mem_len = 150, ext_len = 0,
    encoder = AdaptiveInput(**config_encoder.parameters()),
    decoder = DecoderFT(**config_decoder.parameters()),
    lm_loss = AdaptiveLogSoftmaxWithLoss(**config_encoder.parameters()),
)
trainer = Trainer(config_run = config_run, device = device)

pass;


Loading cached dataset...


In [7]:
# ###############################################################################
# # Build the model
# ###############################################################################

# corpus = get_lm_corpus(config_run.data, config_run.dataset)
# config_encoder.n_classes = len(corpus.vocab)

# model = Blur(
#     tgt_len = 150, mem_len = 150, ext_len = 0,
#     encoder = AdaptiveInput(**config_encoder.parameters()),
#     decoder = DecoderXL(**config_decoder.parameters()),
#     lm_loss = AdaptiveLogSoftmaxWithLoss(**config_encoder.parameters()),
# )
# trainer = Trainer(config_run = config_run, device = device)

# pass;


## Load some example data

In [8]:
from blur.training import StreamDataset, StreamCollator

In [9]:
train_set = StreamDataset(
    data=corpus.train, tgt_len=model.tgt_len, batch_size=config_run.batch_size)
train_loader = torch.utils.data.DataLoader(
        dataset=train_set, batch_size=config_run.batch_size, shuffle=False,
        num_workers=0, pin_memory=True, sampler=None, collate_fn=StreamCollator())

In [10]:
detokenize_fn = np.vectorize(lambda x: corpus.vocab.idx2sym[x])

iterdata = {}
mems = tuple()

for batch, (data, target) in enumerate(train_loader):
    iterdata[batch] = (data, target, mems)
    
#     ret = model(data, target, *mems)
#     loss, mems = ret[0], ret[1:]
#     loss = loss.float().mean().type_as(loss)
    
    if batch>2:
        break

## Inspect example data

In [11]:
detokenize_fn(iterdata[0][1][0,:])[:50]
detokenize_fn(iterdata[0][1][1,:])[:50]

array(['=', 'Valkyria', 'Chronicles', 'III', '=', '<eos>', '<eos>',
       'Senjō', 'no', 'Valkyria', '3', ':', '<unk>', 'Chronicles', '(',
       'Japanese', ':', '戦場のヴァルキュリア3', ',', 'lit', '.', 'Valkyria', 'of',
       'the', 'Battlefield', '3', ')', ',', 'commonly', 'referred', 'to',
       'as', 'Valkyria', 'Chronicles', 'III', 'outside', 'Japan', ',',
       'is', 'a', 'tactical', 'role', '@-@', 'playing', 'video', 'game',
       'developed', 'by', 'Sega', 'and'], dtype='<U12')

array(['she', 'signed', 'with', 'the', 'Creative', 'Artists', 'Agency',
       '(', 'CAA', ')', 'having', 'previously', 'been', 'signed', 'with',
       'William', 'Morris', 'Endeavor', '(', '<unk>', ')', '.', 'Gaga',
       'told', 'Entertainment', 'Weekly', 'that', 'the', 'experience',
       'with', 'American', 'Horror', 'Story', 'will', 'influence', 'the',
       'creative', 'process', 'of', 'her', 'fifth', 'studio', 'album',
       ',', 'claiming', ':', '"', 'I', 'have', 'returned'], dtype='<U13')

## Run data through model

In [12]:
output = model(*iterdata[0])
{k:v.shape if hasattr(v, 'shape') else len(v) for k,v in output.items()}

{'output': torch.Size([150, 8, 410]), 'mems': 17, 'loss': torch.Size([150, 8])}

In [13]:
output['mems']

[tensor([[[-0.0000,  0.0050, -0.0196,  ..., -0.0086,  0.0000, -0.0000],
          [ 0.0128, -0.0262,  0.0291,  ...,  0.0000,  0.0016, -0.0010],
          [-0.0035,  0.0067,  0.0010,  ..., -0.0252, -0.0080, -0.0100],
          ...,
          [ 0.0000, -0.0337, -0.0067,  ...,  0.0391,  0.0267,  0.0013],
          [-0.0111, -0.0017, -0.0010,  ..., -0.0036, -0.0086,  0.0134],
          [-0.0361,  0.0050, -0.0000,  ..., -0.0086,  0.0205, -0.0056]],
 
         [[ 0.0001,  0.0187, -0.0295,  ...,  0.0138,  0.0011,  0.0289],
          [ 0.0157,  0.0271,  0.0281,  ...,  0.0316, -0.0276,  0.0127],
          [-0.0161,  0.0252, -0.0382,  ...,  0.0154, -0.0000,  0.0241],
          ...,
          [-0.0353, -0.0334, -0.0121,  ...,  0.0236,  0.0225, -0.0239],
          [ 0.0090,  0.0366,  0.0055,  ...,  0.0278,  0.0135,  0.0206],
          [-0.0000, -0.0238, -0.0128,  ...,  0.0103, -0.0242,  0.0235]],
 
         [[-0.0000,  0.0075, -0.0000,  ..., -0.0066,  0.0011, -0.0079],
          [-0.0038,  0.0224,