# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
# from fastai import *
from fastai.text import *
# from fastai.callbacks.tracker import *
import pdb
import textwrap
import sentencepiece as spm

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Setup Data

In [None]:
def remove_newline_spaces(x):
    x = x.replace(' \n','\n').replace('\n ','\n')
    return re.sub(r'(\n)+','\n',x)

def remove_equals(x):
    return x.replace(' =', '').replace('= ', '')

# convert spaced out " strings " to "strings"
def despace_quotes(m):
    m = m.group(0)   # entire matched string
    return m.replace('" ','"').replace(' "','"')

def cleanup(x):
    x = fix_html(x)
    x = remove_newline_spaces(x)
    x = remove_equals(x)
    x = x.replace( " \'", "\'").replace(' ,', ',').replace(' .', '.').replace(' :', ':').replace(
        ' ;', ';').replace('( ', '(').replace(' )', ')').replace('[ ', '[').replace(' ]', ']').replace(
        ' %', '%').replace('$ ','$').replace('# ','#')
    x = re.sub(r'\"(.+?)\"', despace_quotes, x)
    x = re.sub(r'[^\x00-\x7F]+','', x)   # remove all non ascii characters
    return x

## WikiText

In [None]:
# wiki_path = Path('data/wikitext/wikitext-2-raw')
wiki_path = Path('data/wikitext/wikitext-103-raw')

In [None]:
with open(wiki_path/'wiki.train.raw') as file:  
    trn = file.read()
with open(wiki_path/'wiki.valid.raw') as file:  
    val = file.read()
with open(wiki_path/'wiki.test.raw') as file:  
    tst = file.read()

In [None]:
len(trn)
# 2:    10918892
# 103: 539566975

In [None]:
full = trn + val + tst
# len(trn)

In [None]:
full = cleanup(full)

In [None]:
full[0:5000]

In [None]:
lines = full.split('\n')
len(lines)

In [None]:
wiki_df = DataFrame({'text': lines[1:-1]})
wiki_df.head()

In [None]:
wiki_df.to_csv(PATH/'clean_wiki_103.csv', index=False)

## IMDB

In [None]:
imdb_path = untar_data(URLs.IMDB)
imdb_path

In [None]:
CSV = 'texts.csv'
imdb_df = pd.read_csv(imdb_path/CSV)
len(imdb_df)

In [None]:
imdb_df['text'] = imdb_df.text.apply(lambda x: cleanup(x))

In [None]:
len(imdb_df)

In [None]:
imdb_df.text.values[0]

In [None]:
imdb_df.to_csv(PATH/'clean_imdb.csv', index=False)

## Book texts

In [None]:
FPATH = Path('data/fonts/')

a = pd.read_csv(FPATH/'knot.csv', usecols=['filename', 'label'])
c = pd.read_csv(FPATH/'adrift.csv', usecols=['filename', 'label'])
d = pd.read_csv(FPATH/'zane.csv', usecols=['filename', 'label'])
e = pd.read_csv(FPATH/'american.csv', usecols=['filename', 'label'])
f = pd.read_csv(FPATH/'age.csv', usecols=['filename', 'label'])
g = pd.read_csv(FPATH/'room.csv', usecols=['filename', 'label'])
book_df = pd.concat([a,c,d,e,f,g], ignore_index=True)
len(book_df)

In [None]:
book_df['text'] = book_df.label.apply(lambda x: x.replace('\n',' '))

In [None]:
book_df.head()

In [None]:
book_df.to_csv(PATH/'clean_books.csv', columns=['text'], index=False)

## Paragraphs (w/out test)

In [None]:
pg_df = pd.read_csv(PATH/'edited_pg.csv')
pg_df.head()

In [None]:
pg_df['text'] = pg_df.text.apply(lambda x: x.replace('\n',' '))

In [None]:
pg_df.to_csv(PATH/'clean_pg.csv', columns=['text'], index=False)

## Combine and process (remove caps)

In [None]:
a = pd.read_csv(PATH/'clean_pg.csv')
b = pd.read_csv(PATH/'clean_books.csv')
c = pd.read_csv(PATH/'clean_imdb.csv')
d = pd.read_csv(PATH/'clean_wiki_103.csv')

In [None]:
full = pd.concat([a, b, c, d], ignore_index=True)

In [None]:
full.dropna(inplace=True)
full.reset_index(inplace=True, drop=True)
len(full)

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

In [None]:
full['text'] = full.text.apply(lambda x: rm_useless_spaces(x))

In [None]:
full['text'] = full.text.apply(lambda x: add_cap_tokens(x))

In [None]:
full.tail()

## Write to raw

In [None]:
name = str(PATH/'spm_full')
# name = str(PATH/'spm_pg')
fname = name + '.txt'

In [None]:
# puts dataset into format expected by sentencepiece:
# .txt file entries separated by \n
def write_text(texts, filename=fname):
    with open(filename, 'w', encoding='utf-8') as f:
        for text in texts:
            f.write(text + "\n")

In [None]:
write_text(full.text.values)
# write_text(pg_df.text.values)

# Train SPM

In [None]:
name = str(PATH/'spm_full')
fname = name + '.txt'

In [None]:
symbols = "\n,[UP],[MAJ],▁,:,;,!,?,(,),[,],{,},<,>,@,#,$,%,^,&,*,-,_,+,=,/,~"

In [None]:
spm.SentencePieceTrainer.Train(
    f"--unk_id=3 --pad_id=0 --input={fname} --model_prefix={name+'_30k'} --vocab_size=30000 --user_defined_symbols={symbols} --input_sentence_size=1500000 --shuffle_input_sentence=True"
)

In [None]:
sp = spm.SentencePieceProcessor()
sp.Load(name+'_30k.model')

In [None]:
df = pd.read_csv(fname, sep='\n', header=None, names=['text'])
st = df.text.values[-2]; st

In [None]:
pieces = sp.encode_as_pieces(st)

In [None]:
st2 = sp.decode_pieces(pieces[1:])

In [None]:
st == st2

In [None]:
pieces

In [None]:
sp.EncodeAsPieces('adaptability')

In [None]:
for n in range(5):
    print(sp.SampleEncodeAsPieces('adaptability', -1, .1))

In [None]:
vocab = {i: sp.id_to_piece(i) for i in range(len(sp))}
vocab

# DataBunch

In [None]:
def char_label_text(pred):
    return self.sp.DecodeIds(pred.tolist())

In [None]:
class SPMProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.sp = ds.sp if ds is not None else None

    def process_one(self,item): return self.sp.EncodeAsIds(item)
    def process(self, ds): super().process(ds)
    
class SPMList(ItemList):
    _processor = [SPMProcessor]

    def __init__(self, items:Iterator, sp_model_path, **kwargs):
        super().__init__(items, **kwargs)
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(sp_model_path)
        self.sp.SetEncodeExtraOptions("eos")
        self.sp.SetDecodeExtraOptions("bos:eos")
        self.copy_new += ['sp']

    def get(self, i):
        o = super().get(i)
        return Text(o, self.sp.DecodeIds(o))

    def reconstruct(self, t:Tensor):
        return Text(t, self.sp.DecodeIds(t.tolist()))

In [None]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load(str(PATH/'spm_train.model'))
sp.SetEncodeExtraOptions("bos:eos")

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SPMList, sp_processor=sp)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        .normalize()
       )

In [None]:
def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1 and lbls[0] is 0:   #predict
        labels = torch.zeros(1,1).long()
        return imgs, labels    
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

In [None]:
class SequenceList(TextList):    
    def __init__(self, items:Iterator, vocab:Vocab, tokenizer:Tokenizer, **kwargs):
        toknizr = Tokenizer(tok_func=tokenizer, pre_rules=[], post_rules=[], special_cases=[])
        procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False), NumericalizeProcessor(vocab=vocab)]
        super().__init__(items, vocab, sep='', pad_idx=0, processor=procs)
    
    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'<s>\s*', '', text)    #bos (w/ following whitespace)
    text = re.sub(r'\s*</s>]', '', text)  #eos (w/ preceding whitespace)
    return text

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return bert_tok.tokenize(t) + ["[SEP]"]

class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab
        
    def process_one(self,item):
        return np.array(self.vocab.numericalize(item), dtype=np.int64)
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])    
        

class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        w = self.items[i]
        return Text(w, self.vocab.textify(w))
    
    def reconstruct(self, t:Tensor):
        idx_min,idx_max = (t != self.pad_idx).nonzero().min(), (t != self.pad_idx).nonzero().max()
        return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

# DataBunch

In [None]:
class CustomVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

In [None]:
class CustomTokenizer(BaseTokenizer):
    "Split words but keep original spacing"
    def tokenizer(self, t:str) -> List[str]: 
        chars = list(t)
        res = []
        tok = ""
        for c in chars:
            if c.isalnum():
                tok+=c
            else:
                if tok.isalnum(): res.append(tok) 
                res.append(c)
                tok = ""
        if tok.isalnum(): res.append(tok)
        return ['xxbos'] + res + ['xxeos']

In [None]:
def label_text(pred, sep=''):
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)] #[:-1]  #remove eos token
    return sep.join([itos[i] for i in nonzero])

In [None]:
itos = pickle.load(open(PATH/'combo_itos_60k.pkl', 'rb'))
itos = itos[:30000]
vocab = CustomVocab(itos)

### words

In [None]:
toknizr = Tokenizer(tok_func=CustomTokenizer, pre_rules=[rm_useless_spaces],
                    special_cases=['xxbos','xxeos','xxmask','xxunk','xxpad','xxmaj','xxup','\n'])

### characters

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters. (Keep special modifiers: xxmaj, xxup)"
    res = []
    for t in x:
        [res.append(c) for c in list(t)]
    return res

toknizr = Tokenizer(tok_func=CustomTokenizer, pre_rules=[rm_useless_spaces],
                    post_rules=[replace_all_caps, deal_caps, characterize],
                    special_cases=['xxbos','xxeos','xxmask','xxunk','xxpad','xxmaj','xxup'])

### Databunch

In [None]:
CSV = 'wiki2.csv' #'wiki103_imdb.csv' #'wiki103.csv' #'wiki2.csv'
data = TextLMDataBunch.from_csv(PATH, CSV, tokenizer=toknizr, vocab=vocab, include_bos=False, device=device)

In [None]:
data.show_batch()

# Metrics

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, ignore_index=-1):
        super().__init__()
        self.name = 'cer'
        self.ignore_index = ignore_index

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        mask = (last_target!=self.ignore_index)
        last_output = last_output[mask]
        last_target = last_target[mask]
        error,size = cer(last_output, last_target)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

def cer(preds, targs):
    res = torch.argmax(preds, dim=-1)
    p = label_text(res)   #.replace(' ', '')
    t = label_text(targs) #.replace(' ', '')
    return Lev.distance(t, p)/len(t), 1

# def cer(preds, targs):
#     bs = targs.size(0)
#     res = torch.argmax(preds, dim=-1)
#     error = 0
#     for i in range(bs):
#         p = label_text(res[i])   #.replace(' ', '')
#         t = label_text(targs[i]) #.replace(' ', '')
#         error += Lev.distance(t, p)/len(t)
#     return error, bs

# v1 ULMFit

In [None]:
config = dict(emb_sz=512, n_hid=1400, n_layers=3, pad_token=0, qrnn=False, bidir=False, output_p=0.2,
              hidden_p=0.2, input_p=0.5, embed_p=0.1, weight_p=0.4, tie_weights=True, out_bias=True)

learn = language_model_learner(data, AWD_LSTM, config=config, drop_mult=0.5,
                               pretrained=False, metrics=[accuracy, CER()])

In [None]:
# true number of trainable params
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Total trainable params: 30,378,720

# v1 Transformer

## Transformer Modules

In [None]:
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)    
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

## Architecture

In [None]:
class TransformerLM(Module):
    def __init__(self, vocab, d_model=512, N=4, drops=0.2, attn_type='multi', attn_heads=8):        
        if attn_type=='multi':
            attn = MultiHeadedAttention(d_model, attn_heads)
        else:
            attn = SingleHeadedAttention(d_model)
        ff = PositionwiseFeedForward(d_model, drops)

        self.tgt_embed = nn.Sequential(Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000))
        self.encoder = Encoder(EncoderLayer(d_model, attn, ff, drops), N)
        
    def reset(self): pass

    def forward(self, x):
        bs,x_len = x.size()
        mask = self.subsequent_mask(x_len)
        inp = self.encoder(self.tgt_embed(x), mask=mask)
        return ([inp],[inp]) #For the LinearDecoder
    
    def subsequent_mask(self, size):
#         return torch.tril(torch.ones((size,size), device=device)).byte()[None]  # original...
        return None #MLM

        # only next output is masked
#         upper = torch.triu(torch.ones((size,size), device=device), diagonal=1).bool()
#         lower = torch.tril(torch.ones((size,size), device=device), diagonal=1).bool()
#         mask = upper & lower
#         return (~mask).byte()[None]

In [None]:
def init_tfmr_lm(m):
    for p in m.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
            
#     classname = m.__class__.__name__
#     if classname.find('Linear') != -1:
#         if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02)
#         if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)
#     elif classname.find('LayerNorm') != -1:
#         if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02)
#         if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)

In [None]:
def make_learner(data, d_model=512, N=6, drops=0.2, attn_type='multi', attn_heads=8, **learn_kwargs):
    vocab_sz = len(data.vocab.itos)
    encoder = TransformerLM(vocab_sz, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    decoder = LinearDecoder(vocab_sz, d_model, drops, tie_encoder=encoder.tgt_embed[0].lut, bias=True)
    model = SequentialRNN(encoder, decoder)
    model.apply(init_tfmr_lm)
    return LanguageLearner(data, model, **learn_kwargs)

In [None]:
learn = make_learner(data, 512, 6, metrics=[accuracy, CER()])

In [None]:
# true number of trainable params
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Total trainable params: 12,659,808

# TransformerXL

In [None]:
def init_transformer(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02)
        if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)
    elif classname.find('LayerNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02)
        if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)
    elif classname.find('TransformerXL') != -1:
        if hasattr(m, 'u'): nn.init.normal_(m.u, 0., 0.02)
        if hasattr(m, 'v'): nn.init.normal_(m.v, 0., 0.02)

## fastai

In [None]:
config = dict(ctx_len=150, n_layers=6, n_heads=8, d_model=512, d_head=64, d_inner=2048, resid_p=0.1, attn_p=0.1,
              ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.GeLU, double_drop=True,
              tie_weights=True, out_bias=True, init=init_transformer, mem_len=150, mask=True)

learn = language_model_learner(data, TransformerXL, config=config, drop_mult=1,
                               pretrained=False, metrics=[accuracy, CER()], callback_fns=[MLM_Mask])

In [None]:
# true number of trainable params
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Total trainable params: 30,378,720

## manual

In [None]:
class PositionalEncoding(Module):
    "Encode the position with a sinusoid."
    def __init__(self, d_model:int):
        self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d_model, 2.) / d_model)))

    def forward(self, pos:Tensor):
        inp = torch.ger(pos, self.freq)
        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
        return enc

In [None]:
class FeedForward(Module):
    def __init__(self, d_model:int, drops=0.1):
        self.core = nn.Sequential(
            nn.Linear(d_model, d_model*4), nn.ReLU(inplace=True), nn.Dropout(drops),
            nn.Linear(d_model*4, d_model), nn.Dropout(drops)
        )
        self.ln = nn.LayerNorm(d_model)
        
    def forward(self, x):
        out = self.core(x)
        return self.ln(x + out)

In [None]:
class MultiHeadRelativeAttention(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1, bias=False):
        d_head = d_model//n_heads
        self.n_heads, self.d_head = n_heads, d_head
        self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
        self.drop_att,self.drop_res = nn.Dropout(drops),nn.Dropout(drops)
        self.ln = nn.LayerNorm(d_model)
        self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ln(x + self.drop_res(self.out(self._mhra(x, mask=mask, **kwargs))))

    def _mhra(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
        #Notations from the paper:
        #x: input, r: vector of relative distance between two elements
        #u,v: learnable parameters of the model common between layers
        #mask: to avoid cheating, mem: previous hidden states
                
        #x: [bs, sl, d_model]
        #u/v: [n_heads, 1, d_head]
        #r: [sl, d_model]
        #mem: 1st:[0]; 2nd:[bs, sl, d_model]; nth: sl*i-1 up to mem_len
        bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
        context = x if mem is None else torch.cat([mem, x], dim=1)
        # after 1st iteration: mem => 
        wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
        wq = wq[:,-x_len:]
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        # [bs, sl, n_heads, d_head]
        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)   #wk: transposed(-2,-1)
        wkr = self.r_attn(r)
        wkr = wkr.view(seq_len, self.n_heads, self.d_head)
        wkr = wkr.permute(1,2,0)  #transposed ala wk w/out bs
        #### compute attention score (AC is (a) + (c) and BD is (b) + (d) in the paper)
        AC = torch.matmul(wq+u,wk)
        BD = _line_shift(torch.matmul(wq+v, wkr))
        attn_score = (AC + BD).div_(self.d_head ** 0.5)  #scale
        if mask is not None:
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
        attn_vec = torch.matmul(attn_prob, wv)
        return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
    
def _line_shift(x:Tensor):
    "Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal."
    bs,nh,n,p = x.size()
    x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)
    x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)
    return x_shift

In [None]:
class DecoderLayer(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1):
        self.mhra = MultiHeadRelativeAttention(n_heads, d_model, drops=drops)
        self.ff   = FeedForward(d_model, drops=drops)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ff(self.mhra(x, mask=mask, **kwargs))

In [None]:
class TransformerXL(Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, vocab_sz:int, d_model:int, n_layers:int, n_heads:int, drops:float=0.1, mem_len:int=150):
        d_head = d_model//n_heads
        self.encoder = nn.Embedding(vocab_sz, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.drop_emb = nn.Dropout(drops)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.mem_len,self.n_layers,self.d_model = mem_len,n_layers,d_model
        self.init = False
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, drops) for k in range(n_layers)])

    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()

    def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]

    def forward(self, x):
        #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
        if self.mem_len > 0 and not self.init:
            self.reset()
            self.init = True
        bs,x_len = x.size()
        inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        
        mask = None  #MLM
#         upper = torch.tril(x.new_ones(x_len, seq_len), diagonal=1+m_len).bool()[None,None]
#         mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).bool()[None,None]  # regular tfmrXL
#         mask = upper & mask
        
        hids = []
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype) #[len, len-1, len-2, len-3,...]
        pos_enc = self.pos_enc(pos)
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i] if self.mem_len > 0 else None
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        if self.mem_len > 0 : self._update_mems(hids)
        return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]

In [None]:
vocab_sz = len(data.vocab.itos)
encoder = TransformerXL(vocab_sz, 512, 4, n_heads=8, drops=0.1)
decoder = LinearDecoder(vocab_sz, 512, 0.1, tie_encoder=encoder.encoder, bias=True)
model = SequentialRNN(encoder, decoder)
model.apply(init_transformer)
learn = LanguageLearner(data, model, split_func=tfmerXL_lm_split, metrics=[accuracy, CER()])

In [None]:
# true number of trainable params
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Total trainable params: 20,525,152

# BeRT

In [None]:
class MLM_Mask(LearnerCallback):
    def __init__(self, learn:Learner, mlm_probability=0.3, mask_tok=7):
        super().__init__(learn)
        self.mask_tok = mask_tok
        self.mlm_probability = mlm_probability
        self.itos = learn.data.vocab.itos
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        new_input,new_target = self.mask_tokens(last_input)
        return {'last_input':new_input, 'last_target':new_target}
    
    def mask_tokens(self, inputs):
        """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training
        masked_indices = torch.bernoulli(torch.full(labels.shape, self.mlm_probability)).bool()
        labels[~masked_indices] = -1  # We only compute loss on masked tokens
        
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.mask_tok

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.itos), labels.shape, dtype=torch.long, device=device)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

In [None]:
def bert_acc(input:Tensor, targs:Tensor, ignore_index=-1)->Rank0Tensor:
    mask = (targs!=ignore_index)
    preds = input.argmax(dim=-1)
    return (preds[mask]==targs[mask]).float().mean()

In [None]:
# tfmr w/ MLM
learn = make_learner(data, 512, 6, metrics=[bert_acc, CER()],
                     callback_fns=[MLM_Mask], loss_func=CrossEntropyFlat(ignore_index=-1))

In [None]:
# tfmrXL w/ MLM
vocab_sz = len(data.vocab.itos)
encoder = TransformerXL(vocab_sz, 512, 6, n_heads=8, drops=0.1, mem_len=30)
decoder = LinearDecoder(vocab_sz, 512, 0.1, tie_encoder=encoder.encoder, bias=True)
model = SequentialRNN(encoder, decoder)
model.apply(init_transformer)
learn = LanguageLearner(data, model, split_func=tfmerXL_lm_split, metrics=[bert_acc, CER()],
                        callback_fns=[MLM_Mask], loss_func=CrossEntropyFlat(ignore_index=-1))

# Train

In [None]:
learn.load('wiki2_bert_tfmr')#, strict=False)
None

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
lr=1e-3
learn.fit_one_cycle(5, lr, callbacks=[SaveModelCallback(learn, name='wiki2_bert_tfmr8')])
### Words: combo_60k ###
# wiki2

# MLM(.15),N=4,multi(8)
# 2.953938	2.823303	0.088051	0.907012	03:24   tfmr  'wiki2_lm_bert'
# 2.149526	2.129630	0.101498	0.869657	03:34   tfmrXL  'wiki2_lm_bertXL'
# 0.204440	0.175675	0.979842	0.041045	04:13   diagonal+1  'wiki2_eye_tfmrXL'
# 2.724019	2.668955	0.092031	0.874019	03:43   tfmrXL pretrained on wiki2_eye_tfmrXL

# 5cycle(1e-3); 30k; N:6
# 2.815044	2.807196	0.537950	0.632317	02:50   vanilla tfmr,  'wiki2_bert_tfmr'
# 2.558876	2.362679	0.095183	0.885061	02:50   preload wiki2_bert_tfmr w/ MLM  'wiki2_bert_tfmr2'
# 2.436609	2.309809	0.188842	0.819951	02:50   preload wiki2_bert_tfmr2 w/ MLM(.30)  'wiki2_bert_tfmr3'
# 3.051376	2.938654	0.280082	0.772065	02:51   MLM(.50)  'wiki2_bert_tfmr4'
# 2.513130	2.432869	0.185478	0.823640	02:51   preload wiki2_bert_tfmr4 w/ MLM(.30)  'wiki2_bert_tfmr5'

# fixed acc & CER
# 3.077021	2.917398	0.574049	0.593222	02:50   tfmr MLM(.15)
# 2.528980	2.358310	0.629164	0.541393	02:51   preload wiki2_bert_tfmr w/ MLM(.15)  'wiki2_bert_tfmr6'
# 2.575874	2.412832	0.614977	0.549122	02:54   preload wiki2_bert_tfmr w/ MLM(.30)  'wiki2_bert_tfmr7'
# 2.165149	2.158344	0.659073	0.503470	03:06   tfmrXL MLM(.15)   'wiki2_bert_tfmr8'

### Chars ###
# wiki2: 10cycle, 1e-3

# AWD-LSTM
# 1.144401	1.092210	0.671697	0.327961   (512/1400) 'wiki2_lm'

# Tfmr 
# 1.432683	1.367244	0.594961	0.405437   N=4,multi(8) 'wiki2_lm_tfmr'
# 1.371769	1.318286	0.607220	0.394283   N=6,multi(8) 'wiki2_lm_tfmr2'

# TfmrXL
# 1.174197	1.160675	0.652781	0.349375   N=6,multi(8)  'wiki2_lm_tfmrXL'
# 1.168891	1.129729	0.659086	0.343158   5cycle,1e-6   'wiki2_lm_tfmrXL_v2'
# 1.171422	1.126684	0.660261	0.342019   pretrained on above; N=10, 3cycle,1e-4   wiki2_lm_tfmrXL_v3
# 1.204486	1.190764	0.645694	0.356218   N=10,multi(8)   wiki2_lm_tfmrXL_v4

# 1cycle, 1e-3
# 1.466067	1.401443	0.587846	0.413188   fastai TransformerXL
# 1.462828	1.403346	0.587787	0.413295   manual TransformerXL

# wiki103: 3cycle, 1e-4, (stopped after 1st cycle)
# 1.216455	1.101982	0.666687	0.334717   'wiki103_lm'


In [None]:
x,y = next(iter(data.valid_dl))

In [None]:
x[2],y[2]

In [None]:
preds = learn.model(x[2][None])

In [None]:
pred = torch.argmax(preds[1][0][0], dim=-1)

In [None]:
pred

In [None]:
[itos[word.item()] for word in pred]

In [None]:
def predict(self:learn, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',
            decoder=decode_spec_tokens):
    "Return `text` and the `n_words` that come after"
    self.model.reset()
    xb,yb = self.data.one_item(text)
    
    # remove the eos token which is automatically added
    xb = xb[:,:-1]
    print(xb)

    
    new_idx = []
    for _ in range(n_words): #progress_bar(range(n_words), leave=False):
        res = self.pred_batch(batch=(xb,yb))[0][-1]
        #if len(new_idx) == 0: self.model[0].select_hidden([0])
        if no_unk: res[self.data.vocab.stoi[UNK]] = 0.
        if min_p is not None:
            if (res >= min_p).float().sum() == 0:
                warn(f"There is no item with probability >= {min_p}, try a lower value.")
            else: res[res < min_p] = 0.
        if temperature != 1.: res.pow_(1 / temperature)
        idx = torch.multinomial(res, 1).item()
        new_idx.append(idx)
        xb = xb.new_tensor([idx])[None]
    return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None)))

In [None]:
predict(learn, "This is a wonderful", n_words=3, sep='')

In [None]:
learn.save_encoder('wiki2_lm_enc')

# Test

In [None]:
learner.model.eval()
learner.model.training

In [None]:
def next_with_creativity(preds, k=5, thresh=.05):
    probs, idxs = torch.topk(F.softmax(preds, dim=-1), k, dim=-1)
    d = {itos[k]: round(v.item(), 3) for k,v in zip(idxs,probs)}
    print(d)
    
    seq = np.array([], dtype=np.long)
    for p,i in zip(probs,idxs):
        num = int(p * 100)
        seq = np.append(seq, [i.item()] * num)
    
    return random.choice(seq.flatten())
    
#     return{k:v if v>=thresh else None for k,v in d}
#     mask = [probs >= thresh] 
#     m_probs, m_idxs = probs[mask], idxs[mask]
    
#     if len(m_idxs) > 0:
#         # simple weighted choice
#         seq = 
#         random.choice(seq)
#         idx = random.randint(0,len(m_idxs))
#         return m_idxs[idx]
#     else:
#         return idxs[0]

In [None]:
def get_next(inp):
    idxs = T(np.array([stoi[c] for c in inp])).unsqueeze(0)
    p = learner.model(Variable(idxs))
#     i = torch.argmax(p[0][-1], dim=-1)
#     i = torch.multinomial(p[0].exp(), 1)[-1]
    i = next_with_creativity(p[0][-1])
    return itos[i.item()]

In [None]:
get_next('whe')

In [None]:
def get_next_n(inp, n):
    res = inp
    for i in range(n):
        c = get_next(res)
        res += c
    return res

In [None]:
get_next_n('th', 10)