# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks.tracker import *
from fastai.callbacks.hooks import *

import pdb

In [None]:
from transformers import DistilBertTokenizer, DistilBertForMaskedLM

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Helpers

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(im.data), alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((1,size,size), device=device).byte())
    
def parallelogram_mask(size, diagonal):
    mask = torch.ones((1,size,size), device=device).byte()
    upper = torch.tril(mask).bool()
    lower = torch.triu(mask, diagonal=-diagonal).bool()
    return (upper & lower).byte()

## Loss, Metrics, Callbacks

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = self.loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs
    
    def loss_prep(self, pred, target):
        "equalize input/target sl; combine bs/sl dimensions"
        bs,tsl = target.shape
        _ ,sl,vocab = pred.shape

        # F.pad( front,back for dimensions: 1,0,2 )
        if sl>tsl: target = F.pad(target, (0,sl-tsl))
        if tsl>sl: pred = F.pad(pred, (0,0,0,tsl-sl))

        targ = target.contiguous().view(-1).long()
        pred = pred.contiguous().view(-1, vocab)
        return pred, targ

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, fn):
        super().__init__()
        self.name = 'cer'
        self.fn = fn

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = cer(last_output, last_target, self.fn)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

In [None]:
def cer(preds, targs, fn):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = str(fn(res[i]))
        t = str(fn(targs[i]))
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

In [None]:
class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        return {'last_input':(last_input, last_target), 'last_target':last_target}

# Data

## paragraphs

In [None]:
fname = 'edited_pg.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 512,10
seq_len,word_len = 700,300

## sm synth dataset

In [None]:
fname = 'edited_sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'edited_sm_synth'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 256,50#100
seq_len,word_len = 100,50

## combo_cat

In [None]:
#font generated
fname = 'font_mix_129k.csv'
FOLDER = 'combo_cat'

In [None]:
#handwriting only
fname = 'hand_mix_25k.csv'
FOLDER = 'combo_cat'

In [None]:
CSV = PATH/fname
df = pd.read_csv(CSV)

sz,bs = 512,15
seq_len,word_len = 750,300
len(df)

# ModelData

In [None]:
tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=2, max_warp=0.1, max_lighting=0.5)

def force_gray(image): return image.convert('L').convert('RGB')

## Char or Word

In [None]:
bs = 16

In [None]:
def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1 and lbls[0] is 0:   #predict
        labels = torch.zeros(1,1).long()
        return imgs, labels    
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'\[CLS\]\s*', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'\s*\[SEP\]', '', text)  #[SEP] (w/ preceding whitespace)
    return text

def remove_wordpiece_toks(text):
    return re.sub(r'##', '', text)

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return bert_tok.tokenize(t) + ["[SEP]"]

class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab
        
    def process_one(self,item):
        return np.array(self.vocab.numericalize(item), dtype=np.int64)
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])    

### Words (bert_tok)

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens

In [None]:
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        w = self.items[i]
        return Text(w, self.vocab.textify(w))
    
    def reconstruct(self, t:Tensor):
        idx_min,idx_max = (t != self.pad_idx).nonzero().min(), (t != self.pad_idx).nonzero().max()
        return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Chars (bert_tok)

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters."
    res = []
    for t in x:
        if t in ['\n',' ','[UP]','[MAJ]','[UNK]', '[CLS]', '[MASK]', '[PAD]', '[SEP]']:
            res.append(t)
        elif t.startswith('##'):  # wordpiece
            [res.append(c) for c in list(t[2:])]
        else:
            [res.append(c) for c in list(t)]
    return res

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[characterize], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens

In [None]:
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        w = self.items[i]
        return Text(w, self.vocab.textify(w))
    
    def reconstruct(self, t:Tensor):
        idx_min,idx_max = (t != self.pad_idx).nonzero().min(), (t != self.pad_idx).nonzero().max()
        return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

## Char and Word (bert_tok)

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'\[CLS\]\s*', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'\s*\[SEP\]', '', text)  #[SEP] (w/ preceding whitespace)
    return text

def remove_wordpiece_toks(text):
    return re.sub(r'##', '', text)

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters."
    res = []
    for t in x:
        if t in ['\n',' ','[UP]','[MAJ]','[UNK]', '[CLS]', '[MASK]', '[PAD]', '[SEP]']:
            res.append(t)
        elif t.startswith('##'):  # wordpiece
            [res.append(c) for c in list(t[2:])]
        else:
            [res.append(c) for c in list(t)]
    return res

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]:
        w_toks = bert_tok.tokenize(t) + ["[SEP]"]
        c_toks = characterize(w_toks)
        return [c_toks, w_toks]

In [None]:
class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.tokenizer = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.tokenizer.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab
        
    def process_one(self,item):
        chars = np.array(self.vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.vocab.numericalize(item[1]), dtype=np.int64)
        return [chars,words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])    
        

class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        c,w = self.items[i]
        return [Text(c, self.vocab.textify(c)), Text(w, self.vocab.textify(w))]
    
    def reconstruct(self, t:Tensor):
        c,w = t
        return [self.reconstruct_one(c),self.reconstruct_one(w)]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)
    
    def reconstruct_one(self, x):
        nonzero_idxs = (x != self.pad_idx).nonzero()
        idx_min = 0  #(x != self.pad_idx).nonzero().min()
        idx_max = nonzero_idxs.max() if len(nonzero_idxs) > 0 else 0
        return Text(x[idx_min:idx_max+1], self.vocab.textify(x[idx_min:idx_max+1]))

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
# data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

## Char and Word (SentencePiece)

In [None]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load(str(PATH/'spm_full_10k.model'))
sp.SetEncodeExtraOptions("eos")
sp.SetDecodeExtraOptions("bos:eos")

In [None]:
itos = {}
for i in range(len(sp)):
    itos[i] = sp.id_to_piece(i)

c_itos={}
for k,v in itos.items():
    if k<7:
        c_itos[k] = [k]
    else:
        c_itos[k] = [sp.piece_to_id(c) for c in list(v)]

In [None]:
def characterize(t:Collection[int]) -> Collection[int]:
    return functools.reduce(operator.iconcat, [c_itos[c] for c in t], [])
    # flatten nested list - fastest
    # https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

In [None]:
class SPMTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[int]:
        w_toks = sp.EncodeAsIds(t)[1:]
        c_toks = characterize(w_toks)
        return [c_toks, w_toks]

class SPMProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=SPMTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens

In [None]:
class SPMMultiList(ItemList):
    _processor = [SPMProcessor]

    def __init__(self, items:Iterator, sp, **kwargs):
        super().__init__(items, **kwargs)
        self.vocab = sp
        self.pad_idx = 0
        self.copy_new += ['vocab']
    
    def get(self, i):
        c,w = self.items[i]
        return [Text(c, self.textify(c)), Text(w, self.textify(w))]
    
    def reconstruct(self, t:Tensor):
        c,w = t
        return [self.reconstruct_one(c),self.reconstruct_one(w)]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)
    
    def reconstruct_one(self, x):
        nonzero_idxs = (x != self.pad_idx).nonzero()
        idx_min = 0  #(x != self.pad_idx).nonzero().min()
        idx_max = nonzero_idxs.max() if len(nonzero_idxs) > 0 else 0
        return Text(x[idx_min:idx_max+1], self.textify(x[idx_min:idx_max+1]))
    
    def textify(self, ids):
        if isinstance(ids, torch.Tensor): ids = ids.tolist()
        st = self.vocab.DecodeIds(ids)
        st = remove_cap_tokens(st)
        return st

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SPMMultiList, sp=sp)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
#data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

# Transformer Modules

In [None]:
# LayerNorm = nn.LayerNorm
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training
# LayerNorm = partial(nn.BatchNorm2d, eps=1e-4)

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # acts as a weak LM
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Bert_tok architectures

## Word Arch

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x, spatial_postions=None):
#         if spatial_postions:
#             rows,cols = spatial_positions
#         else:
        rows,cols = self.encode_spatial_positions(x)
            
        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #32x32 : 256
        
    def forward(self, x):
        return self.base(x)

class Adaptor(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
                        
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        self.conv = conv_layer(em_sz,em_sz)
        self.lin = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.conv(self.pool(x))
        x = x.flatten(2,3).permute(0,2,1)
        x = self.lin(x).mul(8)
        return x

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder        
        self.w_decoder = decoder
        self.w_embed = embed
        self.generator = generator
    
    def forward(self, src, tgt):
        tgt = rshift(tgt, 101).long()    # CLS tok
        mask = parallelogram_mask(tgt.size(-1), 20)

        feats = self.encode(src)
        return self.w_decoder(self.w_embed(tgt), feats, mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(60,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, adaptor, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.adaptor = adaptor
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.adaptor(self.img_enc(src))
        outs = self.transformer(feats, tgt)
        return self.transformer.generate(outs)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    adaptor = Adaptor(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, adaptor, transformer)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)], callback_fns=[TeacherForce])
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

## Word Arch w/ MLM

Load above Word Arch modules plus the following:

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, adaptor, transformer, lm):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.adaptor = adaptor
        self.transformer = transformer
        self.lm = lm
        
    def forward(self, src, tgt):
        with torch.no_grad():
            feats = self.adaptor(self.img_enc(src))
            outs = self.transformer(feats, tgt)
            outs = self.transformer.generate(outs)
            preds = torch.argmax(outs, dim=-1)
        return self.lm(preds)[0]

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    adaptor = Adaptor(em_sz, d_model)
    lm = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
    lm.resize_token_embeddings(len(itos))
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, adaptor, transformer, lm)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)], callback_fns=[TeacherForce])
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

In [None]:
def greedy_decode(src, model, seq_len, kind='char', bos_tok=1):
    model.eval()
    tfmr = model.transformer
    img_enc = model.img_enc
    adaptor = model.adaptor
    lm = model.lm
    
    decoder = tfmr.c_decoder if kind=='char' else tfmr.w_decoder
    embed = tfmr.c_embed if kind=='char' else tfmr.w_embed
    p_num = 20
    
    with torch.no_grad():
        feats = tfmr.encode(adaptor(img_enc(src)))

        bs = src.size(0)
        tgt = torch.zeros((bs,1), dtype=torch.long, device=device) + bos_tok
 
        res = []
        for i in progress_bar(range(seq_len)):
#             mask = subsequent_mask(tgt.size(-1))
            mask = parallelogram_mask(tgt.size(-1), p_num)
            
            dec_outs = decoder(embed(tgt), feats, mask)
            prob = tfmr.generator(dec_outs[:,-1])
            res.append(prob)
            pred = torch.argmax(prob, dim=-1, keepdim=True)   #[bs,sl]
            if (pred==0).all(): break
            tgt = torch.cat([tgt,pred], dim=-1)
        out = lm(tgt)[0]
        #out = torch.stack(res).transpose(1,0).contiguous()

        return out
    
# def encode_spatial_positions(x, nl_tok=30522):
#     rows,cols = torch.zeros_like(x),torch.zeros_like(x)
#     for ii,batch in enumerate(x.unbind()):
#         nls = torch.nonzero(batch==nl_tok).flatten()
#         last = torch.nonzero(batch).flatten()[-1][None]
#         splits = torch.cat([nls,last])

#         p=0
#         for i,n in enumerate(splits, start=1):
#             rows[ii,p:n+1] = i
#             cols[ii,p:n+1] = torch.arange(1,n-p+2)
#             p = n+1
#     return rows,cols

## Char Arch

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      
        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = decoder
        self.c_embed = embed
        
        self.generator = generator
    
    def forward(self, src, tgt):
        tgt = rshift(tgt, 101).long()
        mask = parallelogram_mask(tgt.size(-1), 25)

        feats = self.encode(src)
        return self.c_decoder(self.c_embed(tgt), feats, mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #increases gradients on weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(100,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        outs = self.transformer(feats, tgt)
        return self.transformer.generate(outs)

In [None]:
def cer(preds, targs, fn):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = str(fn(res[i]))   #.replace(' ', '')
        t = str(fn(targs[i])) #.replace(' ', '')
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)],
                    callback_fns=[TeacherForce, BnFreeze, partial(AccumulateScheduler, n_step=10)]
                   )
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

## Combo Arch

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, vocab, rows, cols, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = nn.Embedding(vocab, d_model, 0)
        self.rows = nn.Embedding(15, d_model//2, 0)
        self.c_cols = nn.Embedding(cols[0], d_model//2, 0)
        self.w_cols = nn.Embedding(cols[1], d_model//2, 0)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, char, word):
        return self.encode_one(char, 'char'), self.encode_one(word, 'word')
    
    def encode_one(self, x, kind):
        rows,cols = self.encode_spatial_positions(x)
        
        x_cols = self.c_cols if kind=='char' else self.w_cols

        row_t = self.rows(rows)            
        col_t = x_cols(torch.clamp(cols, max=x_cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet18(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #32x32 : 256
        
    def forward(self, x):
        return self.base(x)

In [None]:
class Adaptor(nn.Module):
    def forward(self, x):
        x = x.flatten(2,3).permute(0,2,1)
        return x.mul(8)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, embeddings, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        
        self.embed = embeddings
        
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)
        #LPE
        #c_emb,w_emb = self.embed(c_tgt, w_tgt)
        #pos_enc
        c_emb = self.embed(c_tgt)
        w_emb = self.embed(w_tgt)

        feats = self.encoder(src)
        char_outs = self.c_decoder(c_emb, feats, c_mask)
        word_outs = self.w_decoder(w_emb, feats, w_mask)
        return char_outs, word_outs
    
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt, 101).long()
        w_tgt = rshift(w_tgt, 101).long()
        
#         c_mask = parallelogram_mask(c_tgt.size(-1), 20)   # char needs word context
#         w_mask = parallelogram_mask(w_tgt.size(-1), 20)   # word needs sentence context
        c_mask = subsequent_mask(c_tgt.size(-1)) 
        w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0, heads=8):
    c = deepcopy
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        #LearnedPositionalEmbeddings(d_model, vocab, rows=15, cols=[100,60]),   #LPE
        nn.Sequential( Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000) ),  #pos_enc
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, adaptor, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.adaptor = adaptor
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.adaptor(self.img_enc(src))
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.recon = learn.data.y.reconstruct_one

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.recon)
        w_error,_    = cer(w_out, w_targ, self.recon)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    adaptor = Adaptor()
    transformer = make_full_model(len(itos), d_model, N, drops, heads)
    net = Img2Seq(img_encoder, adaptor, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing),
                    callback_fns=[TeacherForce, MultiCER])
    return learn

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# SentencePiece Architectures

## Word Arch w/ LM

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = v_embed        
        self.rows = row_embed
        self.cols = col_embed
        
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, spatial_postions=None):
        rows,cols = self.encode_spatial_positions(x)
            
        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet18(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #32x32 : 256
        
    def forward(self, x):
        return self.base(x)

In [None]:
class Adaptor(nn.Module):
    def forward(self, x):
        x = x.flatten(2,3).permute(0,2,1)
        return x.mul(8)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, decoder, embeddings, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.w_decoder = decoder
        self.embed = embeddings
        self.generator = generator
            
    def forward(self, src, tgt):
        tgt = rshift(tgt, 1).long()
        mask = subsequent_mask(tgt.size(-1))
        return self.w_decoder(self.embed(tgt), self.encoder(src), mask)

    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(60,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, adaptor, transformer, lm):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.adaptor = adaptor
        self.transformer = transformer
        self.lm = lm
        
    def forward(self, src, tgt):
        feats = self.adaptor(self.img_enc(src))
        outs = self.transformer(feats, tgt)
        x = self.transformer.generate(outs)
        with torch.no_grad():
            x = torch.argmax(x, dim=-1)
            bunch = self.data.y.reconstruct(x)
            
        return self.lm(x)

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])
vocab_len = len(bert_tok)

bert_lm = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased")
bert_lm.resize_token_embeddings(vocab_len)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz)
    adaptor = Adaptor()
    transformer = make_full_model(len(data.vocab), d_model, N, drops, heads)
    net = Img2Seq(img_encoder, adaptor, transformer, lm)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)], callback_fns=[TeacherForce])
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

## Combo Arch

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, vocab, rows, cols, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = nn.Embedding(vocab, d_model, 0)
        self.rows = nn.Embedding(15, d_model//2, 0)
        self.c_cols = nn.Embedding(cols[0], d_model//2, 0)
        self.w_cols = nn.Embedding(cols[1], d_model//2, 0)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, char, word):
        return self.encode_one(char, 'char'), self.encode_one(word, 'word')
    
    def encode_one(self, x, kind):
        rows,cols = self.encode_spatial_positions(x)
        
        x_cols = self.c_cols if kind=='char' else self.w_cols

        row_t = self.rows(rows)            
        col_t = x_cols(torch.clamp(cols, max=x_cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet18(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #32x32 : 256
        
    def forward(self, x):
        return self.base(x)

In [None]:
class Adaptor(nn.Module):
    def forward(self, x):
        x = x.flatten(2,3).permute(0,2,1)
        return x.mul(8)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, embeddings, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        self.embed = embeddings
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)
        c_emb,w_emb = self.embed(c_tgt, w_tgt)
        feats = self.encoder(src)
        char_outs = self.c_decoder(c_emb, feats, c_mask)
        word_outs = self.w_decoder(w_emb, feats, w_mask)
        return char_outs, word_outs
        
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
        
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt, 1).long()
        w_tgt = rshift(w_tgt, 1).long()
        
#         c_mask = parallelogram_mask(c_tgt.size(-1), 20)   # char needs word context
#         w_mask = parallelogram_mask(w_tgt.size(-1), 20)   # word needs sentence context
        c_mask = subsequent_mask(c_tgt.size(-1)) 
        w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0, heads=8):
    c = deepcopy
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        LearnedPositionalEmbeddings(d_model, vocab, rows=15, cols=[100,60]),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, adaptor, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.adaptor = adaptor
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.adaptor(self.img_enc(src))
        char_outs, word_outs = self.transformer(feats, tgt)
        return self.transformer.generate(char_outs, word_outs)

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.recon = learn.data.y.reconstruct_one

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.recon)
        w_error,_    = cer(w_out, w_targ, self.recon)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz)
    adaptor = Adaptor()
    transformer = make_full_model(len(data.vocab), d_model, N, drops, heads)
    net = Img2Seq(img_encoder, adaptor, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing),
                    callback_fns=[TeacherForce, MultiCER])
    return learn

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

## Combo Arch w/ integrated LM

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, vocab, rows, cols, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = nn.Embedding(vocab, d_model, 0)
        self.rows = nn.Embedding(15, d_model//2, 0)
        self.c_cols = nn.Embedding(cols[0], d_model//2, 0)
        self.w_cols = nn.Embedding(cols[1], d_model//2, 0)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, char, word):
        return self.encode_one(char, 'char'), self.encode_one(word, 'word')
    
    def encode_one(self, x, kind):
        rows,cols = self.encode_spatial_positions(x)
        
        x_cols = self.c_cols if kind=='char' else self.w_cols

        row_t = self.rows(rows)            
        col_t = x_cols(torch.clamp(cols, max=x_cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet18(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #32x32 : 256
        
    def forward(self, x):
        return self.base(x)

In [None]:
class Adaptor(nn.Module):
    def forward(self, x):
        x = x.flatten(2,3).permute(0,2,1)
        return x.mul(8)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, embeddings, generator, lm_enc):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        self.embed = embeddings
        self.generator = generator
        self.lm_enc = lm_enc
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)
        c_emb,w_emb = self.embed(c_tgt, w_tgt)
        feats = self.encoder(src)
        char_outs = self.c_decoder(c_emb, feats, c_mask)
        word_outs = self.w_decoder(w_emb, feats, w_mask)
        return char_outs, word_outs
        
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def lm(self, x):
        # x => [bs, sl, vocab] : generator output
        with torch.no_grad():
            x = torch.argmax(x, dim=-1)
        x = self.embed.embed(x)
        #x = torch.matmul(F.softmax(x, dim=-1), self.embed.embed.weight)  # reverse embedding from vocab
        return self.generator(self.lm_enc(x))
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt, 1).long()
        w_tgt = rshift(w_tgt, 1).long()
        
#         c_mask = parallelogram_mask(c_tgt.size(-1), 20)   # char needs word context
#         w_mask = parallelogram_mask(w_tgt.size(-1), 20)   # word needs sentence context
        c_mask = subsequent_mask(c_tgt.size(-1)) 
        w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0, heads=8):
    c = deepcopy
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        LearnedPositionalEmbeddings(d_model, vocab, rows=15, cols=[100,60]),
        nn.Linear(d_model, vocab),
        nn.Sequential(
            Lambda(lambda x: x.mul_(math.sqrt(d_model))),
            PositionalEncoding(d_model, drops),
            Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        )
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, adaptor, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.adaptor = adaptor
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.adaptor(self.img_enc(src))
        char_outs, word_outs = self.transformer(feats, tgt)
        c_outs,w_outs = self.transformer.generate(char_outs, word_outs)
        return c_outs,w_outs,self.transformer.lm(w_outs)

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.recon = learn.data.y.reconstruct_one

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word', 'word_lm'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out, w_lm = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.recon)
        w_error,_    = cer(w_out, w_targ, self.recon)
        lm_error,_   = cer(w_lm, w_targ, self.recon)
        self.c_errors += c_error
        self.w_errors += w_error
        self.lm_errors += lm_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.lm_errors, self.total = 0, 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total, self.lm_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        lm = loss(pred[2], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl + lm

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz)
    adaptor = Adaptor()
    transformer = make_full_model(len(data.vocab), d_model, N, drops, heads)
    net = Img2Seq(img_encoder, adaptor, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing),
                    callback_fns=[TeacherForce, MultiCER])
    return learn

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Experiments

In [None]:
learn.data = data

In [None]:
learn.load('sm_sp10k_lpe2'); None

In [None]:
lrs = 1e-3

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
# optimizer settings

# lr_max: 1e-3
# moms: (0.95, 0.85)
# div_factor: 25.0
# pct_start: 0.3
# final_div: 250000.0
# tot_epochs: 3

# wd: 1e-2

# OptimWrapper over Adam (
# Parameter Group 0
#     amsgrad: False
#     betas: (0.95, 0.99)
#     eps: 1e-08
#     lr: 4e-09
#     weight_decay: 0

# Parameter Group 1
#     amsgrad: False
#     betas: (0.95, 0.99)
#     eps: 1e-08
#     lr: 4e-09
#     weight_decay: 0
# ).
# True weight decay: True

In [None]:
learn.fit_one_cycle(5, max_lr=lrs, callbacks=[SaveModelCallback(learn, name='hw_sp10k')])
# data: small dataset, sz:256, bs:50, sentence_piece
# train: 5cycle(1e-3)
# arch: (512, 512, N=4, drops=0.1, heads=8); subsequent_masks

# 10k, lpe, 67.7M
# 71.242447	68.054993	0.759380	0.841816	08:04   freeze(), 1cycle(1e-3)
# 12.448058	13.792800	0.082833	0.123208	08:18   unfreeze(), 5cycle(1e-4,1e-3)  'sm_sp10k_lpe'

# 10k, lpe, 67.7M, 5cycle(1e-3)   
# 10.024531	10.917119	0.059539	0.100420	08:22   'sm_sp10k_lpe2'  **
# chars:    1.71376   .10433
# words:    1.56690   .06228

# 30k, lpe, 88.2M, 5cycle(1e-3)
# 11.141268	11.952768	0.070033	0.114384	09:21   'sm_sp30k_lpe'

# 50k, lpe, 108.7M, 5cycle(1e-3)
# 12.930821	12.563131	0.097857	0.161395	11:04   'sm_sp50k_lpe2'


# 10k, lpe, parallelogram_mask, 5cycle(1e-3)
# 10.259628	11.054597	0.059808	0.101713	09:08


# w/ WORD LM
# 10k, lpe, attached word_lm, 80.3M  --  no good for inference:(
# 5.786808	6.193685	0.054422	0.026740	09:50   'sm_sp10k_lm'   **

# word_lm after generator, 80.3M
# 30.129230	29.837000	0.053065	0.672288	09:11

# word_lm after generator from embedding
# 11.316447	12.298356	0.057172	0.115761	0.006366	09:27    'sm_sp10k_lm2'
# chars:    1.79026   .10513
# words:    1.46591   .07781
#    lm:    1.57080   .07781

# word_lm from generator w/ reverse embedding
# 18.072691	18.354738	0.079624	0.202494	0.034407	10:21    'sm_sp10k_lm3'
# chars:    1.92827   .14636
# words:    1.67687   .13378
#    lm:    1.87302   .13378

# word_lm from generator, argmax, embedding, pos_enc
# 17.686855	18.746574	0.059707	0.110811	0.114993	09:31   'sm_sp10k_lm4'
# chars:    2.21205   .10050
# words:    1.32152   .07384
#    lm:    1.37493   .08308


# data: font generated, sz:512, bs:15, sentence_piece10k
# train: 3cycle(1e-3)
# arch: (512, 512, N=4, drops=0.1, heads=8); subsequent_masks

# 10k, lpe, 67.7M, combo
# 31.972498	28.244387	0.017639	0.037072	1:45:40   'font_sp10k'
# chars:    32.2701   .01937
# words:    41.7109   .02917

# word_lm from generator w/ reverse embedding - preload sm_sp10k_lm3
# 66.723541	62.563965	0.039113	0.102822	0.010772	1:59:23   stopped after 2cycles, 'font_sp10k_lm'
# chars:    49.6740   .05171
# words:    51.9776   .03500
#    lm:    59.4986   .03500

# word_lm from generator, argmax, embedding, pos_enc - preload sm_sp10k_lm4
# 84.748764	66.507683	0.034499	0.074262	0.069198	1:58:18   stopped after 2cycles, 'font_sp10k_lm2'
# chars:    73.3652   .10826
# words:    65.1936   .06696
#    lm:    65.5722   .06706


# data: handwriting, sz:512, bs:15, sentence_piece10k
# train: 5cycle(1e-3)
# arch: (512, 512, N=4, drops=0.1, heads=8); subsequent_masks

# 10k, lpe, 67.7M, combo
# 14.077939	11.943507	0.006205	0.003816	21:38   'hw_sp10k'
# chars:    50.1363   .00435
# words:    5.17031   .00043

# test pg:
# chars:    144.698   .03923
# words:    80.1309   .04630

# test upload:
# chars:    129.044   .26124
# words:    61.2923   .30572

### Previous Architecture testing

In [None]:
# small, bs: 50

# maxpool(16,None), original adaptor
# 52.609947	48.825050	0.691070	0.818217	15:10   88M; freeze(), 1e-3   
# 7.000093	8.394917	0.048174	0.052259	14:23   unfreeze(), (1e-4,1e-3)   *sm_combo_bt
# chars:    1.73652   .07256
# words:    0.89602   .03005

# adaptor from multi-resolution; remove src_adapt
# 58.962891	53.468037	0.750514	0.815653	11:11   91M; freeze(), 1e-3
# 8.015573	9.585506	0.054404	0.065465	12:05   unfreeze(), (1e-4,1e-3)   *sm_combo_bt2
# chars:    1.18025   .09011
# words:    0.96089   .05492

# adaptor: conv_layer(em_sz,em_sz,(2,1)) + lin/mul(8)
# 58.688446	54.525970	0.677837	0.841201	12:07   86M; freeze(), 1e-3
# 7.844480	9.105121	0.054534	0.064099	12:44   unfreeze(), (1e-4,1e-3)   *sm_combo_bt3

# preload pg_combo_bt4; maxpool(16,None) -> conv
# 54.845032	50.746662	0.684326	0.817202	12:39   86M; freeze(), 1e-3
# 6.565070	8.298703	0.048370	0.050501	13:18   unfreeze(), (1e-4,1e-3)   *sm_combo_bt4
# chars:    1.22381   .09270
# words:    0.94450   .06052


# pg, bs: 10

# conv -> maxpool(16,None), original + lin/mul(8)
# 1794.547241	1622.572144	1.033109	0.881144	02:04
# 1015.301025	1080.083496	0.564250	0.680124	02:09   *pg_combo_bt

# maxpool(16,None) -> conv
# 1776.183105	1585.937256	1.065259	0.884869	02:05
# 1010.742371	1077.592163	0.563446	0.678190	02:08   *pg_combo_bt4

# adaptor from multi-resolution
# 1789.552856	1590.777710	1.138431	0.911390	02:00
# 1027.244995	1084.395264	0.569967	0.679046	02:03   *pg_combo_bt2

# adaptor: conv(em_sz,em_sz,(2,1))
# 1791.784668	1596.507812	1.053432	0.880950	02:04
# 1043.639893	1096.874268	0.565953	0.697545	02:07   *pg_combo_bt3

In [None]:
# pg, bs: 10, N:4, d_model:512

# alternate LearnedPositionalEmbeddings
# 1767.251709	1570.349365	1.094300	0.893096	02:09   split(adaptor); freeze(1e-3)
# 1018.636475	1082.605713	0.556753	0.688827	02:10   unfreeze(1e-4,1e-3) alt_pg_combo_bt

# resnet18 (512, 512, N=4, drops=0.1, heads=8), 88M
# 1819.735596	1640.198608	1.139695	0.910257	01:53
# 978.301208	1062.761963	0.558057	0.660668	01:56   alt_pg_combo_bt2

# resnet34 (512, 512, N=4, drops=0.1, heads=8), 98M
# 1821.668213	1645.444458	1.102799	0.891206	01:59
# 982.210693	1061.083496	0.554228	0.664686	02:03    alt_pg_combo_bt3


# resnet18 (512, 512, N=4, drops=0.1, heads=16), 88M
# 1810.988281	1641.318848	1.107212	0.897255	01:59
# 965.121765	1060.464966	0.552585	0.658893	02:01   alt_pg_combo_bt4

# resnet18 + lin adaptor (768, 512, N=4, drops=0.1, heads=12), 162M
# 1725.002075	1577.129517	1.024268	0.871293	02:18
# 977.914246	1056.783325	0.542184	0.670891	02:22   alt_pg_combo_bt5

# resnet18 (512, 512, N=4, drops=0.1, heads=16), 88M, subsequent_masks
# 1813.572388	1627.146973	1.082598	0.888579	01:58
# 1042.711060	1122.876343	0.601146	0.677406	02:01   alt_pg_combo_bt6


# resnet18 (512, 512, N=4, drops=0.1, heads=16), 88M [mod tfms]
# 1825.002075	1646.828247	1.200644	0.911882	01:59
# 986.249390	1067.287598	0.556860	0.663619	02:01   alt_pg_combo_bt7



# SentencePiece
# resnet18 (512, 512, N=4, drops=0.1, heads=16), 88M [mod tfms]
# 1807.717773	1669.616699	1.265745	0.939794	01:35
# 944.019165	1058.646973	0.548149	0.669054	01:38   alt_pg_combo_sp

# N=6, 111M
# 1814.066284	1672.442993	1.256078	0.922056	01:54
# 961.308472	1061.681641	0.546759	0.687847	01:56   alt_pg_combo_sp2

In [None]:
# small dataset, sz:256, bs:50, 3cycle(1e-3)
# arch: (512, 512, N=4, drops=0.1, heads=8); subsequent_masks

# bert_tok, learned positional embedding
# 23.695314	22.344717	0.163726	0.301154	09:50

# bert_tok, pos_enc
# 30.034079	27.572166	0.206749	0.367589	08:39

# sentence_piece, learned positional embedding
# 23.044443	21.780771	0.161438	0.296342	09:04

# sentence_piece, pos_enc
# 31.692009	28.524887	0.213417	0.392681	07:51

# sentence_piece, lpe, 50k, 108.7M
# 24.660820	23.094442	0.174975	0.321934	10:13

# sentence_piece, lpe, 10k, 67.7M
# 17.922606	17.049728	0.113365	0.216181	08:12

# Train

In [None]:
learn.load('font_combo_sp2', strict=False); None

In [None]:
learn.split(lambda m: [m.adaptor]); None
len(learn.layer_groups)

In [None]:
learn.freeze()
lrs = 1e-3

In [None]:
learn.unfreeze()
lrs = slice(1e-5,1e-4)

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.fit_one_cycle(3, max_lr=lrs, callbacks=[SaveModelCallback(learn, name='font_combo_sp2')])

# small, bs: 50

# preload pg_combo_bt4; maxpool(16,None) -> conv
# 54.845032	50.746662	0.684326	0.817202	12:39   86M; freeze(), 1e-3
# 6.565070	8.298703	0.048370	0.050501	13:18   unfreeze(), (1e-4,1e-3)   *sm_combo_bt4
# chars:    1.22381   .09270
# words:    0.94450   .06052

# sentencepiece, modified tfms, resnet18 (512, 512, N=4, drops=0.1, heads=16), 88M
# 74.434914	70.777931	0.816833	0.876581	09:21   freeze(), 1cycle(1e-3)
# 12.942133	15.643773	0.080849	0.165932	09:28   unfreeze(), 5cycle(1e-4,1e-3)   *sm_combo_sp
# chars:    2.22918   .17094
# words:    1.13729   .17206

# sentencepiece, "", preload alt_pg_combo_sp
# 61.517582	57.855293	0.514289	0.772304	09:09   freeze()
# 11.257002	14.372008	0.071675	0.139939	09:30   unfreeze(),    *sm_combo_sp2


# font_mix, bs: 10, preload 'sm_combo_bt4'

# 75.363541	62.632046	0.038649	0.099920	2:56:03   2epochs (1e-4,1e-3)   *font_combo_bt
# 42.600681	39.815552	0.015061	0.045635	3:04:19   2epochs (1e-5,1e-4)   *font_combo_bt2
# chars:    41.8378   .00182
# words:    83.1500   .00275

# 53.768700	45.393501	0.020888	0.074038	2:12:17   4epochs(1e-4,1e-3)    sentencepiece  *font_combo_sp
# 50.356300	44.293865	0.016115	0.061809	2:18:10   1epoch(3e-6,2e-5)
# 49.612297	42.464523	0.014771	0.061766	2:17:11   2epochs(1e-5,1e-4)    *font_combo_sp2
# chars:    27.5662   .00351
# words:    90.0695   .00650
#TEST pg
# chars:    216.829   .19460
# words:    89.0265   .18346

# handwriting_mix, bs: 10, preload 'font_combo_bt2'

# 11.269932	10.436209	0.004604	0.002646	35:18   5cycle(1e-4,1e-3)    *hw_combo_bt
# chars:    85.4927   .00315
# words:    0.61776   .00065

#TEST pg
# chars:    138.678   .04029
# words:    69.4958   .05652
#TEST upload
# chars:    128.604   .23494
# words:    56.7417   .29080

# word only, preload 'hw_combo_bt'
# 4.775697	5.477561	0.000647	18:20    2cycle(2e-5)    *hw_word_bt
#     pg:    68.1078   .05712
# upload:    50.6502   .29197

# w/ pretrained LM: 'distilbert_mlm_xtra', preload: 'hw_word_bt'
# 2.109078	1.383242	0.000374	15:17     *hw_word_bt_lm
#     pg:    94.0378   .05911
# upload:    64.4596   .30308

# word_only, preload 'font_combo_bt2', bs: 16
# 5.143809	5.627481	0.000719	15:40   5cycle(1e-4,1e-3)    *hw_word_bt_alt
#     pg:    64.6734   .06081
# upload:    50.5408   .25785

# View Model Telemetry

In [None]:
class FullStats(HookCallback):
    def on_train_begin(self, **kwargs):
        self.modules = [m for m in flatten_model(self.learn.model) if hasattr(m, 'weight')]
        self.g_hooks = Hooks(self.modules, self.g_hook, is_forward=False)
        self.a_hooks = Hooks(self.modules, self.a_hook)
        self.grads,self.acts = [],[]

    def g_hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:
        oo = next(o)
        return oo.mean().item(),oo.std().item()
    
    def a_hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:
        return o.mean().item(),o.std().item()

    def on_batch_end(self, train, **kwargs):
        if train:
            self.acts.append(self.a_hooks.stored)
            self.grads.append(self.g_hooks.stored)
            
    def on_train_end(self, **kwargs):
        self.a_hooks.remove()
        self.g_hooks.remove()
        self.acts = tensor(self.acts).permute(2,1,0)
        self.grads = tensor(self.grads).permute(2,1,0)

In [None]:
learn.fit(1, 1e-5, callbacks=[FullStats(learn)])#, StopAfterNBatches(n_batches=2)])

In [None]:
acts,grads = learn.full_stats.acts, learn.full_stats.grads
acts.shape,grads.shape

In [None]:
names=[]
for name, param in learn.model.named_parameters():
    if name.endswith('weight'):
        names.append(name)

names.insert(193, 'transformer.w_embed.embed.weight')
names.insert(194, 'transformer.w_embed.rows.weight')
len(names)

In [None]:
# :64     img_enc
# 64:67   adaptor
# 67:84   encoder
# 84:137  c_decoder
# 137:190 w_decoder
# 190:    embeddings/generator

plt.figure(figsize=(20,10))
for l in acts[1,137:190]:
    plt.plot(l)
plt.legend(names[137:190])

In [None]:
avg_act_stds_by_layer = acts[1,:].mean(-1)
avg_grad_stds_by_layer = grads[1,:].mean(-1)

In [None]:
plt.plot(avg_act_stds_by_layer)

In [None]:
plt.plot(avg_grad_stds_by_layer)

In [None]:
for (i,mod),a,g in zip(enumerate(names), avg_act_stds_by_layer, avg_grad_stds_by_layer):
    mod_name = str(mod).split('(')[0]
    print(f"{str(i).ljust(3)} {mod_name.ljust(60)} \
            {str(round(a.item(),5)).ljust(6)} {str(round(g.item(),5)).ljust(6)}")

In [None]:
# Last batch activations by layer

for (i,mod),m,s in zip(enumerate(names), acts[0,:,-1], acts[1,:,-1]):
    mod_name = str(mod).split('(')[0]
    print(f"{str(i).ljust(3)} {mod_name.ljust(50)} \
            {str(round(m.item(),5)).ljust(6)}  {str(round(s.item(),5)).ljust(6)}")

# Adjust State Dict and Split

## Add LM to model state_dict

In [None]:
learn.model.modules

In [None]:
sd = torch.load(PATH/'models/hw_word_bt.pth', map_location=device)

In [None]:
lm_sd = torch.load(PATH/'models/distilbert_mlm_xtra.pth', map_location=device)

In [None]:
lm_sd['model'].keys()

In [None]:
from collections import OrderedDict
model_sd = OrderedDict()

for k,v in lm_sd['model'].items():
    lm_k = k.replace('model', 'lm')
    model_sd[lm_k] = v
#     if k.startswith('transformer.encoder'):
#         c_k = k.replace('encoder', 'c_encoder')
#         w_k = k.replace('encoder', 'w_encoder')
#         model_sd[c_k] = v
#         model_sd[w_k] = v
#     else:
#         model_sd[k] = v

In [None]:
sd['model'].update(model_sd)

In [None]:
sd['model'].keys()

In [None]:
sd['model']['transformer.generator.weight'] = sd['model']['transformer.c_generator.weight']
sd['model']['transformer.generator.bias']  = sd['model']['transformer.c_generator.bias']

In [None]:
learn.model.load_state_dict(sd['model'], strict=False)

## load and split learner

In [None]:
learn.split(lambda m: [m.transformer]); None
len(learn.layer_groups)

In [None]:
# psutil.virtual_memory()  #41.7

In [None]:
learn.freeze_to(1)
learn.model.transformer.generator.weight.requires_grad #.conv[0].weight.requires_grad

In [None]:
learn.unfreeze()

# Char/Word Greedy results

In [None]:
def greedy_decode(src, model, seq_len, kind='char', bos_tok=1, lm=False):
    model.eval()
    tfmr = model.transformer
    img_enc = model.img_enc
    adaptor = model.adaptor
    decoder = tfmr.c_decoder if kind=='char' else tfmr.w_decoder
    
    with torch.no_grad():
        feats = tfmr.encoder(adaptor(img_enc(src)))
        bs = src.size(0)
        tgt = torch.zeros((bs,1), dtype=torch.long, device=device) + bos_tok

        res = []
        for i in progress_bar(range(seq_len)):
            mask = subsequent_mask(tgt.size(-1))
            emb = tfmr.embed.encode_one(tgt, kind)
            
            dec_outs = decoder(emb, feats, mask)
            prob = tfmr.generator(dec_outs[:,-1])
            res.append(prob)
            pred = torch.argmax(prob, dim=-1, keepdim=True)
            if (pred==0).all(): break
            tgt = torch.cat([tgt,pred], dim=-1)
        out = torch.stack(res).transpose(1,0).contiguous()
        if lm: out = tfmr.lm(out)
        return out

In [None]:
vdl = iter(learn.data.valid_dl)

In [None]:
x,y = next(vdl)

### Single Word

In [None]:
g_preds = greedy_decode(x, learn.model, word_len, 'word', 101)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, data.y.reconstruct)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    #i += 8
    p = data.y.reconstruct(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

### Single Char

In [None]:
g_preds = greedy_decode(x, learn.model, seq_len, 'char')
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, data.y.reconstruct)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

### Combo Chars

In [None]:
g_preds = greedy_decode(x, learn.model, seq_len, 'char', 1)
g_res = torch.argmax(g_preds, dim=-1)
loss_func = LabelSmoothing()
g = [loss_func(g_preds, y[0]).item()/bs, cer(g_preds, y[0], data.y.reconstruct_one)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(3,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct_one(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

### Combo Words

In [None]:
g_preds = greedy_decode(x, learn.model, word_len, 'word', 1, lm=False)
g_res = torch.argmax(g_preds, dim=-1)
loss_func = LabelSmoothing()
g = [loss_func(g_preds, y[1]).item()/bs, cer(g_preds, y[1], data.y.reconstruct_one)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(3,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct_one(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

# Test

In [None]:
FOLDER = 'uploads'
df = pd.read_csv(PATH/'uploads.csv')
len(df)

sz,bs = 512,14
seq_len = 700

In [None]:
FOLDER = 'paragraphs'
df = pd.read_csv(PATH/'test_pg.csv')
len(df)

sz,bs = 512,15
seq_len = 700

In [None]:
# combo(bert_tok)
test_data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
# words(bert_tok)
test_data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
# sentencepiece combo
test_data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=SPMMultiList, sp=sp)
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
dl = iter(test_data.train_dl)

In [None]:
x,y = next(dl)