In [1]:
%matplotlib inline

# Extending Sequence Prediction Using Character Level Recurrent Networks

We extend the approach outlined in `char_rnn_one_file_code_gen.ipynb` to improve efficiency and to accomodate other recurrent cells. Notably:

* Training will use multiple files, instead of a single file
* Validation sets will be introduced to avoid overfitting on our training data
* We will have our model utilize mini-batches to speed up training

Our outlined task is still the same.

**Given a sequence of characters, predict the next likely character in the sequence.**

In [6]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from json import load
from random import shuffle
from math import floor
from time import time

## Preparing the Data

We will be training on a set of 900 preprocessed Python files (and validating on a set of 100 other files) arbitrarily sampled from [GitHub BigQuery Python Extracts](https://bigquery.cloud.google.com/table/fh-bigquery:github_extracts.contents_py_201802snap?pli=1).

We limit the characters that our neural network can produce to a subset of standard ASCII.
* `ORD 2*, 3*, 9, 10, 32-126`
  * `ORD 2` for start of text (special, never predicted)
  * `ORD 3` for end of text (special, prediction ends)
  * `ORD 9` horizontal tab "\t"
  * `ORD 10` NL line feed, new line "\n"
  * `ORD 32-126` Space, Punctuation, Digits, English Letters

NOTE: The full dataset contains files written using non standard characters. For the models in this notebook, we ensure that all Python files within our dataset are composed only of ASCII characters that we accept.

In [16]:
VALID_UNICODE_IDS = (2, 3, 9, 10) + tuple(range(32, 127))
for uid in VALID_UNICODE_IDS:
    if uid <= 32:
        print("{}: {}".format(uid, repr(chr(uid))), end=", ")
        continue
    elif uid == 33:
        print()
    print(chr(uid), end="")
print()

with open("./data/train.json", "r") as f:
    train_data = load(f)
with open("./data/validate.json", "r") as f:
    validation_data = load(f)

print("Num Training Files: {}".format(len(train_data)))
print("Num Validation Files: {}".format(len(validation_data)))

2: '\x02', 3: '\x03', 9: '\t', 10: '\n', 32: ' ', 
!"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~
Num Training Files: 180
Num Validation Files: 20
