In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Import external dependencies

In [2]:
import argparse
import time
import math
import os
import itertools

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
# import devastator.blur as blur

## Import Blur modules

In [3]:
from blur import init_config_run, Config, Trainer, ScheduledOptimizer
from blur import Blur, DecoderXL, AdaptiveInput, AdaptiveLogSoftmaxWithLoss
from blur.processing import get_lm_corpus

## Define model and training configuration

In [4]:
###############################################################################
# Define config paths
###############################################################################

config_dir = os.path.join('configs', 'xl')
config_run_path = os.path.join(config_dir, 'config_run.json')
config_encoder_path = os.path.join(config_dir, 'config_encoder.json')
config_decoder_path = os.path.join(config_dir, 'config_decoder.json')

###############################################################################
# Load config files
###############################################################################


config_model = Config(**{"tgt_len": 150, "mem_len": 150, "ext_len": 0})
config_run = init_config_run(config_run=Config.from_json(config_run_path), config_model=config_model)
config_encoder = Config.from_json(config_encoder_path)
config_decoder = Config.from_json(config_decoder_path)

## Set parameters for sandbox testing

In [5]:
device = 'cpu'
config_run.multi_gpu = False
config_run.max_step = 100

## Load data and set vocabulary length

In [6]:
###############################################################################
# Load data
###############################################################################

corpus = get_lm_corpus(config_run.data, config_run.dataset)
config_encoder.n_classes = len(corpus.vocab)

Loading cached dataset...


## Build model

In [7]:
###############################################################################
# Build the model
###############################################################################

model = Blur(
    tgt_len = 150, mem_len = 150, ext_len = 0,
    encoder = AdaptiveInput(**config_encoder.parameters()),
    decoder = DecoderXL(**config_decoder.parameters()),
    lm_loss = AdaptiveLogSoftmaxWithLoss(**config_encoder.parameters()),
)
model.to(device)
para_model = None
trainer = Trainer(config_run = config_run, device = device)

pass;

## Load some example data

In [8]:
train_iter = trainer.get_train_iter(corpus, model)
detokenize_fn = np.vectorize(lambda x: corpus.vocab.idx2sym[x])

iterdata = {}
mems = tuple()

for batch, (data, target, seq_len) in enumerate(train_iter):
    iterdata[batch] = (data, target, mems)
    
#     ret = model(data, target, *mems)
#     loss, mems = ret[0], ret[1:]
#     loss = loss.float().mean().type_as(loss)
    
    if batch>2:
        break

## Inspect example data

In [9]:
detokenize_fn(iterdata[0][1][0,:])[:50]
detokenize_fn(iterdata[0][1][1,:])[:50]

array(['=', 'Valkyria', 'Chronicles', 'III', '=', '<eos>', '<eos>',
       'Senjō', 'no', 'Valkyria', '3', ':', '<unk>', 'Chronicles', '(',
       'Japanese', ':', '戦場のヴァルキュリア3', ',', 'lit', '.', 'Valkyria', 'of',
       'the', 'Battlefield', '3', ')', ',', 'commonly', 'referred', 'to',
       'as', 'Valkyria', 'Chronicles', 'III', 'outside', 'Japan', ',',
       'is', 'a', 'tactical', 'role', '@-@', 'playing', 'video', 'game',
       'developed', 'by', 'Sega', 'and'], dtype='<U12')

array(['working', 'with', 'Major', 'General', 'Leslie', 'R.', 'Groves',
       ',', 'Jr', '.', "'", 's', 'Manhattan', 'Project', ',', 'Tibbets',
       'selected', 'Wendover', 'for', 'his', 'training', 'base', 'over',
       'Great', 'Bend', 'Army', 'Air', 'Field', ',', 'Kansas', ',', 'and',
       'Mountain', 'Home', 'Army', 'Airfield', ',', 'Idaho', ',',
       'because', 'of', 'its', 'remoteness', '.', 'On', '14', 'September',
       '1944', ',', 'the'], dtype='<U11')

## Run data through model

In [10]:
output = model(*iterdata[0])
{k:v.shape if hasattr(v, 'shape') else len(v) for k,v in output.items()}

{'output': torch.Size([150, 24, 410]),
 'mems': 17,
 'loss': torch.Size([150, 24])}