In [None]:
#hide
!nvidia-smi

In [None]:
#hide
import sys
if 'google.colab' in sys.modules:
    !pip install -Uqq fastai einops datasets axial_positional_embedding wandb
    !pip install -qq git+git://github.com/arampacha/reformer_fastai.git

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#all_slow

# enwik8 - shared QK

In [None]:
from fastai.text.all import *
from reformer_fastai.all import *

## Experiment Tracking

Make sure you have wandb and are logged in:

In [None]:
# hide
!wandb login

Load Experiment Tracking with Weights & Biases:

In [None]:
import wandb

WANDB_NAME = 'test_n_layers_enwik8'
GROUP = 'TEST'
NOTES = 'ReformerLM on enwik8 sl 32k'
CONFIG = {}
TAGS = ['lm','lsh','enwik8', 'test']

## Download and Unpack enwik8 Data

Download and unzip enwik8 data

In [None]:
path = untar_data('http://mattmahoney.net/dc/enwik8.zip', dest='/data')

## Prepare Data

In [None]:
df = pd.DataFrame({'text':read_lines(path)})
df.head()

In [None]:
btt = ByteTextTokenizer(is_lm=True, add_bos=False, add_eos=False)

In [None]:
%%time
df['toks'] = df['text'].apply(btt)
df['lens'] = df['toks'].apply(len)
df['lens_cum_sum'] = df.lens.cumsum()

In [None]:
train_cutoff = df.lens.sum() - 10_000_000  # keep all but 10M characters for val and test
train_idxs = df.loc[df['lens_cum_sum'] < train_cutoff].index.values
train_idxs = list(range(0, max(train_idxs)))

remaining_idxs = len(df) - max(train_idxs)
validation_idxs = list(range(max(train_idxs), max(train_idxs) + int(remaining_idxs/2)))
test_idxs = list(range(max(validation_idxs), len(df)))

splits = [train_idxs, validation_idxs]

In [None]:
tfms = [attrgetter("text"), btt]
dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)

In [None]:
%%time
bs, sl = 1, 2**15
# pad_seq2seq = partial(pad_input, pad_idx=bte.pad_token_id, pad_fields=[0,1])
dl_kwargs = [{'lens':df['lens'].values[train_idxs]},
             {'val_lens':df['lens'].values[validation_idxs]}]
dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True, n_workers=2)

In [None]:
#collapse_output
dls.show_batch(max_n=2)

In [None]:
vocab_sz = btt.vocab_size

In [None]:
xb, yb = dls.one_batch()
xb.shape, yb.shape

In [None]:
#hide
del xb, yb
torch.cuda.empty_cache()

## Training

In [None]:
#hide_output
wandb.init(reinit=True, project="reformer-fastai", entity="fastai_community", 
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)

In [None]:
config = NLayersConfig(n_layers=12, n_hashes=4, max_seq_len=sl)
config

In [None]:
learn = Learner(dls, ReformerLM.from_config(config),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                cbs = [GradientAccumulation(n_acc=8), GradientClip(1.0),
                       PadBatchCallback(bucket_size=config.bucket_size)],
                metrics=[accuracy, perplexity, bpc])

In [None]:
#hide
# learn.lr_find()

In [None]:
learn.fit(1, cbs=WandbCallback(log_model=False, log_preds=False))

In [None]:
learn.recorder.plot_loss()