# Prelims

In [None]:
### RuntimeError: cuda runtime error (59) : device-side assert triggered ###

# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks.tracker import *

import pdb

In [None]:
from transformers import DistilBertTokenizer, DistilBertForMaskedLM

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(im.data), alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((1,size,size), device=device).byte())
    
def parallelogram_mask(size, diagonal):
    mask = torch.ones((1,size,size), device=device).byte()
    upper = torch.tril(mask).bool()
    lower = torch.triu(mask, diagonal=-diagonal).bool()
    return (upper & lower).byte()

## Loss, Metrics, Callbacks

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = self.loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs
    
    def loss_prep(self, pred, target):
        "equalize input/target sl; combine bs/sl dimensions"
        bs,tsl = target.shape
        _ ,sl,vocab = pred.shape

        # F.pad( front,back for dimensions: 1,0,2 )
        if sl>tsl: target = F.pad(target, (0,sl-tsl))
        if tsl>sl: pred = F.pad(pred, (0,0,0,tsl-sl))

        targ = target.contiguous().view(-1).long()
        pred = pred.contiguous().view(-1, vocab)
        return pred, targ

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, fn):
        super().__init__()
        self.name = 'cer'
        self.fn = fn

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = cer(last_output, last_target, self.fn)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

In [None]:
def cer(preds, targs, fn):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = str(fn(res[i]))
        t = str(fn(targs[i]))
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

def label_text(pred, itos, sep=''):
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)] #[:-1]  #remove eos token
    return sep.join([itos[i] for i in nonzero])

In [None]:
class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        return {'last_input':(last_input, last_target), 'last_target':last_target}

# Data

## paragraphs

In [None]:
fname = 'edited_pg.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 512,15
seq_len,word_len = 700,300

## sm synth dataset

In [None]:
fname = 'edited_sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'edited_sm_synth'

df = pd.read_csv(CSV)
len(df)

In [None]:
# sz,bs = 128,100
sz,bs = 256,100
seq_len,word_len = 100,50

## combo

In [None]:
fname = 'combo_145k.csv'
FOLDER = 'combo_cat'

CSV = PATH/fname
df = pd.read_csv(CSV)

In [None]:
# full
sz,bs = 512,20
seq_len,word_len = 750,300
len(df)

In [None]:
# 6 and fewer
df = df[df.num_lines <= 6]
sz,bs = 512,10#15
seq_len,word_len = 600,200
len(df)

In [None]:
# 6 and greater
df = df[df.num_lines >= 6]
sz,bs = 512,15
seq_len,word_len = 750,300
len(df)

## test combo (no fonts)

In [None]:
FOLDER = 'test_combo'
CSV = PATH/'test_combo.csv'
df = pd.read_csv(CSV)

sz,bs = 512,10
seq_len,word_len = 750,300
len(df)

# ModelData

In [None]:
tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=0, max_warp=0.1)

def force_gray(image): return image.convert('L').convert('RGB')

## Char or Word

In [None]:
def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1 and lbls[0] is 0:   #predict
        labels = torch.zeros(1,1).long()
        return imgs, labels    
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

In [None]:
class SequenceList(TextList):    
    def __init__(self, items:Iterator, vocab:Vocab, tokenizer:Tokenizer, **kwargs):
        toknizr = Tokenizer(tok_func=tokenizer, pre_rules=[], post_rules=[], special_cases=[])
        procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False), NumericalizeProcessor(vocab=vocab)]
        super().__init__(items, vocab, sep='', pad_idx=0, processor=procs)
    
    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

### Chars

In [None]:
itos = pickle.load(open(PATH/'itos.pkl', 'rb'))

In [None]:
class CharTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return list(t)+['xxeos']
            
class CharVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_special_toks(st)
        return st
    
def remove_special_toks(text):
    text = re.sub(r'(xxbos)', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'(xxeos)', '', text)  #[SEP] (w/ preceding whitespace)
    return text

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos), tokenizer=CharTokenizer)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
#         .transform(tfms, size=sz, resize_method=ResizeMethod.PAD, padding_mode='border')
        # maintains aspect ratio but too small for good results => mostly whitespace
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        #.normalize() # this sets x values to an odd range (~.3,-6)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Words

In [None]:
word_itos = pickle.load(open(PATH/'word_itos_60k_mod.pkl', 'rb'))
#word_itos = word_itos[:10000]
len(word_itos)

In [None]:
class WordTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: 
        chars = list(t.lower())
        res = []
        tok = ""
        for c in chars:
            if c.isalnum():
                tok+=c
            else:
                if tok.isalnum(): res.append(tok) 
                res.append(c)
                tok = ""
        if tok.isalnum(): res.append(tok)
        return res + ['xxeos']

# class WordTokenizer(BaseTokenizer):
#     def tokenizer(self, t:str) -> List[str]: return t.lower().replace('\n', ' \n ').split(' ') + ['xxeos']
            
class WordVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})
    
    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_special_toks(st)
        return st
    
    def remove_special_toks(text):
        text = re.sub(r'(xxbos)', '', text)  #[CLS] (w/ following whitespace)
        text = re.sub(r'(xxeos)', '', text)  #[SEP] (w/ preceding whitespace)
        return text

In [None]:
# w_vocab = Vocab(word_itos)
# w_procs = [TokenizeProcessor(include_bos=False, include_eos=True), NumericalizeProcessor(vocab=w_vocab)]

words = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SequenceList, vocab=WordVocab(word_itos), tokenizer=WordTokenizer)
        #.label_from_df(label_cls=TextList, pad_idx=0, vocab=w_vocab, processor=w_procs)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
words.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Words - Bert tokenizer (with caps)

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
# num_added_tokens = bert_tok.add_tokens(['\n','[UP]','[MAJ]'])
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'\[CLS\]\s*', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'\s*\[SEP\]', '', text)  #[SEP] (w/ preceding whitespace)
    return text

def remove_wordpiece_toks(text):
    return re.sub(r'##', '', text)

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]:
        return bert_tok.tokenize(t) + ["[SEP]"]

In [None]:
class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
class BertWordList(TextList):    
    def __init__(self, items:Iterator, vocab, **kwargs):
        toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                            post_rules=[], special_cases=[])
        procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False), NumericalizeProcessor(vocab=vocab)]
        super().__init__(items, vocab, sep='', pad_idx=0, processor=procs)
    
    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=BertWordList, vocab=BertVocab())
        .transform(tfms, size=(sz,sz), padding_mode='zeros')
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Words (bert_tok)

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'\[CLS\]\s*', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'\s*\[SEP\]', '', text)  #[SEP] (w/ preceding whitespace)
    return text

def remove_wordpiece_toks(text):
    return re.sub(r'##', '', text)

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return bert_tok.tokenize(t) + ["[SEP]"]

class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab
        
    def process_one(self,item):
        return np.array(self.vocab.numericalize(item), dtype=np.int64)
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])    
        

class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        w = self.items[i]
        return Text(w, self.vocab.textify(w))
    
    def reconstruct(self, t:Tensor):
        idx_min,idx_max = (t != self.pad_idx).nonzero().min(), (t != self.pad_idx).nonzero().max()
        return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Chars (bert_tok)

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'\[CLS\]\s*', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'\s*\[SEP\]', '', text)  #[SEP] (w/ preceding whitespace)
    return text

def remove_wordpiece_toks(text):
    return re.sub(r'##', '', text)

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return bert_tok.tokenize(t) + ["[SEP]"]

class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters."
    res = []
    for t in x:
        if t in ['\n',' ','[UP]','[MAJ]','[UNK]', '[CLS]', '[MASK]', '[PAD]', '[SEP]']:
            res.append(t)
        elif t.startswith('##'):  # wordpiece
            [res.append(c) for c in list(t[2:])]
        else:
            [res.append(c) for c in list(t)]
    return res

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[characterize], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab
        
    def process_one(self,item):
        return np.array(self.vocab.numericalize(item), dtype=np.int64)
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])    
        

class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        w = self.items[i]
        return Text(w, self.vocab.textify(w))
    
    def reconstruct(self, t:Tensor):
        idx_min,idx_max = (t != self.pad_idx).nonzero().min(), (t != self.pad_idx).nonzero().max()
        return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

## Char and Word

### Trio w/ Bert tokenizer

In [None]:
class CustomVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

In [None]:
class CustomTokenizer(BaseTokenizer):
    "Split on words but keep original spacing"
    def tokenizer(self, t:str) -> List[str]: 
        chars = list(t)
        res = []
        tok = ""
        for c in chars:
            if c.isalnum():
                tok+=c
            else:
                if tok.isalnum(): res.append(tok) 
                res.append(c)
                tok = ""
        if tok.isalnum(): res.append(tok)
        return ['xxbos'] + res + ['xxeos']

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
class BertTokenizer(BaseTokenizer): 
    def __init__(self, tokenizer=bert_tok, **kwargs):
        self._tok_func = tokenizer
        
    def __call__(self, *args, **kwargs): return self 
    
    def tokenizer(self, t:str) -> List[str]: 
        return ["[CLS]"] + self._tok_func.tokenize(t) + ["[SEP]"]

In [None]:
itos = pickle.load(open(PATH/'combo_itos_60k.pkl', 'rb'))
itos = itos[:30000]
#vocab = CustomVocab(itos)

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters. (Keep special modifiers: xxmaj, xxup)"
    res = []
    for t in x:
        [res.append(c) for c in list(t)]
    return res

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls, bert_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    return imgs, (c_pad(char_lbls), c_pad(word_lbls), c_pad(bert_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.c_toknizr = Tokenizer(tok_func=CustomTokenizer, pre_rules=[rm_useless_spaces],
                                   post_rules=[replace_all_caps, deal_caps, characterize],
                                   special_cases=['xxbos','xxeos','xxmask','xxunk','xxpad','xxmaj','xxup','\n'])
        self.w_toknizr = Tokenizer(tok_func=CustomTokenizer, pre_rules=[rm_useless_spaces],
                                   special_cases=['xxbos','xxeos','xxmask','xxunk','xxpad','xxmaj','xxup','\n'])
        self.b_toknizr = Tokenizer(tok_func=BertTokenizer(bert_tok),
                           pre_rules=[], post_rules=[], special_cases=['[CLS], [SEP]'])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        ds.items = list(zip(self._p(self.c_toknizr, ds), self._p(self.w_toknizr, ds), self._p(self.b_toknizr, ds)))
    
    def _p(self, toknizr, ds):
        res = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            res += toknizr.process_all(ds.items[i:i+self.chunksize])
        return res
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = CustomVocab(itos)
        self.bert_vocab = Vocab(list(bert_tok.vocab.keys()))
        
    def process_one(self,item):
        chars = np.array(self.vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.vocab.numericalize(item[1]), dtype=np.int64)
        bert_words = np.array(self.bert_vocab.numericalize(item[2]), dtype=np.int64)
        return [chars,words,bert_words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])
        
        
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]
    
    def __init__(self, items:Iterator, pad_idx=0, sep='', **kwargs):
        super().__init__(items, **kwargs)
        self.vocab = CustomVocab(itos)
        self.bert_tok = bert_tok
        self.bert_vocab = Vocab(list(bert_tok.vocab.keys()))
        self.pad_idx, self.sep = pad_idx,sep
        self.copy_new += ['vocab', 'bert_vocab', 'bert_tok', 'pad_idx', 'sep']

    def get(self, i):
        o = self.items[i]
        return [self._get(o[0]), self._get(o[1]), self._get(o[2], self.bert_vocab, ' ')]
    
    def reconstruct(self, t:Tensor):
        "t: List of tensors -> [char, word, bert]"        
        return [self._recon(t[0]), self._recon(t[1]), self._recon(t[2], self.bert_vocab)]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)
    
    def _get(self, x, vocab=None, sep=None):
        vocab = ifnone(vocab, self.vocab)
        sep = ifnone(sep, self.sep)
        return Text(x, vocab.textify(x, sep))
    
    def _recon(self, x, vocab=None):
        vocab = ifnone(vocab, self.vocab)
        idx_min,idx_max = (x != self.pad_idx).nonzero().min(), (x != self.pad_idx).nonzero().max()
        return Text(x[idx_min:idx_max+1], vocab.textify(x[idx_min:idx_max+1]))

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words,berts = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words) + '\n\n' + str(berts))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Combo (bert_tok)

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
num_added_tokens = bert_tok.add_tokens(['\n',' ','[UP]','[MAJ]'])

In [None]:
def add_cap_tokens(text):  # before encode
    re_caps = re.compile(r'[A-Z]+')
    return re_caps.sub(_replace_caps, text)
    
def _replace_caps(m):
    tok = '[UP]' if m.end()-m.start() > 1 else '[MAJ]'
    return tok + m.group().lower()

def remove_cap_tokens(text):  # after decode
    text = re.sub(r'\[UP\]\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'\[MAJ\]\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

def remove_special_toks(text):
    text = re.sub(r'\[CLS\]\s*', '', text)  #[CLS] (w/ following whitespace)
    text = re.sub(r'\s*\[SEP\]', '', text)  #[SEP] (w/ preceding whitespace)
    return text

def remove_wordpiece_toks(text):
    return re.sub(r'##', '', text)

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters."
    res = []
    for t in x:
        if t in ['\n',' ','[UP]','[MAJ]','[UNK]', '[CLS]', '[MASK]', '[PAD]', '[SEP]']:
            res.append(t)
        elif t.startswith('##'):  # wordpiece
            [res.append(c) for c in list(t[2:])]
        else:
            [res.append(c) for c in list(t)]
    return res

In [None]:
class BertTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]:
        w_toks = bert_tok.tokenize(t) + ["[SEP]"]
        c_toks = characterize(w_toks)
        return [c_toks, w_toks]

In [None]:
class BertVocab(Vocab):
    def __init__(self):
        self.itos = list(bert_tok.vocab.keys()) + ['\n',' ','[UP]','[MAJ]']
        self.stoi = collections.defaultdict(lambda: 100, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        st = sep.join([self.itos[i] for i in nums])
        st = remove_wordpiece_toks(st)
        st = remove_cap_tokens(st)
        st = remove_special_toks(st)
        return st

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[rm_useless_spaces, add_cap_tokens],
                                 post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            tokens += self.toknizr.process_all(ds.items[i:i+self.chunksize])
        ds.items = tokens
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab
        
    def process_one(self,item):
        chars = np.array(self.vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.vocab.numericalize(item[1]), dtype=np.int64)
        return [chars,words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])    
        

class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]

    def get(self, i):
        c,w = self.items[i]
        return [Text(c, self.vocab.textify(c)), Text(w, self.vocab.textify(w))]
    
    def reconstruct(self, t:Tensor):
        c,w = t
        return [self.reconstruct_one(c),self.reconstruct_one(w)]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)
    
    def reconstruct_one(self, x):
        nonzero_idxs = (x != self.pad_idx).nonzero()
        idx_min = 0  #(x != self.pad_idx).nonzero().min()
        idx_max = nonzero_idxs.max() if len(nonzero_idxs) > 0 else 0
        return Text(x[idx_min:idx_max+1], self.vocab.textify(x[idx_min:idx_max+1]))

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
# data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Combo (shared itos)

In [None]:
def remove_cap_tokens(text):  # after decode
    text = re.sub(r'xxup\w+', lambda m: m.group()[4:].upper(), text)  #cap entire word
    text = re.sub(r'xxmaj\w?', lambda m: m.group()[5:].upper(), text) #cap first letter
    return text

In [None]:
class CustomVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], handle_cap_tokens:bool=False, sep=''):
        res = sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]
        if handle_cap_tokens:
            return remove_cap_tokens(res)
        else:
            return res

In [None]:
class CustomTokenizer(BaseTokenizer):
    "Split on words but keep original spacing"
    def tokenizer(self, t:str) -> List[str]: 
        chars = list(t)
        res = []
        tok = ""
        for c in chars:
            if c.isalnum():
                tok+=c
            else:
                if tok.isalnum(): res.append(tok) 
                res.append(c)
                tok = ""
        if tok.isalnum(): res.append(tok)
        return ['xxbos'] + res + ['xxeos']

In [None]:
itos = pickle.load(open(PATH/'combo_itos_60k.pkl', 'rb'))
itos = itos[:30000]
#vocab = CustomVocab(itos)

In [None]:
def characterize(x:Collection[str]) -> Collection[str]:
    "Separate word tokens into letters. (Keep special modifiers: xxmaj, xxup)"
    res = []
    for t in x:
        [res.append(c) for c in list(t)]
    return res

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.c_toknizr = Tokenizer(tok_func=CustomTokenizer, pre_rules=[rm_useless_spaces],
                                   post_rules=[replace_all_caps, deal_caps, characterize],
                                   special_cases=['xxbos','xxeos','xxmask','xxunk','xxpad','xxmaj','xxup','\n'])
        self.w_toknizr = Tokenizer(tok_func=CustomTokenizer, pre_rules=[rm_useless_spaces],
                                   special_cases=['xxbos','xxeos','xxmask','xxunk','xxpad','xxmaj','xxup','\n'])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        ds.items = list(zip(self._p(self.c_toknizr, ds), self._p(self.w_toknizr, ds)))
    
    def _p(self, toknizr, ds):
        res = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            res += toknizr.process_all(ds.items[i:i+self.chunksize])
        return res
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.vocab = ds.vocab #CustomVocab(itos)
        
    def process_one(self,item):
        chars = np.array(self.vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.vocab.numericalize(item[1]), dtype=np.int64)
        return [chars,words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])
        
        
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]
    
    def __init__(self, items:Iterator, pad_idx=0, sep='', **kwargs):
        super().__init__(items, **kwargs)
        self.vocab = CustomVocab(itos)        
        self.pad_idx, self.sep = pad_idx,sep
        self.copy_new += ['vocab', 'pad_idx', 'sep']

    def get(self, i):
        o = self.items[i]
        return [self._get(o[0]),self._get(o[1])]
    
    def reconstruct(self, t:Tensor):
        "t: List of tensors -> [char, word]"        
        c,w = t
        return [self._recon(c),self._recon(w)]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)
    
    def _get(self, x):
        return Text(x, self.vocab.textify(x, self.sep))
    
    def _recon(self, x):
        idx_min,idx_max = (x != self.pad_idx).nonzero().min(), (x != self.pad_idx).nonzero().max()
        return Text(x[idx_min:idx_max+1], self.vocab.textify(x[idx_min:idx_max+1]))

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Both (char itos + bert_tok)

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
bert_tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

class BertTokenizer(BaseTokenizer): 
    def tokenizer(self, t:str) -> List[str]: 
        return bert_tok.tokenize(t) + ["[SEP]"]

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.c_toknizr = Tokenizer(tok_func=CharTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        self.w_toknizr = Tokenizer(tok_func=BertTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        ds.chars = ds.items
        ds.words = ds.items

        char_tokens = []
        word_tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            char_tokens += self.c_toknizr.process_all(ds.chars[i:i+self.chunksize])
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            word_tokens += self.w_toknizr.process_all(ds.words[i:i+self.chunksize])
        ds.items = list(zip(char_tokens, word_tokens))
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.char_vocab = ds.char_vocab
        self.word_vocab = ds.word_vocab

    def process_one(self,item):
        chars = np.array(self.char_vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.word_vocab.numericalize(item[1]), dtype=np.int64)
        return [chars,words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])
        
        
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]
            
    def __init__(self, items:Iterator, pad_idx=0, **kwargs):
        super().__init__(items, **kwargs)
        self.char_vocab = CharVocab(itos)
        self.word_vocab = Vocab(list(bert_tok.vocab.keys()))
        
        self.pad_idx = pad_idx
        self.copy_new += ['char_vocab', 'word_vocab', 'pad_idx']

    def get(self, i):
        o = super().get(i)
        chars = Text(o[0], self.char_vocab.textify(o[0], ''))
        words = Text(o[1], self.word_vocab.textify(o[1], ' '))
        return [chars,words]
    
    def reconstruct(self, t:Tensor):
        "t: List of tensors -> [char, word]"        
        c,w = t
        idx_min,idx_max = (c != self.pad_idx).nonzero().min(), (c != self.pad_idx).nonzero().max()
        chars = Text(c[idx_min:idx_max+1], self.char_vocab.textify(c[idx_min:idx_max+1]))
        
        idx_min,idx_max = (w != self.pad_idx).nonzero().min(), (w != self.pad_idx).nonzero().max()
        words = Text(w[idx_min:idx_max+1], self.word_vocab.textify(w[idx_min:idx_max+1]))
        return [chars,words]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

### Both (separate itos)

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.c_toknizr = Tokenizer(tok_func=CharTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        self.w_toknizr = Tokenizer(tok_func=WordTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        ds.chars = ds.items
        ds.words = ds.items

        char_tokens = []
        word_tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            char_tokens += self.c_toknizr.process_all(ds.chars[i:i+self.chunksize])
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            word_tokens += self.w_toknizr.process_all(ds.words[i:i+self.chunksize])
        ds.items = list(zip(char_tokens, word_tokens))
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.char_vocab = CharVocab(itos)
        self.word_vocab = WordVocab(word_itos)

    def process_one(self,item):
        chars = np.array(self.char_vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.word_vocab.numericalize(item[1]), dtype=np.int64)
        return [chars,words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])
        
        
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]
            
    def __init__(self, items:Iterator, pad_idx=0, **kwargs):
        super().__init__(items, **kwargs)
        self.char_vocab = CharVocab(itos)
        self.word_vocab = WordVocab(word_itos)
        
        self.pad_idx = pad_idx
        self.copy_new += ['char_vocab', 'word_vocab', 'pad_idx']

    def get(self, i):
        o = super().get(i)
        chars = Text(o[0], self.char_vocab.textify(o[0], ''))
        words = Text(o[1], self.word_vocab.textify(o[1], ''))
        return [chars,words]
    
    def reconstruct(self, t:Tensor):
        c,w = t
        return [self.reconstruct_one(c, self.char_vocab),self.reconstruct_one(w, self.word_vocab)]
    
    def reconstruct_one(self, x, vocab):
        nonzero_idxs = (x != self.pad_idx).nonzero()
        idx_min = 0  #(x != self.pad_idx).nonzero().min()
        idx_max = nonzero_idxs.max() if len(nonzero_idxs) > 0 else 0
        return Text(x[idx_min:idx_max+1], vocab.textify(x[idx_min:idx_max+1]))

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

# Transformer Modules

In [None]:
# LayerNorm = nn.LayerNorm
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training
# LayerNorm = partial(nn.BatchNorm2d, eps=1e-4)

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # acts as a weak LM
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Word Arch (bert_tok)

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      
        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.w_decoder = decoder
        self.w_embed = embed
        
        self.generator = generator
    
    def forward(self, src, tgt):
        tgt = rshift(tgt, 101).long()
        mask = parallelogram_mask(tgt.size(-1), 20)

        feats = self.encode(src)
        return self.w_decoder(self.w_embed(tgt), feats, mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #increases gradients on weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(65,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        outs = self.transformer(feats, tgt)
        return self.transformer.generate(outs)

In [None]:
def cer(preds, targs, fn):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = str(fn(res[i]))   #.replace(' ', '')
        t = str(fn(targs[i])) #.replace(' ', '')
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)], callback_fns=[TeacherForce])
    return learn

In [None]:
# def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
#     itos = data.vocab.itos
#     img_encoder = ResnetBase(em_sz)
#     transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
#     net = Img2Seq(img_encoder, transformer)
#     learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
#                     metrics=[CER(data.y.reconstruct)],
#                     callback_fns=[TeacherForce, BnFreeze, partial(AccumulateScheduler, n_step=3)])
#     return learn

In [None]:
learn = make_learner(data, 512, 256, N=6, drops=0.1, heads=8)

# Char Arch (bert_tok)

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      
        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = decoder
        self.c_embed = embed
        
        self.generator = generator
    
    def forward(self, src, tgt):
        tgt = rshift(tgt, 101).long()
        mask = parallelogram_mask(tgt.size(-1), 25)

        feats = self.encode(src)
        return self.c_decoder(self.c_embed(tgt), feats, mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #increases gradients on weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(100,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        outs = self.transformer(feats, tgt)
        return self.transformer.generate(outs)

In [None]:
def cer(preds, targs, fn):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = str(fn(res[i]))   #.replace(' ', '')
        t = str(fn(targs[i])) #.replace(' ', '')
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)],
                    callback_fns=[TeacherForce, BnFreeze, partial(AccumulateScheduler, n_step=10)]
                   )
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

# Char Arch (itos)

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      
        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = decoder
        self.c_embed = embed
        
        self.generator = generator
    
    def forward(self, src, tgt):
        tgt = rshift(tgt).long()
        mask = parallelogram_mask(tgt.size(-1), 25)

        feats = self.encode(src)
        return self.c_decoder(self.c_embed(tgt), feats, mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #increases gradients on weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(100,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        outs = self.transformer(feats, tgt)
        return self.transformer.generate(outs)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
                    metrics=[CER(data.y.reconstruct)], callback_fns=[TeacherForce])
    return learn

In [None]:
# def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
#     itos = data.vocab.itos
#     img_encoder = ResnetBase(em_sz)
#     transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
#     net = Img2Seq(img_encoder, transformer)
#     learn = Learner(data, net, loss_func=LabelSmoothing(smoothing),
#                     metrics=[CER(data.y.reconstruct)],
#                     callback_fns=[TeacherForce, BnFreeze, partial(AccumulateScheduler, n_step=3)])
#     return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

# Combo Arch (itos)

In [None]:
length_results = []

class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
                
        row_t = self.rows(rows)
        col_t = self.cols(cols)
        
        length_results.append(cols.max().item())  # this is only for verification of lengths on cuda error
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    

# c_ler = ler[::2]
# w_ler = ler[1::2]
# max(c_ler), max(w_ler)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, src_adapt, c_emb, w_emb, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        
        self.c_embed = c_emb
        self.w_embed = w_emb
        
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.c_embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.w_embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt).long()
        w_tgt = rshift(w_tgt).long()
        
        c_mask = parallelogram_mask(c_tgt.size(-1), 25)
        w_mask = parallelogram_mask(w_tgt.size(-1), 10)
#         c_mask = subsequent_mask(c_tgt.size(-1)) 
#         w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    c_col_emb = nn.Embedding(100, d_model//2, 0)
    w_col_emb = nn.Embedding(60, d_model//2, 0)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8))
            # .mul_(8) increases the gradients on these weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, c_col_emb),   
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, w_col_emb),
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn, itos):
        super().__init__(learn)
        self.itos = itos

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.itos)
        w_error,_    = cer(w_out, w_targ, self.itos)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing), callback_fns=[TeacherForce])
    learn.callbacks.append(MultiCER(learn, itos))
    return learn

# Fixup Combo Arch (bert_tok)

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class FixupBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(FixupBasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bias1a = nn.Parameter(torch.zeros(1))
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bias1b = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.bias2a = nn.Parameter(torch.zeros(1))
        self.conv2 = conv3x3(planes, planes)
        self.scale = nn.Parameter(torch.ones(1))
        self.bias2b = nn.Parameter(torch.zeros(1))
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x + self.bias1a)
        out = self.relu(out + self.bias1b)

        out = self.conv2(out + self.bias2a)
        out = out * self.scale + self.bias2b

        if self.downsample is not None:
            identity = self.downsample(x + self.bias1a)

        out += identity
        out = self.relu(out)

        return out
    
class FixupBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(FixupBottleneck, self).__init__()
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.bias1a = nn.Parameter(torch.zeros(1))
        self.conv1 = conv1x1(inplanes, planes)
        self.bias1b = nn.Parameter(torch.zeros(1))
        self.bias2a = nn.Parameter(torch.zeros(1))
        self.conv2 = conv3x3(planes, planes, stride)
        self.bias2b = nn.Parameter(torch.zeros(1))
        self.bias3a = nn.Parameter(torch.zeros(1))
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.scale = nn.Parameter(torch.ones(1))
        self.bias3b = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x + self.bias1a)
        out = self.relu(out + self.bias1b)

        out = self.conv2(out + self.bias2a)
        out = self.relu(out + self.bias2b)

        out = self.conv3(out + self.bias3a)
        out = out * self.scale + self.bias3b

        if self.downsample is not None:
            identity = self.downsample(x + self.bias1a)

        out += identity
        out = self.relu(out)

        return out

In [None]:
class FixupResNet(nn.Module):
    def __init__(self, block, layers, d_model=512):
        super().__init__()
        self.num_layers = sum(layers)
        self.inplanes = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bias1 = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
        #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.maxpool = nn.AdaptiveMaxPool2d((16,None))
        self.bias2 = nn.Parameter(torch.zeros(1))
        self.fc = nn.Linear(512 * block.expansion, d_model)
        self.apply_init()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = conv1x1(self.inplanes, planes * block.expansion, stride)

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x + self.bias1)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.maxpool(x) #self.avgpool(x)
        #x = x.view(x.size(0), -1)
        x = self.fc(x.flatten(2,3).permute(0,2,1) + self.bias2)
        return x
    
    def apply_init(self):
        for m in self.modules():
            if isinstance(m, FixupBasicBlock):
                nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(
                    2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
                nn.init.constant_(m.conv2.weight, 0)
                if m.downsample is not None:
                    nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(
                        2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
            elif isinstance(m, FixupBottleneck):
                nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(
                    2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25))
                nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(
                    2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25))
                nn.init.constant_(m.conv3.weight, 0)
                if m.downsample is not None:
                    nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(
                        2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                nn.init.constant_(m.bias, 0)              

def fixup_resnet34(**kwargs):
    """Constructs a Fixup-ResNet-34 model.
    """
    model = FixupResNet(FixupBasicBlock, [3, 4, 6, 3], **kwargs)
    return model

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      

        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, c_emb, w_emb, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        
        self.c_embed = c_emb
        self.w_embed = w_emb
        
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.c_embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.w_embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(src)
    
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt, 101).long()
        w_tgt = rshift(w_tgt, 101).long()
        
        c_mask = parallelogram_mask(c_tgt.size(-1), 25)
        w_mask = parallelogram_mask(w_tgt.size(-1), 10)
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(100, d_model//2, 0)),  #char
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(60,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.recon = learn.data.y.reconstruct_one

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.recon)
        w_error,_    = cer(w_out, w_targ, self.recon)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
class MultiAccumulateScheduler(AccumulateScheduler):
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "accumulate samples and batches"
        self.acc_samples += last_input[0].shape[0]
        self.acc_batches += 1

In [None]:
def make_learner(data, d_model, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = fixup_resnet34(d_model=d_model) #ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing),
                    callback_fns=[TeacherForce, MultiCER, partial(MultiAccumulateScheduler, n_step=5)])
    return learn

In [None]:
learn = make_learner(data, 512, N=4, drops=0.1, heads=8)

# Integrated Combo Arch (bert_tok)

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 30522 #4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      

        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    
    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
                
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, src_adapt, c_emb, w_emb, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        
        self.c_embed = c_emb
        self.w_embed = w_emb
        
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.c_embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.w_embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt, 101).long()
        w_tgt = rshift(w_tgt, 101).long()
        
        c_mask = parallelogram_mask(c_tgt.size(-1), 25)
        w_mask = parallelogram_mask(w_tgt.size(-1), 10)
#         c_mask = subsequent_mask(c_tgt.size(-1)) 
#         w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #increases gradients on weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(100, d_model//2, 0)),  #char
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, nn.Embedding(60,  d_model//2, 0)),  #word
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.recon = learn.data.y.reconstruct_one

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.recon)
        w_error,_    = cer(w_out, w_targ, self.recon)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing),
                    callback_fns=[TeacherForce, MultiCER])
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# word char (separate vocabs) Arch

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, vocab, num_cols, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = nn.Embedding(15, d_model//2, 0)
        self.cols = nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = self.encode_spatial_positions(x)      

        row_t = self.rows(rows)            
        col_t = self.cols(torch.clamp(cols, max=self.cols.num_embeddings-1))  # clamp to max column value
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)

    def encode_spatial_positions(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        return rows,cols

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, src_adapt, c_emb, w_emb, c_gen, w_gen):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.w_decoder = w_dec
        self.c_decoder = c_dec
        
        self.w_embed = w_emb
        self.c_embed = c_emb
        self.w_generator = w_gen
        self.c_generator = c_gen
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.c_embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.w_embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, c_outs, w_outs):
        return self.c_generator(c_outs), self.w_generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt).long()
        w_tgt = rshift(w_tgt).long()
        
        c_mask = parallelogram_mask(c_tgt.size(-1), 25)
        w_mask = parallelogram_mask(w_tgt.size(-1), 10)
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(c_vocab, w_vocab, d_model, em_sz, N=4, drops=0.1, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential( nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) ),  # multiplies gradients by 8!
        LearnedPositionalEmbeddings(d_model, c_vocab, 100),   
        LearnedPositionalEmbeddings(d_model, w_vocab, 60),
        nn.Linear(d_model, c_vocab),
        nn.Linear(d_model, w_vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class CERMetric(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn, c_itos, w_itos):
        super().__init__(learn)
        self.c_itos = c_itos
        self.w_itos = w_itos

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.c_itos)
        w_error,_    = cer(w_out, w_targ, self.w_itos)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    c_vocab = data.train_dl.dl.dataset.y.char_vocab.itos
    w_vocab = data.train_dl.dl.dataset.y.word_vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(c_vocab), len(w_vocab), d_model, em_sz, N, drops, attn_type, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing), callback_fns=[TeacherForce])
    learn.callbacks.append(CERMetric(learn, c_vocab, w_vocab))
    return learn

# word char (combo vocab) w/ DistilBert LM

In [None]:
length_results = []

class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, v_embed, row_embed, col_embed, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.d_model = d_model
        
        self.embed = v_embed #nn.Embedding(vocab, d_model, 0)
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = row_embed #nn.Embedding(15, d_model//2, 0)
        self.cols = col_embed #nn.Embedding(num_cols, d_model//2, 0)

    def forward(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            nls = torch.nonzero(batch==self.nl_tok).flatten()
            last = torch.nonzero(batch).flatten()[-1][None]
            splits = torch.cat([nls,last])

            p=0
            for i,n in enumerate(splits, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n-p+2)
                p = n+1
        
        row_t = self.rows(rows)
        col_t = self.cols(cols)
        
        length_results.append(cols.max().item())  # this is only for verification of lengths on cuda error
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)
    

# c_ler = ler[::2]
# w_ler = ler[1::2]
# max(c_ler), max(w_ler)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, src_adapt, c_emb, w_emb, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        
        self.c_embed = c_emb
        self.w_embed = w_emb
        
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.c_embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.w_embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt,_ = mix_tgt
        c_tgt = rshift(c_tgt).long()
        w_tgt = rshift(w_tgt).long()
        
        c_mask = parallelogram_mask(c_tgt.size(-1), 25)
        w_mask = parallelogram_mask(w_tgt.size(-1), 10)
#         c_mask = subsequent_mask(c_tgt.size(-1)) 
#         w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)

    ff = PositionwiseFeedForward(d_model, drops)
    #pos = PositionalEncoding(d_model, drops, 2000)
    
    v_embed = nn.Embedding(vocab, d_model, 0)
    row_emb = nn.Embedding(15, d_model//2, 0)
    c_col_emb = nn.Embedding(100, d_model//2, 0)
    w_col_emb = nn.Embedding(60, d_model//2, 0)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8))
            # .mul_(8) increases the gradients on these weights by 8!
        ),
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, c_col_emb),   
        LearnedPositionalEmbeddings(d_model, v_embed, row_emb, w_col_emb),
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class DistilBertMLM(nn.Module):
    def __init__(self, vocab, name='distilbert-base-uncased'):
        super().__init__()
        self.tokenizer = DistilBertTokenizer.from_pretrained(name)
        self.model = DistilBertForMaskedLM.from_pretrained(name)
        self.vocab = vocab
        self.model.train()
        
    def forward(self, t:tensor):
        preds = torch.argmax(t, dim=-1)
        lbls = []
        for p in preds:
            words = decode_spec_tokens(self.vocab.textify(p))
            lbls.append(np.array(self.tokenizer.encode(words, add_special_tokens=True)))
        input_ids = c_pad(np.array(lbls))
        return self.model(input_ids.to(device=device))[0]

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, lm):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.lm = lm
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        bert_outs = self.lm(outs[1])
        return (*outs, bert_outs)

In [None]:
class CERMetric(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.itos = learn.data.vocab.itos
        self.b_itos = learn.data.bert_vocab.itos

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word', 'bert'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        (c_out, w_out),b_out = last_output
        c_targ, w_targ,b_targ = last_target
        c_error,size = cer(c_out, c_targ, self.itos)
        w_error,_    = cer(w_out, w_targ, self.itos)
        b_error,_    = cer(b_out, b_targ, self.b_itos)
        self.c_errors += c_error
        self.w_errors += w_error
        self.b_errors += b_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.b_errors, self.total = 0, 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        mets = [self.c_errors/self.total, self.w_errors/self.total, self.b_errors/self.total]
        return add_metrics(last_metrics, mets)

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ, b_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        bl = loss(pred[2], b_targ)
        print(f'char loss: {cl}  word_loss: {wl}  bert_loss:{bl}')
        return cl + wl + bl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    vocab = data.vocab
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(vocab.itos), d_model, em_sz, N, drops, attn_type, heads)
    lm = DistilBertMLM(vocab)
    net = Img2Seq(img_encoder, transformer, lm)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing), callback_fns=[TeacherForce])
    learn.callbacks.append(CERMetric(learn))
    return learn

# Feature Pyramid Network

In [None]:
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
import torchvision

In [None]:
# https://github.com/kuangliu/pytorch-fpn/blob/master/fpn.py

class ResNetFPN(ResNet):        
    def __init__(self, block, layers, em_sz):
        super().__init__(block, layers)
        
        self.expansion = block.expansion
        self.avgpool = None
        self.fc = None

        # Smooth layers - reduce aliasing effects from upsampling
        self.smooth1 = nn.Conv2d(em_sz, em_sz, kernel_size=3, stride=1, padding=1)
        self.smooth2 = nn.Conv2d(em_sz, em_sz, kernel_size=3, stride=1, padding=1)
        #self.smooth3 = nn.Conv2d(em_sz, em_sz, kernel_size=3, stride=1, padding=1)

        # Lateral layers - reduce dimensionality
        self.latlayer0 = nn.Conv2d(512*self.expansion, em_sz, kernel_size=1, stride=1, padding=0)
        self.latlayer1 = nn.Conv2d(256*self.expansion, em_sz, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(128*self.expansion, em_sz, kernel_size=1, stride=1, padding=0)
        #self.latlayer3 = nn.Conv2d( 64*self.expansion, em_sz, kernel_size=1, stride=1, padding=0)
            
    def _upsample_add(self, x, y):
        _,_,H,W = y.size()
        return F.interpolate(x, size=(H,W), mode='bilinear', align_corners=False) + y
        #bilinear - supports arbitrary output sizes

    def forward(self, x):
        # Bottom-up - resnet backbone
        c1 = self.relu(self.bn1(self.conv1(x)))
        c1 = self.maxpool(c1)
        c2 = self.layer1(c1)  #128
        c3 = self.layer2(c2)  #64
        c4 = self.layer3(c3)  #32
        c5 = self.layer4(c4)  #16
        
        # Top-down
        p5 = self.latlayer0(c5)
        p4 = self._upsample_add(p5, self.latlayer1(c4))
        p3 = self._upsample_add(p4, self.latlayer2(c3))
        #p2 = self._upsample_add(p3, self.latlayer3(c2))
        
        # Smoothing
        p4 = self.smooth1(p4)
        p3 = self.smooth2(p3)
        #p2 = self.smooth3(p2)

        cat = torch.cat([p3.flatten(2,3),p4.flatten(2,3),p5.flatten(2,3)], dim=-1)
        cat = cat.mul_(2)
        return cat.permute(0,2,1)

In [None]:
def FPN(arch='resnet34', em_sz=256):
    if arch is 'resnet34':
        block = BasicBlock
        url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
    elif arch is 'resnet50':
        block = Bottleneck
        url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    else:
        raise 'arch must be either resnet34 or resnet50'
        
    # load pretrained -- resnet34: BasicBlock; resnet50: Bottleneck
    model = ResNetFPN(block, [3,4,6,3], em_sz)
    sd = torchvision.models.utils.load_state_dict_from_url(url, progress=True)
    model.load_state_dict(sd, strict=False)
    return model

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, src_adapt, embed, generator):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.c_decoder = c_dec
        self.w_decoder = w_dec
        
        self.embed = embed
        
        self.generator = generator
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, c_outs, w_outs):
        return self.generator(c_outs), self.generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt, 101).long()
        w_tgt = rshift(w_tgt, 101).long()

        c_mask = subsequent_mask(c_tgt.size(-1)) 
        w_mask = subsequent_mask(w_tgt.size(-1)) 
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, heads=8):
    c = deepcopy
    
    attn = MultiHeadedAttention(d_model, heads)
    ff = PositionwiseFeedForward(d_model, drops)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), 2),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Linear(em_sz, d_model),
        nn.Sequential( Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000) ),
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class MultiCER(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn):
        super().__init__(learn)
        self.recon = learn.data.y.reconstruct_one

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.recon)
        w_error,_    = cer(w_out, w_targ, self.recon)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, heads=8, smoothing=0.1):
    itos = data.vocab.itos
    img_encoder = FPN('resnet34', em_sz)
    transformer = make_full_model(len(itos), d_model, em_sz, N, drops, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing), callback_fns=[TeacherForce, MultiCER])
    return learn

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, heads=8)

In [None]:
# number of trainable parameters
sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

# Multi-Resolution Arch

## Image Encoders

### Resnet34

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

### Xception

In [None]:
from fastai.vision.models.cadene_models import xception_cadene

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,
                               padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

In [None]:
class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(nn.ReLU(inplace=True))
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(nn.ReLU(inplace=True))
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(nn.ReLU(inplace=True))
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {728:-6, 1024:-5, 2048:-1}
        #728, 4, 32
        #1024, 2, 16
        #2048, 2, 16
        s = slices[em_sz]

        net = xception_cadene(True)
        self.base = net[:s]
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

### SEResnet50

In [None]:
from fastai.vision.models.cadene_models import se_resnet50

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {512:-4, 1024:-3, 2048:-2}
        #512, 8, 64
        #1024, 4, 32
        #2048, 2, 16
        s = slices[em_sz]

        net = se_resnet50(True)
        modules = list(net.children())[:s]    #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

## Arch

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #16x16
        self.c1 = conv_layer(d_model, d_model, stride=2)     #8x8
        self.c2 = conv_layer(d_model, d_model, stride=2)     #4x4
        self.c3 = conv_layer(d_model, d_model, stride=2)     #2x2
                
    def forward(self, x):
        x = self.base(x)
        x1 = self.c1(x)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        cat = torch.cat([x.flatten(2,3),x1.flatten(2,3),x2.flatten(2,3),x3.flatten(2,3)], dim=-1)
        return cat.permute(0,2,1)

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_adapt = src_adapt
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), self.embed(tgt), tgt_mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(tgt, src, tgt_mask)
    
    def embed(self, tgt):
        return self.tgt_embed(tgt)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)
    pos = PositionalEncoding(d_model, drops, 2000)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #, pos
        ),
        nn.Sequential(
            Embeddings(d_model, vocab), pos
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
#                     mask = parallelogram_mask(tgt.size(-1), 10)

                    t = self.transformer.embed(tgt)
    
#                     dec_outs = self.transformer.decode(feats, t[:,-11:])
                    dec_outs = self.transformer.decode(feats, t, mask)

                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def init_params(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz, d_model)
    voc_len = len(data.train_ds.y.vocab.itos)
    transformer = make_full_model(len(itos), d_model, em_sz, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER(data.y.reconstruct)], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0, attn_type='multi', heads=8)
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
learn.load('combo_512_mix_res9')
None

# Adjust State Dict and Split

## Add LM to model state_dict

In [None]:
sd = torch.load(PATH/'models/combo<6_char_word2.pth', map_location=device)

In [None]:
from collections import OrderedDict
model_sd = OrderedDict()

for k,v in sd['model'].items():
    if k.startswith('transformer.encoder'):
        c_k = k.replace('encoder', 'c_encoder')
        w_k = k.replace('encoder', 'w_encoder')
        model_sd[c_k] = v
        model_sd[w_k] = v
    else:
        model_sd[k] = v

In [None]:
sd['model']['transformer.generator.weight'] = sd['model']['transformer.c_generator.weight']
sd['model']['transformer.generator.bias']  = sd['model']['transformer.c_generator.bias']

In [None]:
learn.model.load_state_dict(sd['model'], strict=False)

## load and split learner

In [None]:
learn.split(lambda m: [m.transformer]); None
len(learn.layer_groups)

In [None]:
learn.freeze_to(1)
learn.model.transformer.generator.weight.requires_grad #.conv[0].weight.requires_grad

In [None]:
learn.unfreeze()

# Experimentation

In [None]:
learn.load('combo_full_bert_tok'); None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
lrs = slice(1e-4,1e-3)
# lrs = 1e-3

In [None]:
learn.fit_one_cycle(3, max_lr=lrs, callbacks=[SaveModelCallback(learn, name='full_combo_char_itos')])
# CHARS / WORDS
# d_model:512, em_sz:256, resnet34, 5cycle(1e-3), N:4, drops:0.1, multi:8

# 4.805066	4.812465	0.051946	07:56   bert_tok char only, p-mask(25)   'sm_char_pmask2'
# chars:     0.56753   .07858

# fit (not one_cycle); bert_tok char only
# 36.014137	29.500330	0.354282	07:09   freeze_to(-1); only tfmr 1e-3
# 6.228029	5.532068	0.062850	07:43   2nd - unfreeze; 4(1e-4,1e-3)
# 3.253963	3.492219	0.039213	07:44   3rd - 2(1e-5,1e-4)    'sm_char_pmask'

# sm dataset, sz:256, bs:100, sl:100, 5cycle(1e-3)
# 2.858166	3.149298	0.034373	07:34   char only (itos), p-mask(25)   'sm_char_itos'
# greedy:    0.22749   .06057

# sm dataset, sz:256, bs:80, seq_len:100, 5cycle(1e-3), word_itos:60k
# 4.847745	5.941104	0.027651	0.029362	08:52   chars/words independent; sharing img features  'sm_char_word'
# chars:    0.346770    0.03947     3.0s
# words:    0.903070    0.03809
# 4.197776	5.396721	0.024970	0.022618	14:57   2d pos, parallel_mask(25,10)  'sm_char_word2'
# chars:    0.336420    0.03665     2.8s
# words:    0.019130    0.02098

# bs: 60, 2d pos, parallel_mask(25,10)
# 6.224502	7.240156	0.025720	0.042237	12:05   combo: itos  'sm_combo'
# chars:    2.10703   .05335
# words:    0.66827   .02541
# 6.176572	7.659454	0.038917	0.039825	11:38   combo: bert_tok  'sm_combo_bert_tok'
# chars:    0.89475   .04761
# words:    0.98421   .02428
# 6.128839	7.527630	0.042009	0.048798	14:26   combo: bert_tok; w/ STN before ResnetBase
# chars:    0.64202   .06120
# words:    0.75339   .02804

# bert_tok(words), bs: 80
# 12.985831	13.813116	0.381876	07:27   2d pos, parallel_mask(10)  'sm_word_bert'   
# words:    1.171370    0.54847

# combo_145<=6, sz:512, bs:15, seq_len:600/200
# 40.977287	36.501549	0.039365	0.133388	34:21   "", train from scratch no preloading  'combo<6_char_word'
# 25.210848	24.398312	0.024259	0.088754	38:35   "", preload 'sm_char_word2'    'combo<6_char_word2'
# chars:    18.64750    0.01549
# words:    7.597820    0.07416

# 50.447319	46.204708	0.045057	0.114878	40:11   preload sm_combo_bert_tok, 5(1e-4),  'combo<6_bert_tok'
# chars:    32.05150    0.03646
# words:    30.36220    0.11244
# --no improvement--   2nd run; 1(1e-5); Freeze img_enc, BnFreeze, AccumulateScheduler(3)
# --no improvement--   2nd run; slice(3e-6,2e-5)

# combo_145>=6, sz:512, bs:5, seq_len:750/250
#  ????     178.2133    0.03660     0.113363            preload combo<6, 3(5e-5)    'combo>6_bert_tok'
# chars:    716.0820    0.06242
# words:    324.4240    0.11281
#  ????     171.1198    0.03438     0.108465           'combo>6_bert_tok2'
# chars:    494.2160    0.05982
# words:    330.0630    0.20758

# word decoder only; bs:15; lrs: slice(6e-6,2e-4)
# 78.159187	66.368530	0.063820	59:51   preload combo>6_bert_tok
# words:    115.2590    0.09386
# 63.500256	54.264275	0.052015	59:35   2nd run   'combo>6_bert_tok_word'
# words:    120.4400    0.08898


# combo full - char itos only, sz:512, bs:20, sl: 700, 3cycle(1e-4,1e-3)
# 17.519344	12.332630	0.012053	1:33:27     preload combo<6_char_word2;  'full_combo_char_itos'
#  chars:    14.1142   .03507
#     pg:    109.775   .06105
# upload:    81.1906   .43663

# combo test - char itos only, sz:512, bs:20, sl:700, 5cycle(1e-5,1e-4)
# 11.622255	8.140586	0.049724	38:43      'test_combo_char_itos'
#  chars:    10.6597   .10455
#     pg:    118.629   .05764
# upload:    83.0442   .28538


# combo full - word decoder only, sz:512, bs:15, word_len:300, lrs: slice(1e-6,2e-5)
# 48.172344	39.704689	0.058319	1:31:44  preload combo>6_bert_tok_word    'combo_full_bert_tok_word'
# words:    49.00830    0.05337
# 50.854053	39.678852	0.058055	1:24:53  2nd run    'combo_full_bert_tok_word2'
# words:    50.46420    0.05224
#    pg:    73.38090    0.06796

# 51.945342 40.016457   0.058763    1:27:13   add 2 additional layers; 1e-6  --failure!!

# combo full - word/char, sz:512, bs:8, 1cycle(2e-5)
# 134.963303	110.691048	0.047529	0.128654	3:29:36   preload combo>6_bert_tok2   'combo_full_bert_tok'
# chars:    96.1589   .07463
# words:    41.3268   .08943

# combo_test - word/char, sz:512, bs:10, 1cycle(1e-4)
# 52.687794	39.834663	0.137030	0.224802	1:05:14   preload combo_full_bert_tok    'test_combo_bert_tok'

In [None]:
learn.save('sm_char_itos')

## Save Outputs for LM

In [None]:
class SaveOutput(LearnerCallback):
    def __init__(self, learn:Learner, filepath):
        super().__init__(learn)
        self.filepath = filepath
        self.recon_fn = self.learn.data.y.reconstruct

    def on_epoch_begin(self, **kwargs):
        self.res = []

    def on_batch_end(self, last_output, last_target, **kwargs):
        preds = torch.argmax(last_output, dim=-1)
        for p,t in zip(preds,last_target):
            p_str = str(self.recon_fn(p))
            t_str = str(self.recon_fn(t))
            self.res.append([t_str,p_str])

    def on_epoch_end(self, **kwargs):
        df = pd.DataFrame(self.res, columns=['targets', 'preds'])
        df.to_csv(self.filepath, index=False)

In [None]:
learn.validate(dl=data.train_dl,
               callbacks=[TeacherForce(learn), SaveOutput(learn, filepath=PATH/'output.csv')],
               metrics=[CER(data.y.reconstruct)])

In [None]:
df = pd.read_csv(PATH/'output.csv'); df

# Char/Word Greedy results

In [None]:
def greedy_decode(src, model, seq_len, kind='char', bos_tok=1):
    model.eval()
    tfmr = model.transformer
    img_enc = model.img_enc
    
    decoder = tfmr.c_decoder if kind=='char' else tfmr.w_decoder
    embed = tfmr.c_embed if kind=='char' else tfmr.w_embed
    p_num = 25 if kind=='char' else 10
    
    with torch.no_grad():
        feats = tfmr.encode(img_enc(src))
        bs = src.size(0)
        tgt = torch.zeros((bs,1), dtype=torch.long, device=device) + bos_tok

        res = []
        for i in progress_bar(range(seq_len)):
#             mask = subsequent_mask(tgt.size(-1))
            mask = parallelogram_mask(tgt.size(-1), p_num)
            
            dec_outs = decoder(embed(tgt)[-p_num:], feats, mask)
            prob = tfmr.generator(dec_outs[:,-1])
            res.append(prob)
            pred = torch.argmax(prob, dim=-1, keepdim=True)
            if (pred==0).all(): break
            tgt = torch.cat([tgt,pred], dim=-1)
        out = torch.stack(res).transpose(1,0).contiguous()

        return out

In [None]:
vdl = iter(learn.data.valid_dl)

In [None]:
x,y = next(vdl)

### Single Word

In [None]:
g_preds = greedy_decode(x, learn.model, word_len, 'word', 101)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, data.y.reconstruct)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

### Single Char

In [None]:
g_preds = greedy_decode(x, learn.model, seq_len, 'char')
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, data.y.reconstruct)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

### Chars

In [None]:
g_preds = greedy_decode(x, learn.model, seq_len, 'char', 101)
g_res = torch.argmax(g_preds, dim=-1)
loss_func = LabelSmoothing()
g = [loss_func(g_preds, y[0]).item()/bs, cer(g_preds, y[0], data.y.reconstruct_one)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct_one(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

### Words

In [None]:
g_preds = greedy_decode(x, learn.model, word_len, 'word', 101)
g_res = torch.argmax(g_preds, dim=-1)
loss_func = LabelSmoothing()
g = [loss_func(g_preds, y[1]).item()/bs, cer(g_preds, y[1], data.y.reconstruct_one)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct_one(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

# Image PreProcessing

In [None]:
class ImageProcessor(nn.Module):
    def __init__(self, d_model, max_lines=14):
        super().__init__()

        net = models.resnet34(True)
        modules = list(net.children())[:-2]
        
        self.base = nn.Sequential(*modules)
        self.head = create_head(1024, max_lines)
        
    def forward(self, x):
        with torch.no_grad():
            base = self.base(x)
            n_lines = self.head(base)
            
            b,c,h,w = base.shape
            #res = torch.zeros((b,c,longest))
        
            preds = torch.argmax(n_lines, dim=-1)
            lines = torch.unique(preds)
            
            longest = lines[-1]*w
            res = torch.zeros((b,c,longest), device=device)
            for l in lines:
                idxs = (preds==l).nonzero()[:,0]
                outs = F.adaptive_max_pool2d(base[idxs], (l,None)).flatten(2,3)
                # could add positional encoding here
                pad = longest - outs.size(-1)
                outs = F.pad(outs, (0,pad), "constant", 0)
                res[idxs] = outs
                
        return res.permute(0,2,1)

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)
    pos = PositionalEncoding(d_model, drops, 2000)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), pos
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
#                     mask = parallelogram_mask(tgt.size(-1), 10)

                    t = self.transformer.embed(tgt)
    
#                     dec_outs = self.transformer.decode(feats, t[:,-11:])
                    dec_outs = self.transformer.decode(feats, t, mask)

                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def init_params(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
def make_learner(data, d_model, N=4, drops=0.2, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ImageProcessor(d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER()], callback_fns=[TeacherForce])

## Combine/load state_dict 

In [None]:
learn = make_learner(data, 512, N=6, drops=0, attn_type='multi', heads=8)

In [None]:
from collections import OrderedDict
new_sd = OrderedDict()

line_sd = torch.load(PATH/'models/line_98acc.pth', map_location=device)

for k, v in line_sd['model'].items():
    if k.startswith('0'):
        name = re.sub("^(0)",'img_enc.base',k)
        new_sd[name] = v
    if k.startswith('1'):
        name = re.sub("^(1)",'img_enc.head',k)
        new_sd[name] = v

In [None]:
learn.model.load_state_dict(new_sd, strict=False)

In [None]:
# sd = torch.load(PATH/'models/combo_512_mix_res9.pth', map_location=device)

# for k, v in sd['model'].items():
#     if not k.startswith('img_enc'):
#         new_sd[k] = v

# Full Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        return self.linear(x) * 8

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    #transformer.generator.weight = transformer.tgt_embed[0].lut.weight
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER()], callback_fns=[TeacherForce])

# TransformerXL

In [None]:
class PositionalEncoding(Module):
    "Encode the position with a sinusoid."
    def __init__(self, d_model:int):
        self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d_model, 2.) / d_model)))

    def forward(self, pos:Tensor):
        inp = torch.ger(pos, self.freq)  #(outer) matrix product of 2 vectors
        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
        return enc

In [None]:
class FeedForward(Module):
    def __init__(self, d_model:int, drops=0.1):
        self.core = nn.Sequential(
            nn.Linear(d_model, d_model*4), nn.ReLU(inplace=True), nn.Dropout(drops),
            nn.Linear(d_model*4, d_model), nn.Dropout(drops)
        )
        self.ln = nn.LayerNorm(d_model)
        
    def forward(self, x):
        out = self.core(x)
        return self.ln(x + out)

In [None]:
class MultiHeadRelativeAttention(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1, bias=False):
        d_head = d_model//n_heads
        self.n_heads, self.d_head = n_heads, d_head
        self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
        self.drop_att,self.drop_res = nn.Dropout(drops),nn.Dropout(drops)
        self.ln = nn.LayerNorm(d_model)
        self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ln(x + self.drop_res(self.out(self._mhra(x, mask=mask, **kwargs))))

    def _mhra(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
        #Notations from the paper:
        #x: input, r: vector of relative distance between two elements
        #u,v: learnable parameters of the model common between layers
        #mask: to avoid cheating, mem: previous hidden states
                
        #x: [bs, sl, d_model]
        #u/v: [n_heads, 1, d_head]
        #r: [sl, d_model]
        #mem: 1st:[0]; 2nd:[bs, sl, d_model]; nth: sl*i-1 up to mem_len
        bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
        context = x if mem is None else torch.cat([mem, x], dim=1)
        wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
        wq = wq[:,-x_len:]
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        # [bs, sl, n_heads, d_head]
        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)   #wk: transposed(-2,-1)
        wkr = self.r_attn(r)
        wkr = wkr.view(seq_len, self.n_heads, self.d_head)
        wkr = wkr.permute(1,2,0)  #transposed ala wk w/out bs
        #### compute attention score (AC is (a) + (c) and BD is (b) + (d) in the paper)
        AC = torch.matmul(wq+u,wk)
        BD = _line_shift(torch.matmul(wq+v, wkr))
        attn_score = (AC + BD).div_(self.d_head ** 0.5)  #scale
        if mask is not None:
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
        attn_vec = torch.matmul(attn_prob, wv)
        return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
    
def _left_shift(x:Tensor):
    "Shift the line i of `x` by p-i elements to the left."
    bs,nh,n,p = x.size()
    x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)
    x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)
    return x_shift

In [None]:
class DecoderLayerXL(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1):
        self.mhra = MultiHeadRelativeAttention(n_heads, d_model, drops=drops)
        self.ff   = FeedForward(d_model, drops=drops)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ff(self.mhra(x, mask=mask, **kwargs))

In [None]:
class TransformerXL(Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, vocab_sz:int, d_model:int, n_layers:int, n_heads:int, drops:float=0.1, mem_len:int=150):
        d_head = d_model//n_heads
        self.emb = nn.Embedding(vocab_sz, d_model)
        self.drop_emb = nn.Dropout(drops)
        self.pos_enc = PositionalEncoding(d_model)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head))
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head))
        self.layers = nn.ModuleList([DecoderLayerXL(n_heads, d_model, drops) for k in range(n_layers)])
        self.init = False

    def forward(self, x):
        if not self.init:
            self.reset()
            self.init = True
        bs,x_len = x.size()
        inp = self.drop_emb(self.emb(x)) #.mul_(self.d_model ** 0.5)
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).bool()[None,None]
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids = []
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i]
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        self._update_mems(hids)
        return self.hidden,[core_out]
    
    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()

    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def select_hidden(self, idxs):
        # used in beam search only...
        self.hidden = [h[idxs] for h in self.hidden]

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(tgt, src)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, lm, layer, N):
        super(Decoder, self).__init__()
        self.lm_layer = lm
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src):
        x = self.lm_layer(x)
        
        for layer in self.layers:
            x = layer(x, src)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src):
        x = self.sublayer[0](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    lm = TransformerXL(vocab, d_model, n_layers=10, n_heads=8, drops=0.1)
    attn = MultiHeadedAttention(d_model, attn_heads) 
    ff   = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(lm, DecoderLayer(d_model, c(attn), c(ff), drops), N),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    dec_outs = self.transformer.decode(feats, tgt)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt)
            out = self.transformer.generate(dec_outs)
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

# w/ Transformer LM

In [None]:
class DecoderLayerWithLM(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = clones(feed_forward, 3)
        self.sublayer = clones(SublayerConnection(size, dropout), 5)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        # lm
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # shared btw lm and decoder
        lm = self.sublayer[1](x, self.feed_forward[0])
        
        # decoder layer
        dec = self.sublayer[2](x, lambda x: self.src_attn(x, src, src))
        dec = self.sublayer[3](dec, self.feed_forward[1])
        return self.sublayer[4](dec+lm, self.feed_forward[2])

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8, weight_tying=False):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayerWithLM(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    if weight_tying:
        model.generator.weight = model.tgt_embed[0].lut.weight
    
    return model

# w/ AWDLSTM LM integrated

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, pos_enc):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enc = pos_enc
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.pos_enc(tgt), src, tgt_mask)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8   #(cube root of d_model?)
        return x

In [None]:
def lm_mod_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        PositionalEncoding(d_model, drops, 2000),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

In [None]:
class LM(nn.Module):
    def __init__(self, vocab, d_model, n_hidden, n_layers):
        super(LM, self).__init__()
        self.lm = AWD_LSTM(vocab, d_model, n_hidden, n_layers, pad_token=0,
                           hidden_p=0.1, input_p=0.25, embed_p=0.05, weight_p=0.2)
                
    def forward(self, tgts):
        res = self.lm(tgts, from_embeddings=True)
        pdb.set_trace()
        return res

In [None]:
class Mixer(nn.Module):
    def __init__(self, d_model, vocab, drops=0.2):
        super(Mixer, self).__init__()
        self.mixer = nn.Sequential(PositionwiseFeedForward(d_model, drops), LayerNorm(d_model))
        self.generator = nn.Linear(d_model, vocab)
        
    def forward(self, dec_outs, lm_outs):
        mix = self.mixer((lm_outs+outs)/2)
        return self.generator(mix)

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_model, vocab, drops, pad_tok=0):
        super(Embedding, self).__init__()
        self.emb = nn.Embedding(vocab, d_model, padding_idx=pad_tok)
        self.emb_drop = EmbeddingDropout(self.emb, drops)

    def forward(self, x):
        return self.emb_drop(x)  #self.lut(x) * math.sqrt(self.d_model)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, embedding, lm, mixer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.embedding = embedding
        self.transformer = transformer
        self.lm = lm
        self.mixer = mixer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    tgt_emb = self.embedding(tgt)
                    dec_outs = self.transformer.decode(feats, tgt_emb, mask)
                    lm_outs = self.lm(tgt_emb)
                    prob = self.mixer(dec_outs[:,-1], lm_outs[:,-1])
                    #outs = self.lm(dec_outs[:,-1], self.transformer.embed, decode=True)
                    #prob = self.transformer.generate(outs)
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            tgt_emb = self.embedding(tgt)
            dec_outs = self.transformer(feats, tgt_emb, tgt_mask)    # ([bs, sl, d_model])
            lm_outs = self.lm(tgt_emb)
            out = self.mixer(dec_outs, lm_outs)
            #out = self.transformer.generate(outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    vocab = len(data.vocab.itos)
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = lm_mod_model(vocab, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    embedding = Embedding(d_model, vocab, drops, pad_tok=0)
    lm = LM(vocab, d_model, 1400, 3)
    mixer = Mixer(d_model, vocab, drops)
    net = Img2Seq(img_encoder, transformer, embedding, lm, mixer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

## Add LM to model state_dict

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
# update existing model according to LM modifications
sd['model']['embedding.emb.weight'] = sd['model']['transformer.tgt_embed.0.lut.weight']
sd['model']['embedding.emb_drop.emb.weight'] = sd['model']['transformer.tgt_embed.0.lut.weight']
sd['model']["transformer.pos_enc.pe"] = sd['model']['transformer.tgt_embed.1.pe']
sd['model']['mixer.generator.weight'] = sd['model']['transformer.generator.weight']
sd['model']['mixer.generator.bias'] = sd['model']['transformer.generator.bias']

del sd['model']['transformer.tgt_embed.1.pe']
del sd['model']['transformer.tgt_embed.0.lut.weight']
del sd['model']['transformer.generator.weight']
del sd['model']['transformer.generator.bias']

In [None]:
lm_sd = torch.load('data/wikitext/models/wiki103_lm_enc.pth', map_location=device)

In [None]:
from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd.items():
    name = 'lm.lm.'+k
    new_lm_sd[name] = v

In [None]:
sd['model'].update(new_lm_sd)

learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
# tie weights of transformer embedding and generator w/ lm encodings
learn.model.embedding.emb.weight = learn.model.lm.lm.encoder.weight
learn.model.embedding.emb_drop.emb.weight = learn.model.lm.lm.encoder_dp.emb.weight
learn.model.mixer.generator.weight = learn.model.lm.lm.encoder.weight

In [None]:
learn.save('combo_512_9_wiki103_base_lm')

## load and split learner

In [None]:
learn.load('combo_512_9_wiki103_base_lm')
None

In [None]:
learn.split([learn.model.img_enc, learn.model.embedding, learn.model.transformer, learn.model.lm, learn.model.mixer])
None

In [None]:
learn.layer_groups[4]

In [None]:
learn.freeze_to(-3)
learn.model.img_enc.linear.weight.requires_grad

In [None]:
lrs = slice(2e-5, 2e-4, 2e-3)

# w/ AWDLSTM added

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8, weight_tying=False):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    if weight_tying:
        model.generator.weight = model.tgt_embed[0].lut.weight
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, lm):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.lm = lm
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    dec_prob = self.transformer.generate(dec_outs[:,-1])
                    lm_prob,_,_ = self.lm(tgt)
                    prob = (dec_prob + lm_prob[:,-1])/2
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
lm_config = dict(emb_sz=512, n_hid=1400, n_layers=3, pad_token=0, qrnn=False, bidir=False, output_p=0.2,
              hidden_p=0.2, input_p=0.5, embed_p=0.1, weight_p=0.4, tie_weights=True, out_bias=True)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    vocab = len(data.vocab.itos)
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(vocab, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    lm = get_language_model(AWD_LSTM, vocab, lm_config, drop_mult=0.5)
    net = Img2Seq(img_encoder, transformer, lm)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

## Add LM to model state_dict

In [None]:
learn.model.lm

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
lm_sd = torch.load(PATH/'models/wiki2_lm.pth', map_location=device)

In [None]:
from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd['model'].items():
    name = 'lm.'+k
    new_lm_sd[name] = v

In [None]:
lm_sd['model'].keys()

In [None]:
sd['model'].update(new_lm_sd)

learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
learn.save('combo_512_9_wiki2_lm')

# Num Lines

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
                .split_by_rand_pct(valid_pct=0.15, seed=42)
                .label_from_df(cols=2)
                .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
                .databunch(bs=bs, val_bs=bs*2, device=device)
       )

In [None]:
l2 = cnn_learner(data, models.resnet34, metrics=accuracy)
# Note: already split and frozen:  m[0][6], m[1]

In [None]:
# load saved sd into l2

In [None]:
from collections import OrderedDict
new_sd = OrderedDict()

sd = torch.load(PATH/'models/combo_512_mix_res9.pth', map_location=device)
for k, v in sd['model'].items():
    if k.startswith('img_enc.base'):
        name = k.replace('img_enc.base', '0')
        new_sd[name] = v

In [None]:
l2_sd = torch.load(PATH/'models/line_97acc.pth', map_location=device)
for k, v in l2_sd['model'].items():
    if k.startswith('1.'):
        new_sd[k] = v

In [None]:
l2.model.load_state_dict(new_sd, strict=False)

In [None]:
l2.layer_groups[1][47].weight.requires_grad, l2.layer_groups[2][9].weight.requires_grad

In [None]:
l2.lr_find()
l2.recorder.plot(suggestion=True)

In [None]:
l2.fit_one_cycle(1,1e-3)
# 0.171519	0.069522	0.976006	12:57     'line_97acc'
# 0.094188	0.050964	0.988564	12:01  preload img_enc: combo_512_mix_res9 and head: line_97acc   'line_98acc'

In [None]:
l2.save('line_98acc')

In [None]:
# learn = cnn_learner(data, models.resnet34, custom_head=create_head(1024, 1), metrics=accuracy, loss_func=MSELossFlat())
# learn.fit_one_cycle(1,1e-3)
# # 1.229723	0.113453	0.9057	12:43

# preds,y,losses = learn.get_preds(with_loss=True)

# # actual accuracy
# preds.apply_(round)
# (preds.long().squeeze(-1) == y).float().mean()

# Experiments

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0, attn_type='multi', heads=8)
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
learn.load('combo_512_mix_res9')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.data = data

In [None]:
learn.fit_one_cycle(5, max_lr=1e-5, callbacks=[SaveModelCallback(learn, name='sm_256_lines2')])
#sm, 5cycle, 1e-3

# Image Preprocessing by num_lines
# 11.384931	10.208405	0.122443   'sm_256_lines'
# 10.548290	9.845284	0.116961   2nd run; 5cycle(1e-5); 'sm_256_lines2'

# 

# 2.474220	4.081600	0.036482  N:4,F.gelu,sz:256,em_sz:512,single,drop:0  'sm_256_1'
# greedy:    1.36697   .066
# 4.430638	4.402001	0.034730  2nd run, lr:1.5e-5, add tfms, drop:0.2   'sm_256_2'
# greedy:    0.37206   .04145

# 5.158808	4.964441	0.042544  "", w/ tfms, drop:0.2   'sm_256_3'
# greedy:    0.50646   .04103
# 4.266788	4.694098	0.039822  2nd run, lr:1.5e-5   'sm_256_4'
# greedy:    0.38972   .03879

# 3.678168	3.962710	0.035466  N:8,tfms   'sm_256_5'
# greedy:    0.38097   .04405

# 2.840161	3.429177	0.030243  N:4,tfms,multi(8)  'sm_256_6'
# greedy:    0.32775   .04201
# greedy:    1.38865   .06943

# 3.377438	3.786088	0.033931   N:6,tfms,drop:0.1,multi(8)   'sm_256_7'
# greedy:    0.36201   .04266  
# greedy:    1.40557   .06902

# 3.989137	4.353169	0.037717   N:4,weight-tying,multi(16)   'sm_256_8'
# greedy:    1.44291   .07299

# 4.806501	5.007887	0.042339   N:4,weight-tying,multi(8),drops:0.2   'sm_256_9'
# greedy:    1.46753   .07424

# 12.817862	11.842313	0.105486   N:4,no init/weight tying,multi(8),drops:0.1   'sm_256_10'

# 3.515456	3.880955	0.033763   N:4,multi(8),drops:0.1   'sm_256_11'
# greedy:    1.40945   .06964

# combo_cat6 - preload 'sm_256_7'
# 7.570516	6.670641	0.015070    'combo_512_7'
# greedy:    6.29406   .02186


# combo 6lines and under, preload 'sm_256_6', 10cycle(1e-3)
# 10.784849	9.144131	0.024880     'combo_512_13'
# greedy:    9.21715   .02788
# upload:    103.456   .44797
#     pg:    148.473   .69869

# combo 6lines and over, preload 'combo_512_13', 5cycle(2e-5)
# 56.732105	38.604137	0.026153    'combo_512_14'
# greedy:    141.022   .02933
# upload:    179.606   .81724

# combo_cat_pg - preload 'combo_512_7'
# 25.674347	21.161726	0.015869    'combo_512_8'
# greedy:    114.878   .03654
#   test:    115.247   .05014

# combo_cat_pg_dl_sorted - preload 'combo_512_8', 5cycle, 1.5e-4
# 9.800321	5.676855	0.006939    'combo_512_9'
# greedy:    34.8757   .01523

#   test:    91.3226   .05102
#   test:    104.685   .05483
#   test:    116.467   .05315

# upload:    115.149   .66801

# imdb_wiki_combo - preload 'combo_512_9'; split: img_enc, transformer; 3cycle, slice(5e-5, 1e-3)
# 20.390352	16.698141	0.018726    'combo_512_12'
# greedy:    60.8384   .02474
#     pg:    112.755   .05204

# combo_145k - fit:1cycle(1e-3); preload combo_512_9; drops:0.2
# 30.075714	31.024599	0.031703	1:50:33

# combo_145k <= 6 - 5cycle(1e-4); preload sm_256_mix_res
# 19.502026	17.365313	0.044436    'combo_512_mix_res'
# greedy:    14.6870   .04744
# upload:    75.1357   .47955
#     pg:    142.710   .66233

# test_combo 65k (sm,cat,pg,dl)
# 18.262995	15.091379	0.039351   5cycle(1e-5)  'combo_512_mix_res5'
# greedy:    129.968   .03932
# upload:    81.6835   .21794
#     pg:    127.317   .05613

# 16.678631	12.867742	0.035985   1cycle(1e-5), drops:0.1   'combo_512_mix_res6'
# 14.878900	12.797726	0.035834   1cycle(1e-6), drops:0.1   'combo_512_mix_res7'
# greedy:    34.3988   .05552
# upload:    75.5511   .21010
#     pg:    118.963   .05287

# 16.086544	12.320695	0.034263   1cycle(1e-5), drops:0.1   'combo_512_mix_res8'
# 11.640953	11.339241	0.032201   1cycle(1e-5), drops:0     'combo_512_mix_res9'
# greedy:    31.5506   .05410
#   test:    74.4820   .22067
#     pg:    105.182   .05076


# 79.324310	69.801376	0.325578	23:04   preload lines/combo_512_mix_res9;  1cycle(1e-4)

# Greedy

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
sd['model']["transformer.src_adapt.0.weight"] = sd['model']['img_enc.linear.weight']
sd['model']["transformer.src_adapt.0.bias"] = sd['model']['img_enc.linear.bias']

In [None]:
learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
# learn.load('combo_512_9_wiki2_lm')
learn.load('combo_512_mix_res')
None

In [None]:
def full_test(learn, sl, dl=data.valid_dl, batches=10):
    learn.model.eval()
    iterable = iter(dl)
    g_loss,g_cer=0,0
    if batches is None:
        batches = len(dl.dl.dataset)//bs
    for i in progress_bar(range(batches)):
        x,y = next(iterable)
        g_preds = learn.model(x, seq_len=sl)
        g_res = torch.argmax(g_preds, dim=-1)
        g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, itos)[0]/bs]
        g_loss+=g[0]
        g_cer+=g[1]
    return [g_loss/batches, g_cer/batches]

In [None]:
# set_seed()
g = full_test(learn, seq_len)

In [None]:
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
# losses = np.array([learn.loss_func(g_preds[i:i+1],y[i:i+1]).item() for i in range(bs)])
# cers = np.array([cer(g_preds[i:i+1],y[i:i+1])[0] for i in range(bs)])

In [None]:
# set_seed()
x,y = next(iter(learn.data.valid_dl))

g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, word_itos)[0]/bs]

In [None]:
#greedy
fig, axes = plt.subplots(2,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], word_itos, sep='')
    ax=show_img(x[i], ax=ax, title=p)

# Test

### data

In [None]:
FOLDER = 'uploads'
df = pd.read_csv(PATH/'uploads.csv')
len(df)

sz,bs = 512,14
seq_len = 700

In [None]:
FOLDER = 'paragraphs'
df = pd.read_csv(PATH/'test_pg.csv')
len(df)

sz,bs = 512,15
seq_len = 700

In [None]:
## Test dataset only!!!
# set_seed()  # reproducibility
test_data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos), tokenizer=CharTokenizer)
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
# words(bert_tok)
test_data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=MultiSequenceList, vocab=BertVocab(), pad_idx=0)
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
learn.data = test_data

In [None]:
dl = iter(test_data.train_dl)

In [None]:
x,y = next(dl)

### learner

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
# learn.load('combo_512_9_wiki2_lm')
# learn.load('combo_512_mix_res')
learn.load('test_combo_char_itos')
None

In [None]:
# learn.data = data

## Bert tok

In [None]:
g_preds = greedy_decode(x, learn.model, word_len, 'word', 101)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, data.y.reconstruct)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = data.y.reconstruct(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

## GPU batch testing

In [None]:
g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

# test:    397.121   .29261  ~3:55   combo_512_9_wiki2_lm  (LM isn't aware of '\n'!!)
# test:    101.374   .05139  ~27s    combo_512_9

In [None]:
#test
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    i +=4
    p = char_label_text(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

## Preprocessing Lines

In [None]:
# test data
x,y = next(iter(data.train_dl))

In [None]:
def lines_into_paragraph(res):
    out = []
    for r in res:
        s = np.split(r, np.where(r == 2)[0])[0].numpy()
        out.append(np.append(s,4))
    out = np.concatenate(out)
    out[-1] = 2  # replace final '\n' with 'eos'
    return torch.from_numpy(out)

In [None]:
def target_into_lines(targ, split_idx=4):
    nonzero = targ[targ.nonzero()].flatten()
    lines = np.split(nonzero, np.where(nonzero == split_idx)[0])
    maxlen = len(max(lines,key=len))
    res = torch.zeros((len(lines),maxlen))
    for i,arr in enumerate(lines):
        res[i,:len(arr)] = arr
    return res

In [None]:
from scipy import signal,ndimage
import statistics

def img_into_line_tensor(img):
    arr = img[0].numpy()
    w = arr.shape[-1]
    heights = []
    lengths = []
    stds = arr.std(axis=1)
    g_stds = scipy.ndimage.gaussian_filter1d(stds, 5)
    peaks,_ = scipy.signal.find_peaks(g_stds, prominence=stds.std()//3, distance=20)
    mins = scipy.signal.argrelextrema(g_stds, np.less_equal)[0]  # np.less_equal critical for flat minima, edges
    for p in peaks:
        rows = range(mins[mins < p][-1], mins[mins > p][0])
        lengths.append(len(rows))
        heights.append(arr[rows])
    max_len = max(lengths)
    outs = torch.ones((len(heights),3 , max_len, w))
    for i,h in enumerate(heights):
        outs[i,:,:len(h)] = torch.from_numpy(h)
    return outs

In [None]:
idx = 1
inps = img_into_line_tensor(x[idx])
targs = target_into_lines(y[idx])

In [None]:
g_preds = learn.model(inps, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)

In [None]:
print(f'test cer:   {str(cer(g_preds,targs)[0]/g_preds.size(0))[1:7]}')

## Single Image (cpu)

In [None]:
idx = 5

In [None]:
x,y = test_data.train_ds[idx]

In [None]:
%time pred = learn.predict(x)
# CPU total time: 4min 24s

# tx,_ = test_data.one_item(x)
# %time pred = greedy_decode(tx, learn.model, seq_len, 'char')    p_mask(25)
# CPU total time: 9min 42s

In [None]:
show_img(x, title=str(pred[0]))

In [None]:
probs = F.softmax(pred[2], dim=-1)

In [None]:
scores, idxs = torch.topk(probs, 3, dim=-1)

In [None]:
scores[0]

In [None]:
for s,i in zip(scores,idxs):
    new = {}
    for score, idx in zip(s,i):
        new[itos[idx]] = score.data
    print(new)

In [None]:
chars = idxs.numpy()
chars.appl

In [None]:
idxs.apply_(lambda x: itos[x])

In [None]:
for i in range(30):
    print(char_label_text(idxs[i], sep=' '))

In [None]:
# learn.show_results(ds_type=data.train_ds, rows=2)

In [None]:
i = 0
x = xs[i][None]
y = ys[i][None]

In [None]:
batch = data.one_item(xs[0])

In [None]:
xs,ys = data.one_batch(detach=False, denorm=False)

In [None]:
g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)[0]]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
p = char_label_text(g_res[0])
show_img(x[0], figsize=(18,10), title=p)

In [None]:
im = PATH/'uploads'/'test1.png'
img = open_image(im)
prediction = learn.predict(img)[0]
show_img(img, title=str(prediction))