# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
# from fastai import *
from fastai.text import *
# from fastai.callbacks.tracker import *
import pdb
import textwrap
import sentencepiece as spm

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Setup Data

In [None]:
def remove_newline_spaces(x):
    x = x.replace(' \n','\n').replace('\n ','\n')
    return re.sub(r'(\n)+','\n',x)

def remove_equals(x):
    return x.replace(' =', '').replace('= ', '')

# convert spaced out " strings " to "strings"
def despace_quotes(m):
    m = m.group(0)   # entire matched string
    return m.replace('" ','"').replace(' "','"')

def cleanup(x):
    x = fix_html(x)
    x = remove_newline_spaces(x)
    x = remove_equals(x)
    x = x.replace( " \'", "\'").replace(' ,', ',').replace(' .', '.').replace(' :', ':').replace(
        ' ;', ';').replace('( ', '(').replace(' )', ')').replace('[ ', '[').replace(' ]', ']').replace(
        ' %', '%').replace('$ ','$').replace('# ','#')
    x = re.sub(r'\"(.+?)\"', despace_quotes, x)
    x = re.sub(r'[^\x00-\x7F]+','', x)   # remove all non ascii characters
    return x

## WikiText

In [None]:
# wiki_path = Path('data/wikitext/wikitext-2-raw')
wiki_path = Path('data/wikitext/wikitext-103-raw')

In [None]:
with open(wiki_path/'wiki.train.raw') as file:  
    trn = file.read()
with open(wiki_path/'wiki.valid.raw') as file:  
    val = file.read()
with open(wiki_path/'wiki.test.raw') as file:  
    tst = file.read()

In [None]:
len(trn)
# 2:    10918892
# 103: 539566975

In [None]:
full = trn + val + tst
# len(trn)

In [None]:
full = cleanup(full)

In [None]:
full[0:5000]

In [None]:
lines = full.split('\n')
len(lines)

In [None]:
wiki_df = DataFrame({'text': lines[1:-1]})
wiki_df.head()

In [None]:
wiki_df.to_csv(PATH/'clean_wiki_103.csv', index=False)

## IMDB

In [None]:
imdb_path = untar_data(URLs.IMDB)
imdb_path

In [None]:
CSV = 'texts.csv'
imdb_df = pd.read_csv(imdb_path/CSV)
len(imdb_df)

In [None]:
imdb_df['text'] = imdb_df.text.apply(lambda x: cleanup(x))

In [None]:
len(imdb_df)

In [None]:
imdb_df.text.values[0]

In [None]:
imdb_df.to_csv(PATH/'clean_imdb.csv', index=False)

## Book texts

In [None]:
FPATH = Path('data/fonts/')

a = pd.read_csv(FPATH/'knot.csv', usecols=['filename', 'label'])
c = pd.read_csv(FPATH/'adrift.csv', usecols=['filename', 'label'])
d = pd.read_csv(FPATH/'zane.csv', usecols=['filename', 'label'])
e = pd.read_csv(FPATH/'american.csv', usecols=['filename', 'label'])
f = pd.read_csv(FPATH/'age.csv', usecols=['filename', 'label'])
g = pd.read_csv(FPATH/'room.csv', usecols=['filename', 'label'])
book_df = pd.concat([a,c,d,e,f,g], ignore_index=True)
len(book_df)

In [None]:
book_df['text'] = book_df.label.apply(lambda x: x.replace('\n',' '))

In [None]:
book_df.head()

In [None]:
book_df.to_csv(PATH/'clean_books.csv', columns=['text'], index=False)

## Paragraphs (w/out test)

In [None]:
pg_df = pd.read_csv(PATH/'edited_pg.csv')
pg_df.head()

In [None]:
pg_df['text'] = pg_df.text.apply(lambda x: x.replace('\n',' '))

In [None]:
pg_df.to_csv(PATH/'clean_pg.csv', columns=['text'], index=False)

## Combine and process (remove caps)

In [None]:
a = pd.read_csv(PATH/'clean_pg.csv')
b = pd.read_csv(PATH/'clean_books.csv')
c = pd.read_csv(PATH/'clean_imdb.csv')
d = pd.read_csv(PATH/'clean_wiki_103.csv')

In [None]:
c.text.values[300]

In [None]:
full = pd.concat([a, b, c, d], ignore_index=True)

In [None]:
full.dropna(inplace=True)
full.reset_index(inplace=True, drop=True)
len(full)

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

In [None]:
full['text'] = full.text.apply(lambda x: rm_useless_spaces(x))

In [None]:
full['text'] = full.text.apply(lambda x: add_cap_tokens(x))

In [None]:
full.tail()

## Write to raw

In [None]:
name = str(PATH/'spm_full')
# name = str(PATH/'spm_pg')
fname = name + '.txt'

In [None]:
# puts dataset into format expected by sentencepiece:
# .txt file entries separated by \n
def write_text(texts, filename=fname):
    with open(filename, 'w', encoding='utf-8') as f:
        for text in texts:
            f.write(text + "\n")

In [None]:
write_text(full.text.values)
# write_text(pg_df.text.values)

# Train SPM

In [None]:
name = str(PATH/'spm_full')
fname = name + '.txt'

In [None]:
symbols = "\n,[UP],[MAJ],▁,:,;,!,?,(,),[,],{,},<,>,@,#,$,%,^,&,*,-,_,+,=,/,~"

In [None]:
spm.SentencePieceTrainer.Train(
    f"--unk_id=3 --pad_id=0 --input={fname} --model_prefix={name+'_30k'} --vocab_size=30000 --user_defined_symbols={symbols} --input_sentence_size=1500000 --shuffle_input_sentence=True"
)

In [None]:
sp = spm.SentencePieceProcessor()
sp.Load(name+'_30k.model')

In [None]:
df = pd.read_csv(fname, sep='\n', header=None, names=['text'])
st = df.text.values[-2]; st

In [None]:
pieces = sp.encode_as_pieces(st)

In [None]:
st2 = sp.decode_pieces(pieces[1:])

In [None]:
st == st2

In [None]:
pieces

In [None]:
sp.EncodeAsPieces('adaptability')

In [None]:
for n in range(5):
    print(sp.SampleEncodeAsPieces('adaptability', -1, .1))

In [None]:
vocab = {i: sp.id_to_piece(i) for i in range(len(sp))}
vocab

# DataBunch

In [None]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load(str(PATH/'spm_full_10k.model'))
sp.SetEncodeExtraOptions("eos")
sp.SetDecodeExtraOptions("bos:eos")

In [None]:
# itos = {i:sp.id_to_piece(i) for i in range(len(sp))}

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

In [None]:
class SPMTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[int]: return [1] + sp.EncodeAsIds(t)[1:]  #remove initial space

class SPMProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=SPMTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens

In [None]:
class SPMTextList(TextList):
    _bunch = TextLMDataBunch   # this is the databunch created when calling .databunch()
    _processor = []
    _is_lm = True

    def __init__(self, items:Iterator, **kwargs):
        super().__init__(items, **kwargs)
        self.vocab = sp
        self.pad_idx = 0
        self.copy_new += ['vocab']
    
    def get(self, i): return Text(i, self.textify(i))
    
    def reconstruct(self, x:Tensor):
        nonzero_idxs = (x != self.pad_idx).nonzero()
        idx_max = nonzero_idxs.max() if len(nonzero_idxs) > 0 else 0
        return Text(x[0:idx_max+1], self.textify(x[0:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)
        
    def textify(self, ids):
        if isinstance(ids, torch.Tensor): ids = ids.tolist()
        st = self.vocab.DecodeIds(ids)
        st = remove_cap_tokens(st)
        return st

In [None]:
class LMLabelList(EmptyLabelList):
    "Basic `ItemList` for dummy labels."
    def __init__(self, items:Iterator, **kwargs):
        super().__init__(items, **kwargs)
        self.loss_func = CrossEntropyFlat()
        
    def reconstruct(self, t:Tensor, x:Tensor=None):
        if len(t.size()) == 0: return EmptyLabel()
        return self.x.reconstruct(t)

In [None]:
CSV = 'clean_pg.csv'

In [None]:
data = (SPMTextList.from_csv(PATH, CSV, cols=0, processor=SPMProcessor())
        .split_by_rand_pct(valid_pct=0.10, seed=42)
        .label_const(0, label_cls=LMLabelList)
        .databunch(bs=64, device=device)
       )

In [None]:
data.show_batch(rows=5, ds_type=DatasetType.Train)

# Metrics

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, fn, ignore_index=-1):
        super().__init__()
        self.name = 'cer'
        self.ignore_index = ignore_index
        self.fn = fn

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        mask = (last_target!=self.ignore_index)
        last_output = last_output[mask]
        last_target = last_target[mask]
        error,size = cer(last_output, last_target, self.fn)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

def cer(preds, targs, fn):
    res = torch.argmax(preds, dim=-1)
    p = fn(res)   #.replace(' ', '')
    t = fn(targs) #.replace(' ', '')
    return Lev.distance(t, p)/len(t), 1

In [None]:
def bert_acc(input:Tensor, targs:Tensor, ignore_index=-1)->Rank0Tensor:
    mask = (targs!=ignore_index)
    preds = input.argmax(dim=-1)
    return (preds[mask]==targs[mask]).float().mean()

In [None]:
class MLM_Mask(LearnerCallback):
    def __init__(self, learn:Learner, mlm_probability=0.3, mask_tok=7):
        super().__init__(learn)
        self.mask_tok = mask_tok
        self.mlm_probability = mlm_probability
        self.vocab_sz = len(learn.data.vocab)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        new_input,new_target = self.mask_tokens(last_input)
        return {'last_input':new_input, 'last_target':new_target}
    
    def mask_tokens(self, inputs):
        """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training
        masked_indices = torch.bernoulli(torch.full(labels.shape, self.mlm_probability)).bool()
        labels[~masked_indices] = -1  # We only compute loss on masked tokens
        
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.mask_tok

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(self.vocab_sz, labels.shape, dtype=torch.long, device=device)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

# BeRT

## Transformer Modules

In [None]:
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)    
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

## Architecture

In [None]:
class TransformerLM(Module):
    def __init__(self, vocab, d_model=512, N=4, drops=0.2, attn_heads=8):        
        attn = MultiHeadedAttention(d_model, attn_heads)
        ff = PositionwiseFeedForward(d_model, drops)
        self.embedding = Embeddings(d_model, vocab)
        self.pos_enc = PositionalEncoding(d_model, drops)
        self.encoder = Encoder(EncoderLayer(d_model, attn, ff, drops), N)
        
        self.generator = nn.Linear(d_model, vocab)
        
    def forward(self, x):
        return self.decode(self.encode(x))
    
    def encode(self, x):
        return self.encoder(self.pos_enc(self.embedding(x)))
        
    def decode(self, x):
        return self.generator(x)

In [None]:
def init_tfmr_lm(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02)
        if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)
    elif classname.find('LayerNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02)
        if hasattr(m, 'bias') and m.bias is not None:     nn.init.constant_(m.bias, 0.)

In [None]:
def make_learner(data, d_model=512, N=4, drops=0.1, attn_heads=8, tie_weights=True, **learn_kwargs):
    vocab_sz = len(sp)
    model = TransformerLM(vocab_sz, d_model, N=N, drops=drops, attn_heads=attn_heads)
    model.apply(init_tfmr_lm)
    if tie_weights: model.generator.weight = model.embedding.lut.weight
    return Learner(data, model, **learn_kwargs)

In [None]:
learn = make_learner(data, 512, 4, metrics=[bert_acc, CER(data.x.textify)],
                     callback_fns=[MLM_Mask], loss_func=CrossEntropyFlat(ignore_index=-1))

In [None]:
# true number of trainable params
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Total trainable params: 22,860,560

# Train

In [None]:
learn.load('wiki2_bert_tfmr')#, strict=False)
None

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
lr=1e-2
learn.fit_one_cycle(5, lr)#, callbacks=[SaveModelCallback(learn, name='pg_bert_tfmr')])
#clean_pg.csv

# 6.215954	4.791131	0.426439	0.768457	00:05

In [None]:
lr=1e-2
learn.fit_one_cycle(5, lr)#, callbacks=[SaveModelCallback(learn, name='pg_bert_tfmr')])
### Words: combo_60k ###
# wiki2

# MLM(.15),N=4,multi(8)
# 2.953938	2.823303	0.088051	0.907012	03:24   tfmr  'wiki2_lm_bert'
# 2.149526	2.129630	0.101498	0.869657	03:34   tfmrXL  'wiki2_lm_bertXL'
# 0.204440	0.175675	0.979842	0.041045	04:13   diagonal+1  'wiki2_eye_tfmrXL'
# 2.724019	2.668955	0.092031	0.874019	03:43   tfmrXL pretrained on wiki2_eye_tfmrXL

# 5cycle(1e-3); 30k; N:6
# 2.815044	2.807196	0.537950	0.632317	02:50   vanilla tfmr,  'wiki2_bert_tfmr'
# 2.558876	2.362679	0.095183	0.885061	02:50   preload wiki2_bert_tfmr w/ MLM  'wiki2_bert_tfmr2'
# 2.436609	2.309809	0.188842	0.819951	02:50   preload wiki2_bert_tfmr2 w/ MLM(.30)  'wiki2_bert_tfmr3'
# 3.051376	2.938654	0.280082	0.772065	02:51   MLM(.50)  'wiki2_bert_tfmr4'
# 2.513130	2.432869	0.185478	0.823640	02:51   preload wiki2_bert_tfmr4 w/ MLM(.30)  'wiki2_bert_tfmr5'

# fixed acc & CER
# 3.077021	2.917398	0.574049	0.593222	02:50   tfmr MLM(.15)
# 2.528980	2.358310	0.629164	0.541393	02:51   preload wiki2_bert_tfmr w/ MLM(.15)  'wiki2_bert_tfmr6'
# 2.575874	2.412832	0.614977	0.549122	02:54   preload wiki2_bert_tfmr w/ MLM(.30)  'wiki2_bert_tfmr7'
# 2.165149	2.158344	0.659073	0.503470	03:06   tfmrXL MLM(.15)   'wiki2_bert_tfmr8'

### Chars ###
# wiki2: 10cycle, 1e-3

# AWD-LSTM
# 1.144401	1.092210	0.671697	0.327961   (512/1400) 'wiki2_lm'

# Tfmr 
# 1.432683	1.367244	0.594961	0.405437   N=4,multi(8) 'wiki2_lm_tfmr'
# 1.371769	1.318286	0.607220	0.394283   N=6,multi(8) 'wiki2_lm_tfmr2'

# TfmrXL
# 1.174197	1.160675	0.652781	0.349375   N=6,multi(8)  'wiki2_lm_tfmrXL'
# 1.168891	1.129729	0.659086	0.343158   5cycle,1e-6   'wiki2_lm_tfmrXL_v2'
# 1.171422	1.126684	0.660261	0.342019   pretrained on above; N=10, 3cycle,1e-4   wiki2_lm_tfmrXL_v3
# 1.204486	1.190764	0.645694	0.356218   N=10,multi(8)   wiki2_lm_tfmrXL_v4

# 1cycle, 1e-3
# 1.466067	1.401443	0.587846	0.413188   fastai TransformerXL
# 1.462828	1.403346	0.587787	0.413295   manual TransformerXL

# wiki103: 3cycle, 1e-4, (stopped after 1st cycle)
# 1.216455	1.101982	0.666687	0.334717   'wiki103_lm'


# Test

In [None]:
x,y = next(iter(data.valid_dl))

In [None]:
x[2],y[2]

In [None]:
preds = learn.model(x[2][None])

In [None]:
pred = torch.argmax(preds[1][0][0], dim=-1)

In [None]:
pred

In [None]:
[itos[word.item()] for word in pred]

In [None]:
def predict(self:learn, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',
            decoder=decode_spec_tokens):
    "Return `text` and the `n_words` that come after"
    self.model.reset()
    xb,yb = self.data.one_item(text)
    
    # remove the eos token which is automatically added
    xb = xb[:,:-1]
    print(xb)

    
    new_idx = []
    for _ in range(n_words): #progress_bar(range(n_words), leave=False):
        res = self.pred_batch(batch=(xb,yb))[0][-1]
        #if len(new_idx) == 0: self.model[0].select_hidden([0])
        if no_unk: res[self.data.vocab.stoi[UNK]] = 0.
        if min_p is not None:
            if (res >= min_p).float().sum() == 0:
                warn(f"There is no item with probability >= {min_p}, try a lower value.")
            else: res[res < min_p] = 0.
        if temperature != 1.: res.pow_(1 / temperature)
        idx = torch.multinomial(res, 1).item()
        new_idx.append(idx)
        xb = xb.new_tensor([idx])[None]
    return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None)))

In [None]:
predict(learn, "This is a wonderful", n_words=3, sep='')

In [None]:
learn.save_encoder('wiki2_lm_enc')

In [None]:
learner.model.eval()
learner.model.training

In [None]:
def next_with_creativity(preds, k=5, thresh=.05):
    probs, idxs = torch.topk(F.softmax(preds, dim=-1), k, dim=-1)
    d = {itos[k]: round(v.item(), 3) for k,v in zip(idxs,probs)}
    print(d)
    
    seq = np.array([], dtype=np.long)
    for p,i in zip(probs,idxs):
        num = int(p * 100)
        seq = np.append(seq, [i.item()] * num)
    
    return random.choice(seq.flatten())
    
#     return{k:v if v>=thresh else None for k,v in d}
#     mask = [probs >= thresh] 
#     m_probs, m_idxs = probs[mask], idxs[mask]
    
#     if len(m_idxs) > 0:
#         # simple weighted choice
#         seq = 
#         random.choice(seq)
#         idx = random.randint(0,len(m_idxs))
#         return m_idxs[idx]
#     else:
#         return idxs[0]

In [None]:
def get_next(inp):
    idxs = T(np.array([stoi[c] for c in inp])).unsqueeze(0)
    p = learner.model(Variable(idxs))
#     i = torch.argmax(p[0][-1], dim=-1)
#     i = torch.multinomial(p[0].exp(), 1)[-1]
    i = next_with_creativity(p[0][-1])
    return itos[i.item()]

In [None]:
get_next('whe')

In [None]:
def get_next_n(inp, n):
    res = inp
    for i in range(n):
        c = get_next(res)
        res += c
    return res

In [None]:
get_next_n('th', 10)