# Prelims

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai.conv_learner import *
from fastai.text import *

  from numpy.core.umath_tests import inner1d


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

## Helpers

In [4]:
def nonzero(pred):
    ints = to_np(pred).astype(int)
    return ints[np.nonzero(ints)]

def char_label_text(pred, chunk=70):
    return ''.join([itos[i] for i in nonzero(pred)])
#     return '\n'.join(textwrap.wrap(st, chunk))

def char_split_text(pred):
    return [itos[i] for i in nonzero(pred)]

def word_label_text(pred, chunk=70):
    return ' '.join([w_itos[i] for i in nonzero(pred)])
#     return '\n'.join(textwrap.wrap(st, chunk))

In [5]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [6]:
import Levenshtein as Lev

# pulled from Sean Nareen's deepspeech decoder module
# https://github.com/SeanNaren/deepspeech.pytorch/blob/master/decoder.py

def cer(t, p):
    """
    Computes the Character Error Rate, defined as the edit distance.
    Arguments:
        t (string): target space-separated sentence
        p (string): prediction space-separated sentence
    """
    t, p, = t.replace(' ', ''), p.replace(' ', '')
    return Lev.distance(t, p)/len(t)

def wer(s1, s2):
    """
    Computes the Word Error Rate, defined as the edit distance between the
    two provided sentences after tokenizing to words.
    Arguments:
        s1 (string): space-separated sentence
        s2 (string): space-separated sentence
    """

    # build mapping of words to integers
    b = set(s1.split() + s2.split())
    word2char = dict(zip(b, range(len(b))))

    # map the words to a char array (Levenshtein package only accepts strings)
    w1 = [chr(word2char[w]) for w in s1.split()]
    w2 = [chr(word2char[w]) for w in s2.split()]

    return Lev.distance(''.join(w1), ''.join(w2))/len(w1)

In [7]:
def char_error_rate(preds, targs):
    bs,sl = targs.size()      #=> ([bs, sl])
    # preds.size()            #=> ([bs, sl, vs])
        
    res = torch.argmax(preds, dim=2)
    error = 0
    for i in range(bs):
        p = char_label_text(res[i])
        t = char_label_text(targs[i])
        error += cer(t,p)
    return error/bs

def word_error_rate(preds, targs):
    bs,sl = targs.size()      #=> ([bs, sl])
    # preds.size()            #=> ([bs, sl, vs])
        
    res = torch.argmax(preds, dim=2)
    error = 0
    for i in range(bs):
        p = word_label_text(res[i])
        t = word_label_text(targs[i])
        error += wer(t,p)
    return error/bs

# Data

## WikiText

In [8]:
PATH = Path('data/wikitext/wikitext-2-raw')
# PATH = Path('data/wikitext/wikitext-103-raw')

with open(PATH/'wiki.train.raw') as file:  
    trn = file.read()
with open(PATH/'wiki.valid.raw') as file:  
    val = file.read()
with open(PATH/'wiki.test.raw') as file:  
    tst = file.read()

In [9]:
len(trn)
# 2:    10918892
# 103: 539566975

10918892

### clean the text

In [10]:
# convert spaced out " strings " to "strings"
def despace_quotes(m):
    m = m.group(0)   # entire matched string
    m = m.replace('" ','"')
    m = m.replace(' "','"')
    return m

def cleanup(x):
    x = x.replace(' @-@ ', '-').replace(' =', '').replace('\n \n \n', '\n').replace('\n \n', '\n').replace(
        " \'", "\'").replace(' ,', ',').replace(' .', '.').replace(' :', ':').replace(' ;', ';').replace(
        '( ', '(').replace(' )', ')').replace('[ ', '[').replace(' ]', ']').replace(' @.@ ', '.').replace(
        ' @,@ ', ',')
    x = re.sub(r'\"(.+?)\"', despace_quotes, x)
    return x

In [11]:
trn = cleanup(trn)
val = cleanup(val)
tst = cleanup(tst)

In [12]:
trn = trn + tst
# len(trn)

## IAM chars

In [12]:
IAM_PATH = Path('data/IAM_handwriting')

In [13]:
maxTextLen = 32
samples = []
chars = set()

with open(IAM_PATH/'ascii/words.txt') as f:
    for line in f:
        # ignore comment line
        if not line or line[0]=='#':
            continue

        lineSplit = line.strip().split(' ')
        assert len(lineSplit) >= 9

        fileName = lineSplit[0]

        # GT text are columns starting at 9
        gtText = ''.join(lineSplit[8:])[:maxTextLen]
        char_len = len(gtText)
        chars = chars.union(set(list(gtText)))

        # put sample into list
        samples.append([fileName, gtText, char_len])
    
samples = np.stack(samples)
df = pd.DataFrame(samples, columns=['filename', 'word', 'char_len'], )
del samples

In [14]:
df['char_len'] = df.char_len.astype('int32')
# df = df.loc[df['char_len'] > 3]
df = df.loc[df['char_len'] < 20]
df.head()

Unnamed: 0,filename,word,char_len
0,a01-000u-00-00,A,1
1,a01-000u-00-01,MOVE,4
2,a01-000u-00-02,to,2
3,a01-000u-00-03,stop,4
4,a01-000u-00-04,Mr.,3


In [15]:
iam_trn = ' '.join(df.word.values)
len(iam_trn)

590841

In [16]:
iam_val = iam_trn[-60000:]   #last 60000
iam_trn = iam_trn[:-60000]

In [17]:
trn = trn+iam_trn
val = val+iam_val

## Str to idx

In [13]:
itos = pickle.load(open('data/IAM_handwriting/tmp/char_itos.pkl', 'rb'))
stoi = collections.defaultdict(lambda: 0, {v:k for k,v in enumerate(itos)})
len(itos)

82

In [14]:
#convert text into idxs
trn_idx = np.array([stoi[c] for c in trn])
val_idx = np.array([stoi[c] for c in val])

In [15]:
# remove unknown chars
trn_mask = trn_idx.nonzero()
trn_idx = trn_idx[trn_mask]

val_mask = val_idx.nonzero()
val_idx = val_idx[val_mask]

In [16]:
''.join([itos[i] for i in trn_idx[:500]])

' \n Valkyria Chronicles III \n Senj no Valkyria 3: Unrecorded Chronicles (Japanese: 3, lit. Valkyria of the Battlefield 3), commonly referred to as Valkyria Chronicles III outside Japan, is a tactical role-playing video game developed by Sega and Media.Vision for the PlayStation Portable. Released in January 2011 in Japan, it is the third game in the Valkyria series. Employing the same fusion of tactical and real-time gameplay as its predecessors, the story runs parallel to the first game and foll'

# Language Model Loader

## AWD-LSTM Language Model

In [17]:
wd=1e-7
bptt=30  # back prop through time
bs=50
opt_fn = partial(optim.Adam, betas=(0.8, 0.99))

In [18]:
trn_dl = LanguageModelLoader(trn_idx, bs, bptt)
val_dl = LanguageModelLoader(val_idx, bs, bptt)
md = LanguageModelData(PATH, 0, len(itos), trn_dl, val_dl, bs=bs, bptt=bptt)

In [19]:
# overfitting - increase multiplier (0.7)
# underfitting - decrease multiplier (0.7)
drops = np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.7

In [20]:
em_sz,nh,nl = 400,1150,3
learner= md.get_model(opt_fn, em_sz, nh, nl, 
    dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])

learner.metrics = [accuracy]

In [36]:
lr=1e-3
learner.fit(lr, 1, wds=wd, use_clr=(10,2), cycle_len=2, best_save_name='wiki103_lm')

# wikitext2  15cycle(20,10), 1e-3
# 1.18086    1.189847   0.646    fastai LM    'wiki2_lm'

# wikitext103
# 1.070876   1.001363   0.696245

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                             
    0      1.134826   1.115255   0.663817  
    1      1.070876   1.001363   0.696245                             


[1.0013631663798634, 0.6962449800136478]

In [22]:
learner.load('wiki103_lm')

## Transformer Language Model

In [47]:
class LMLoader():
    """ Returns a language model iterator that iterates through batches that are of length N(bptt,5)
    The first batch returned is always bptt+25; the max possible width.  This is done because of the way that pytorch
    allocates cuda memory in order to prevent multiple buffers from being created as the batch width grows.
    """
    def __init__(self, nums, vocab_len, bs, bptt):
        self.bs,self.bptt = bs,bptt
        self.vocab_len = vocab_len
        self.data = self.batchify(nums)
        self.i,self.iter = 0,0
        self.n = len(self.data)
        
    def __iter__(self):
        self.i,self.iter = 0,0
        while self.i < self.n-1 and self.iter<len(self):
            if self.i == 0:
                seq_len = self.bptt + 5 * 5
            else:
                bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
                seq_len = max(5, int(np.random.normal(bptt, 5)))
            res = self.get_pair(self.i, seq_len)
            self.i += seq_len
            self.iter += 1
            yield res

    def __len__(self): return self.n // self.bptt - 1
    
    def batchify(self, data):
        nb = data.shape[0] // self.bs        # integer division into batches
        data = np.array(data[:nb*self.bs])   # remove remainder
        data = data.reshape(self.bs, -1).T   # reshape and transpose
        return T(data)                       # output a tensor
        
    def get_pair(self, i, seq_len):
        source = self.data
        seq_len = min(seq_len, len(source) - 1 - i)
        return source[i:i+seq_len].transpose(1,0), source[i+1:i+seq_len+1].transpose(1,0)

In [24]:
bs, bptt = 50, 100

In [48]:
trn_dl = LMLoader(trn_idx, len(itos), bs, bptt)
val_dl = LMLoader(val_idx, len(itos), bs, bptt)

In [49]:
md = LanguageModelData(PATH, 0, len(itos), trn_dl, val_dl)

In [50]:
ii = iter(md.trn_dl)
x,y = next(ii)
x.shape, y.shape

(torch.Size([50, 125]), torch.Size([50, 125]))

In [51]:
char_label_text(x[2])

'hitecture competition for its design, no construction took place. Instead, the parcel was turned into a parking lot, which it'

In [52]:
char_label_text(y[2])

'itecture competition for its design, no construction took place. Instead, the parcel was turned into a parking lot, which it '

## Denoising AutoEncoder LM

In [15]:
class DenoisingAutoEncoderLoader():
    """ Returns a language model iterator that iterates through batches that are of length N(bptt,5)
    The first batch returned is always bptt+25; the max possible width.  This is done because of the way that pytorch
    allocates cuda memory in order to prevent multiple buffers from being created as the batch width grows.
    """
    def __init__(self, nums, vocab_len, bs, bptt):
        self.bs,self.bptt = bs,bptt
        self.vocab_len = vocab_len
        self.data = self.batchify(nums)
        self.i,self.iter = 0,0
        self.n = len(self.data)
        
    def __iter__(self):
        self.i,self.iter = 0,0
        while self.i < self.n-1 and self.iter<len(self):
            if self.i == 0:
                seq_len = self.bptt + 5 * 5
            else:
                bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
                seq_len = max(5, int(np.random.normal(bptt, 5)))
            res = self.get_pair(self.i, seq_len)
            self.i += seq_len
            self.iter += 1
            yield res

    def __len__(self): return self.n // self.bptt - 1
    
    def batchify(self, data):
        nb = data.shape[0] // self.bs        # integer division into batches
        data = np.array(data[:nb*self.bs])   # remove remainder
        data = data.reshape(self.bs, -1).T   # reshape and transpose
        return data                          # output a tensor

    def get_pair(self, i, seq_len):
        seq_len = min(seq_len, self.n - 1 - i)
        arr = self.data[i:i+seq_len]    # (~125,50)

        # seq: (~bptt, bs)
        # need different scramble for each bs
        res,src = [],[]
        for b in range(arr.shape[1]):
            source = arr[:,b]
            
            # remove partial words from beginning and end
            spaces = np.where(source==1)[0]
            source = source[spaces[0]+1:spaces[-1]]
            src.append(source)
            
            seq = source.copy()
            
            # scramble
            num = random.randint(0, math.floor(self.bptt * 0.15))
            idxs = np.random.randint(len(seq), size=num)
            for i in idxs:
                if seq[i] not in [1,2]:   # don't modify ' ' or '\n'
                    prob = random.random()
                    if prob < 0.3:           #replacement
                        seq[i] = random.randrange(56,82) #self.vocab_len)  # only lowercase letters
                    elif 0.3 <= prob < 0.6:  #removal
                        seq[i] = 0
                    elif 0.6 <= prob < 0.9:  #addition
                        seq = np.insert(seq, i, random.randrange(56,82)) #self.vocab_len))
            mask = seq.nonzero()
            res.append(seq[mask])
        
        # convert res to 2d numpy array w/ zero padding
        out_res = np.zeros([len(res), len(max(res,key = lambda x: len(x)))], dtype=int)
        for i,j in enumerate(res):
            out_res[i][0:len(j)] = j
            
        out_src = np.zeros([len(src), len(max(src,key=lambda x: len(x)))], dtype=int)
        for i,j in enumerate(src):
            out_src[i][0:len(j)] = j
            
        return T(out_res), T(out_src)

In [16]:
bs, bptt = 50, 100

In [17]:
trn_dl = DenoisingAutoEncoderLoader(trn_idx, len(itos), bs, bptt)
val_dl = DenoisingAutoEncoderLoader(val_idx, len(itos), bs, bptt)

In [18]:
md = LanguageModelData(PATH, 0, len(itos), trn_dl, val_dl)

In [19]:
ii = iter(md.trn_dl)
x,y = next(ii)
x.shape, y.shape

(torch.Size([50, 125]), torch.Size([50, 122]))

In [20]:
char_label_text(x[2])

"it had dnissipted over open watrs. \n Preparations and impact \n Upon the cyclone's formatyon, the Bureau of Meteorflogy"

In [21]:
char_label_text(y[2])

"it had dissipated over open waters. \n Preparations and impact \n Upon the cyclone's formation, the Bureau of Meteorology"

# Helpers

## Loss and Metrics

In [22]:
def loss_prep(input, target):
    "equalize input/target sl; combine bs/sl dimensions"
    bs,tsl = target.shape
    _ ,sl,vocab = input.shape
        
    # F.pad( front,back for dimensions: 1,0,2 )
    if sl>tsl: target = F.pad(target, (0,sl-tsl))
        
    # this should only be used when testing for small seq_lens
    # if tsl>sl: target = target[:,:sl]
    
    if tsl>sl: input = F.pad(input, (0,0,0,tsl-sl))
    # not ideal => adds 82 logits all 0s...
        
    targ = target.contiguous().view(-1).long()
    pred = input.contiguous().view(-1, vocab)
    return pred, targ

In [23]:
# def loss_prep(input, target):
#     "equalize input/target sl; combine bs/sl dimensions"
# #     bs,tsl = target.shape
#     _,sl,vocab = input.shape
        
# #     # F.pad( front,back for dimensions: 1,0,2 )
# #     if sl>tsl: target = F.pad(target, (0,sl-tsl))
# #     if tsl>sl: target = target[:,:sl]
# # #     if tsl>sl: input = F.pad(input, (0,0,0,0,0,tsl-sl))
        
#     targ = target.contiguous().view(-1).long()
#     pred = input.contiguous().view(-1, vocab)
#     return pred, targ

In [24]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs

## Stepper

In [27]:
def subsequent_mask(size):
    attn_shape = torch.ones((size,size), dtype=torch.int, device=device)
    mask = torch.tril(attn_shape).unsqueeze(0)
    return mask

def make_tgt_mask(tgt, pad=0):
    "Create a mask to hide padding and future words."
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask

In [28]:
def rshift(tgt, token=1):
    "Shift y to the right by prepending token"
    return torch.cat((torch.ones((tgt.size(0),token), device=device, dtype=torch.long), tgt[:,:-1]), dim=-1)

In [29]:
class TfmrStepper(Stepper):
    def step(self, xs, y, epoch):
        xtra = []
        shifted_y = rshift(y).long()
        tgt_mask = subsequent_mask(shifted_y.size(-1)) #make_tgt_mask(shifted_y)
        output = self.m(*xs, shifted_y, tgt_mask)
        
        if isinstance(output,tuple): output,*xtra = output
        self.opt.zero_grad()
        loss = raw_loss = self.crit(output, y)
        if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
        loss.backward()        
        if self.clip:   # Gradient clipping
            nn.utils.clip_grad_norm_(trainable_params_(self.m), self.clip)
        self.opt.step()
        return raw_loss.item()
    
    def evaluate(self, xs, y):
        shifted_y = rshift(y).long()
        tgt_mask = subsequent_mask(shifted_y.size(-1)) #make_tgt_mask(shifted_y)
        preds = self.m(*xs, shifted_y, tgt_mask)
        if isinstance(preds,tuple): preds=preds[0]
        return preds, self.crit(preds, y)

## Transformer Modules

In [30]:
# similar to batchnorm but on a layer level
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [31]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [32]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [33]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [34]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x))
        return self.sublayer[1](x, self.feed_forward)

In [35]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [36]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [37]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)    
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [38]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [39]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2, mult=4):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*mult)
        self.w_2 = nn.Linear(d_model*mult, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [40]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.2, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
    
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

## Architecture

In [41]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(self.src_embed(src))
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [42]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * 18

In [43]:
def make_language_model(vocab, d_model=512, N=4, drop=0.2):
    c = copy.deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 8)
    ff = PositionwiseFeedForward(d_model, drop)
    pos_enc = PositionalEncoding(d_model, drop, 2000)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drop), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drop), N),
        nn.Sequential(nn.Embedding(vocab, d_model), pos_enc),
        nn.Sequential(nn.Embedding(vocab, d_model), pos_enc),
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [44]:
# denoising auto-encoder
# Want to predict entire output including masked words

class BetterSpeller(nn.Module):
    def __init__(self, lm):
        super(BetterSpeller, self).__init__()
        self.lm = lm
        
    def forward(self, src, tgt=None, tgt_mask=None):
        return self.lm.generate(self.lm(src, tgt, tgt_mask))

    def greedy_decode(self, src):
        with torch.no_grad():
            feats = self.lm.encode(src)
            bs,sl = src.shape
            tgt = torch.ones((bs,1), dtype=torch.long, device=device)

            res = []                
            for i in tqdm(range(sl+5)):
                mask = subsequent_mask(tgt.size(-1))
                dec_outs = self.lm.decode(feats, Variable(tgt), Variable(mask))
                prob = self.lm.generate(dec_outs[:,-1])
                res.append(prob)
                pred = torch.argmax(prob, dim=-1, keepdim=True)
                if (pred==0).all(): break
                tgt = torch.cat([tgt,pred], dim=-1)
            out = torch.stack(res).transpose(1,0).contiguous()
            return out      

In [45]:
d_model = 512
lm = make_language_model(len(itos), d_model)
net = BetterSpeller(lm)

wd=1e-7
opt_fn = partial(optim.Adam, betas=(0.8, 0.99))

learn = Learner(md, BasicModel(to_gpu(net)), opt_fn=opt_fn)

learn.clip = 0.25
learn.crit = LabelSmoothing(smoothing=0.1)
learn.metrics = [char_error_rate]

# LM

In [None]:
learn.lr_find(stepper=TfmrStepper)
learn.sched.plot(n_skip=0, n_skip_end=2)

In [63]:
lr=1e-3
learn.fit(lr, 1, wds=wd, use_clr=(20,10), cycle_len=1, stepper=TfmrStepper, best_save_name='better_speller_103')

# wikitext2
# 39.696276  38.299323  0.549255   0.535159     15cycles

# 6.512353   5.655003   0.035785   tfmr  15cycles(20,10)   'better_speller'


# wikitext103
# 10.755489  9.274058   0.029507   tfmr  1cycle(20,10)   'LM_103'
# 7.036011   7.017283   0.022698   2nd cycle

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   char_error_rate                     
    0      7.036011   7.017283   0.022698  



[7.017283363079806, 0.022697553340128872]

In [64]:
learn.save('LM_103')

In [46]:
learn.load('LM_103')

In [65]:
x,y = next(iter(md.val_dl))

In [66]:
learn.model.eval()
preds = learn.model.greedy_decode(x)

100%|██████████| 129/129 [00:02<00:00, 44.02it/s]


In [67]:
probs = torch.argmax(preds, dim=-1)

In [76]:
idx=5

In [77]:
char_label_text(x[idx])

'ater eading boeks about Nazi physician Josef Mengele while on tour with the band: "I remember stopping someplace'

In [78]:
char_label_text(probs[idx])

'after leading books about Nazi physician Josef Mengele while on tour with the band: "I remember stopping someplace'

In [79]:
char_label_text(y[idx])

'after reading books about Nazi physician Josef Mengele while on tour with the band: "I remember stopping someplace'

In [68]:
lr=2e-5
learn.fit(lr, 1, wds=wd, use_clr=(20,10), cycle_len=10)

HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))

epoch      trn_loss   val_loss   char_error_rate accuracy   
    0      43.857646  42.990693  0.627222   0.484561  
    1      40.316483  43.975653  0.621577   0.498255        
    2      40.887276  41.461205  0.623438   0.497235        
    3      39.405278  41.622659  0.617584   0.502438        
    4      40.867689  40.411398  0.625081   0.497289        
    5      39.15356   41.386987  0.618365   0.500589        
    6      39.567341  38.565258  0.62212    0.499334        
    7      40.321067  40.285084  0.621228   0.500227        
    8      39.070562  42.125116  0.619125   0.50083         
    9      39.798026  40.925197  0.617509   0.502518        



[40.92519678009881, 0.617509345910516, 0.5025176637702518]

In [41]:
preds = learn.model(x)

In [42]:
char_error_rate(preds,y)

0.6253905094664887

In [43]:
accuracy(preds,y)

tensor(0.4497, device='cuda:0')

# Test

In [25]:
learner.model.eval()
learner.model.training

False

In [42]:
def next_with_creativity(preds, k=5, thresh=.05):
    probs, idxs = torch.topk(F.softmax(preds, dim=-1), k, dim=-1)
    d = {itos[k]: round(v.item(), 3) for k,v in zip(idxs,probs)}
    print(d)
    
    seq = np.array([], dtype=np.long)
    for p,i in zip(probs,idxs):
        num = int(p * 100)
        seq = np.append(seq, [i.item()] * num)
    
    return random.choice(seq.flatten())
    
#     return{k:v if v>=thresh else None for k,v in d}
#     mask = [probs >= thresh] 
#     m_probs, m_idxs = probs[mask], idxs[mask]
    
#     if len(m_idxs) > 0:
#         # simple weighted choice
#         seq = 
#         random.choice(seq)
#         idx = random.randint(0,len(m_idxs))
#         return m_idxs[idx]
#     else:
#         return idxs[0]

In [47]:
def get_next(inp):
    idxs = T(np.array([stoi[c] for c in inp])).unsqueeze(0)
    p = learner.model(Variable(idxs))
#     i = torch.argmax(p[0][-1], dim=-1)
#     i = torch.multinomial(p[0].exp(), 1)[-1]
    i = next_with_creativity(p[0][-1])
    return itos[i.item()]

In [53]:
get_next('whe')

{' ': 0.196, 'n': 0.121, 's': 0.112, 'l': 0.078, 'd': 0.057}


'd'

In [48]:
def get_next_n(inp, n):
    res = inp
    for i in range(n):
        c = get_next(res)
        res += c
    return res

In [49]:
get_next_n('th', 10)

{'a': 0.296, 'e': 0.189, 'u': 0.12, 'o': 0.103, ' ': 0.09}
{'n': 0.102, 'b': 0.081, 'r': 0.075, 's': 0.056, 'm': 0.056}
{'a': 0.17, 'o': 0.121, 'e': 0.115, ' ': 0.092, 'i': 0.086}
{'a': 0.078, 'w': 0.051, 'i': 0.048, 'o': 0.046, ' ': 0.044}
{'n': 0.039, 'w': 0.039, 'r': 0.031, 'u': 0.028, 'p': 0.027}
{'i': 0.251, 'a': 0.197, 'e': 0.162, 'h': 0.153, 'o': 0.126}
{'n': 0.058, ' ': 0.049, 'm': 0.039, 'r': 0.039, 'l': 0.035}
{'a': 0.078, 'w': 0.051, 'i': 0.048, 'o': 0.046, ' ': 0.044}
{'n': 0.058, ' ': 0.049, 'm': 0.039, 'r': 0.039, 'l': 0.035}
{'g': 0.053, ' ': 0.047, 'i': 0.03, 'e': 0.03, 'a': 0.029}


'thum owa ana'