# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks.tracker import *

import pdb

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(im.data), alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((1,size,size), device=device).byte())
    #return torch.tril(torch.ones((1,1,size,size), device=device).byte())  # complex batches
    
def parallelogram_mask(size, diagonal):
    mask = torch.ones((1,size,size), device=device).byte()
    upper = torch.tril(mask).bool()
    lower = torch.triu(mask, diagonal=-diagonal).bool()
    return (upper & lower).byte()

## Loss, Metrics, Callbacks

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = self.loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs
    
    def loss_prep(self, pred, target):
        "equalize input/target sl; combine bs/sl dimensions"
        bs,tsl = target.shape
        _ ,sl,vocab = pred.shape

        # F.pad( front,back for dimensions: 1,0,2 )
        if sl>tsl: target = F.pad(target, (0,sl-tsl))

        # this should only be used when testing for small seq_lens
        # if tsl>sl: target = target[:,:sl]

        if tsl>sl: pred = F.pad(pred, (0,0,0,tsl-sl))
        # not ideal => adds 96 logits all 0s...

        targ = target.contiguous().view(-1).long()
        pred = pred.contiguous().view(-1, vocab)
        return pred, targ

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, itos):
        super().__init__()
        self.name = 'cer'
        self.itos = itos

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = cer(last_output, last_target, self.itos)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

In [None]:
def cer(preds, targs, itos):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = char_label_text(res[i], itos)   #.replace(' ', '')
        t = char_label_text(targs[i], itos) #.replace(' ', '')
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

def char_label_text(pred, itos, sep=''):
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)] #[:-1]  #remove eos token
    return sep.join([itos[i] for i in nonzero])

In [None]:
class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        return {'last_input':(last_input, last_target), 'last_target':last_target}

# Data

## paragraphs

In [None]:
fname = 'edited_pg.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 512,10
seq_len = 700

## sm synth dataset

In [None]:
fname = 'edited_sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'edited_sm_synth'

df = pd.read_csv(CSV)
len(df)

In [None]:
num_lines = df.label.apply(lambda x: len(x.split('\n')))
df['num_lines'] = num_lines

num_lines = 4
df.head()

In [None]:
# df = df[:20000]
# len(df)

In [None]:
# sz,bs = 128,100
sz,bs = 256,80
seq_len = 100

## combo

In [None]:
fname = 'combo_145k.csv'
FOLDER = 'combo_cat'

CSV = PATH/fname
df = pd.read_csv(CSV)

In [None]:
# full
sz,bs = 512,5
seq_len,word_len = 750,250
len(df)

In [None]:
# 6 and fewer
df = df[df.num_lines <= 6]
sz,bs = 512,15
seq_len,word_len = 600,200
len(df)

In [None]:
# 6 and greater
df = df[df.num_lines >= 6]
sz,bs = 512,5
seq_len,word_len = 750,250
len(df)

## test combo (no fonts)

In [None]:
FOLDER = 'test_combo'
CSV = PATH/'test_combo.csv'
df = pd.read_csv(CSV)

sz,bs = 512,10
seq_len = 750
len(df)

# ModelData

In [None]:
tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=0, max_warp=0.1)

def force_gray(image): return image.convert('L').convert('RGB')

In [None]:
def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels    
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

In [None]:
class SequenceList(TextList):    
    def __init__(self, items:Iterator, vocab:Vocab, tokenizer:Tokenizer, **kwargs):
        toknizr = Tokenizer(tok_func=tokenizer, pre_rules=[], post_rules=[], special_cases=[])
        procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False), NumericalizeProcessor(vocab=vocab)]
        super().__init__(items, vocab, sep='', pad_idx=0, processor=procs)
    
    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

## Chars

In [None]:
itos = pickle.load(open(PATH/'itos.pkl', 'rb'))

In [None]:
class CharTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return list(t)+['xxeos']
            
class CharVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
#         .split_none()
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        #.label_from_df(label_cls=TextList, sep='', pad_idx=0, vocab=vocab, processor=procs)
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos), tokenizer=CharTokenizer)
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        #.transform(tfms, size=sz, resize_method=ResizeMethod.PAD, padding_mode='border')
        # maintains aspect ratio but too small for good results => mostly whitespace
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        #.normalize() # this sets x values to an odd range (~.3,-6)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

## Words

In [None]:
word_itos = pickle.load(open(PATH/'word_itos_60k_mod.pkl', 'rb'))
word_itos = word_itos[:10000]
len(word_itos)

In [None]:
class WordTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: 
        chars = list(t.lower())
        res = []
        tok = ""
        for c in chars:
            if c.isalnum():
                tok+=c
            else:
                if tok.isalnum(): res.append(tok) 
                res.append(c)
                tok = ""
        if tok.isalnum(): res.append(tok)
        return res + ['xxeos']

# class WordTokenizer(BaseTokenizer):
#     def tokenizer(self, t:str) -> List[str]: return t.lower().replace('\n', ' \n ').split(' ') + ['xxeos']
            
class WordVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

In [None]:
# w_vocab = Vocab(word_itos)
# w_procs = [TokenizeProcessor(include_bos=False, include_eos=True), NumericalizeProcessor(vocab=w_vocab)]

words = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SequenceList, vocab=WordVocab(word_itos), tokenizer=WordTokenizer)
        #.label_from_df(label_cls=TextList, pad_idx=0, vocab=w_vocab, processor=w_procs)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
words.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

## Mix

In [None]:
def multi_label_collater(samples:BatchSamples):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    char_lbls, word_lbls = zip(*lbls)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    return imgs, (c_pad(char_lbls), c_pad(word_lbls))
    
def c_pad(lbls, pad_idx=0):
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(lbls), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return labels

In [None]:
class MultiTokenizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None, chunksize:int=10000):
        self.c_toknizr = Tokenizer(tok_func=CharTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        self.w_toknizr = Tokenizer(tok_func=WordTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        self.chunksize = chunksize
        
    def process_one(self, item):
        raise Exception("This isn't implemented!  I didn't think it was necessary...")
    
    def process(self, ds):
        ds.chars = ds.items
        ds.words = ds.items

        char_tokens = []
        word_tokens = []
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            char_tokens += self.c_toknizr.process_all(ds.chars[i:i+self.chunksize])
        for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
            word_tokens += self.w_toknizr.process_all(ds.words[i:i+self.chunksize])
        ds.items = list(zip(char_tokens, word_tokens))
        
class MultiNumericalizeProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.char_vocab = CharVocab(itos)
        self.word_vocab = WordVocab(word_itos)

    def process_one(self,item):
        chars = np.array(self.char_vocab.numericalize(item[0]), dtype=np.int64)
        words = np.array(self.word_vocab.numericalize(item[1]), dtype=np.int64)
        return [chars,words]
            
    def process(self, ds):
        ds.items = array([self.process_one(item) for item in ds.items])
        
        
class MultiSequenceList(TextList):
    _processor = [MultiTokenizeProcessor, MultiNumericalizeProcessor]
            
    def __init__(self, items:Iterator, pad_idx=0, **kwargs):
        super().__init__(items, **kwargs)
        self.char_vocab = CharVocab(itos)
        self.word_vocab = WordVocab(word_itos)
        
        self.pad_idx = pad_idx
        self.copy_new += ['char_vocab', 'word_vocab', 'pad_idx']

    def get(self, i):
        o = super().get(i)
        chars = Text(o[0], self.char_vocab.textify(o[0], ''))
        words = Text(o[1], self.word_vocab.textify(o[1], ''))
        return [chars,words]
    
    def reconstruct(self, t:Tensor):
        "t: List of tensors -> [char, word]"        
        c,w = t
        idx_min,idx_max = (c != self.pad_idx).nonzero().min(), (c != self.pad_idx).nonzero().max()
        chars = Text(c[idx_min:idx_max+1], self.char_vocab.textify(c[idx_min:idx_max+1]))
        
        idx_min,idx_max = (w != self.pad_idx).nonzero().min(), (w != self.pad_idx).nonzero().max()
        words = Text(w[idx_min:idx_max+1], self.word_vocab.textify(w[idx_min:idx_max+1]))
        return [chars,words]

    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
class ImageMultiList(ImageList):  
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            chars,words = ys[i]
            combined = Text([], str(chars) + '\n\n' + str(words))
            xs[i].show(ax=ax, y=combined, **kwargs)
        plt.tight_layout()

In [None]:
data = (ImageMultiList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=MultiSequenceList)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=multi_label_collater)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

# Transformer Modules

In [None]:
# LayerNorm = nn.LayerNorm
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training
# LayerNorm = partial(nn.BatchNorm2d, eps=1e-4)

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # acts as a weak LM
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# img2char2word Arch

In [None]:
# length_results = []

# c_ler = ler[::2]
# w_ler = ler[1::2]
# max(c_ler), max(w_ler)

In [None]:
class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self, embedding, col_encoding, dropout=0.1):
        super(LearnedPositionalEmbeddings, self).__init__()
        self.nl_tok  = 4
        self.end_tok = 2
        self.d_model = embedding.embedding_dim
        
        self.embed = embedding
        self.dropout = nn.Dropout(p=dropout)
        
        self.rows = nn.Embedding(15, self.d_model//2, 0)
        self.cols = col_encoding

    def forward(self, x):
        rows,cols = torch.zeros_like(x),torch.zeros_like(x)
        for ii,batch in enumerate(x.unbind()):
            lines = torch.nonzero((batch==self.nl_tok) + (batch==self.end_tok)).flatten()

            p = 0
            for i,n in enumerate(lines, start=1):
                rows[ii,p:n+1] = i
                cols[ii,p:n+1] = torch.arange(1,n+2-p)
                p = n
        
        row_t = self.rows(rows)
        col_t = self.cols(cols)
        
        #length_results.append(cols.max().item())
        pos_enc = torch.cat((row_t, col_t), dim=-1)
                
        x = self.embed(x)
        x = (x + pos_enc) * math.sqrt(self.d_model)
        return self.dropout(x)

In [None]:
class WordCharTransformer(nn.Module):
    def __init__(self, encoder, c_dec, w_dec, src_adapt, c_emb, w_emb, c_gen, w_gen):
        super(WordCharTransformer, self).__init__()
        self.encoder = encoder
        self.src_adapt = src_adapt
        
        self.w_decoder = w_dec
        self.c_decoder = c_dec
        
        self.w_embed = w_emb
        self.c_embed = c_emb
        self.w_generator = w_gen
        self.c_generator = c_gen
    
    def forward(self, src, mix_tgt):
        c_tgt,w_tgt,c_mask,w_mask = self.shift_with_masks(mix_tgt)

        feats = self.encode(src)
        char_outs = self.c_decoder(self.c_embed(c_tgt), feats, c_mask)
        word_outs = self.w_decoder(self.w_embed(w_tgt), feats, w_mask)

        return char_outs, word_outs
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def generate(self, c_outs, w_outs):
        return self.c_generator(c_outs), self.w_generator(w_outs)
    
    def shift_with_masks(self, mix_tgt):
        c_tgt,w_tgt = mix_tgt
        c_tgt = rshift(c_tgt).long()
        w_tgt = rshift(w_tgt).long()
        c_mask = parallelogram_mask(c_tgt.size(-1), 25) #subsequent_mask(c_tgt.size(-1))
        w_mask = parallelogram_mask(w_tgt.size(-1), 10) #subsequent_mask(w_tgt.size(-1))
        return c_tgt,w_tgt,c_mask,w_mask

In [None]:
def make_full_model(c_vocab, w_vocab, d_model, em_sz, N=4, drops=0, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)

    ff = PositionwiseFeedForward(d_model, drops)
#     pos = PositionalEncoding(d_model, drops, 2000)
    
    model = WordCharTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8))
            # .mul_(8) increases the gradients on these weights by 8!
        ),
        LearnedPositionalEmbeddings(nn.Embedding(c_vocab,d_model,0), nn.Embedding(100, d_model//2, 0)),   
        LearnedPositionalEmbeddings(nn.Embedding(w_vocab,d_model,0), nn.Embedding(60,  d_model//2, 0)),
#         nn.Sequential(
#             Embeddings(d_model, c_vocab), pos
#         ),
#         nn.Sequential(
#             Embeddings(d_model, w_vocab), pos
#         ),
        nn.Linear(d_model, c_vocab),
        nn.Linear(d_model, w_vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt):
        feats = self.img_enc(src)
        char_outs, word_outs = self.transformer(feats, tgt)
        outs = self.transformer.generate(char_outs, word_outs)
        return outs

In [None]:
class CERMetric(LearnerCallback):
    _order=-20 # Needs to run before the recorder
    def __init__(self, learn, c_itos, w_itos):
        super().__init__(learn)
        self.c_itos = c_itos
        self.w_itos = w_itos

    def on_train_begin(self, **kwargs):
        self.learn.recorder.add_metric_names(['char', 'word'])
            
    def on_batch_end(self, last_output, last_target, **kwargs):
        c_out, w_out = last_output
        c_targ, w_targ = last_target
        c_error,size = cer(c_out, c_targ, self.c_itos)
        w_error,_    = cer(w_out, w_targ, self.w_itos)
        self.c_errors += c_error
        self.w_errors += w_error
        self.total += size
        
    def on_epoch_begin(self, **kwargs):
        self.c_errors, self.w_errors, self.total = 0, 0, 0

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, [self.c_errors/self.total, self.w_errors/self.total])

In [None]:
class MultiLabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(MultiLabelSmoothing, self).__init__()
        self.smoothing = smoothing
        
    def forward(self, pred, c_targ, w_targ):
        loss = LabelSmoothing(self.smoothing)
        cl = loss(pred[0], c_targ)
        wl = loss(pred[1], w_targ)
        #print(f'char loss: {cl}  word_loss: {wl}')
        return cl + wl

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    c_vocab = data.train_dl.dl.dataset.y.char_vocab.itos
    w_vocab = data.train_dl.dl.dataset.y.word_vocab.itos
    img_encoder = ResnetBase(em_sz)
    transformer = make_full_model(len(c_vocab), len(w_vocab), d_model, em_sz, N, drops, attn_type, heads)
    net = Img2Seq(img_encoder, transformer)
    learn = Learner(data, net, loss_func=MultiLabelSmoothing(smoothing), callback_fns=[TeacherForce])
    learn.callbacks.append(CERMetric(learn, c_vocab, w_vocab))
    return learn

## Greedy results

In [None]:
def greedy_decode(src, model, seq_len, kind='char'):
    tfmr = model.transformer
    img_enc = model.img_enc
    
    decoder = tfmr.c_decoder if kind=='char' else tfmr.w_decoder
    embed = tfmr.c_embed if kind=='char' else tfmr.w_embed
    generator = tfmr.c_generator if kind=='char' else tfmr.w_generator
    
    with torch.no_grad():
        feats = tfmr.encode(img_enc(src))
        bs = src.size(0)
        tgt = torch.ones((bs,1), dtype=torch.long, device=device)

        res = []
        for i in progress_bar(range(seq_len)):
#             mask = subsequent_mask(tgt.size(-1))
            mask = parallelogram_mask(tgt.size(-1), 25)  # doesn't work well because model counts '/n'
            
            dec_outs = decoder(embed(tgt), feats, mask)
            prob = generator(dec_outs[:,-1])
            res.append(prob)
            pred = torch.argmax(prob, dim=-1, keepdim=True)
            if (pred==0).all(): break
            tgt = torch.cat([tgt,pred], dim=-1)
        out = torch.stack(res).transpose(1,0).contiguous()

        return out

In [None]:
x,y = next(iter(learn.data.valid_dl))

### Chars

In [None]:
g_preds = greedy_decode(x, learn.model, seq_len, 'char')
g_res = torch.argmax(g_preds, dim=-1)
loss_func = LabelSmoothing()
g = [loss_func(g_preds, y[0]).item()/bs, cer(g_preds, y[0], itos)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], itos)
    ax=show_img(x[i], ax=ax, title=p)

### Words

In [None]:
g_preds = greedy_decode(x, learn.model, word_len, 'word')
g_res = torch.argmax(g_preds, dim=-1)
loss_func = LabelSmoothing()
g = [loss_func(g_preds, y[1]).item()/bs, cer(g_preds, y[1], word_itos)[0]/bs]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#greedy
fig, axes = plt.subplots(2,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], word_itos)
    ax=show_img(x[i], ax=ax, title=p)

# Multi-Resolution Arch

## Image Encoders

### Resnet34

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]   #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz,stride=(2,1))
        #self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.conv(x)
        return x.flatten(2,3).permute(0,2,1)

### Xception

In [None]:
from fastai.vision.models.cadene_models import xception_cadene

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,
                               padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

In [None]:
class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(nn.ReLU(inplace=True))
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(nn.ReLU(inplace=True))
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(nn.ReLU(inplace=True))
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {728:-6, 1024:-5, 2048:-1}
        #728, 4, 32
        #1024, 2, 16
        #2048, 2, 16
        s = slices[em_sz]

        net = xception_cadene(True)
        self.base = net[:s]
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

### SEResnet50

In [None]:
from fastai.vision.models.cadene_models import se_resnet50

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, em_sz):
        super().__init__()
        
        slices = {512:-4, 1024:-3, 2048:-2}
        #512, 8, 64
        #1024, 4, 32
        #2048, 2, 16
        s = slices[em_sz]

        net = se_resnet50(True)
        modules = list(net.children())[:s]    #512,2,16;  256,4,32;  128,8,64
        self.base = nn.Sequential(*modules)
        
        self.conv = conv_layer(em_sz,em_sz)
        self.pool = nn.AdaptiveMaxPool2d((16,None))
        
    def forward(self, x):
        x = self.base(x)
        x = self.pool(self.conv(x))
        return x.flatten(2,3).permute(0,2,1)

## Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_adapt = src_adapt
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), self.embed(tgt), tgt_mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(tgt, src, tgt_mask)
    
    def embed(self, tgt):
        return self.tgt_embed(tgt)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)
    pos = PositionalEncoding(d_model, drops, 2000)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #, pos
        ),
        nn.Sequential(
            Embeddings(d_model, vocab), pos
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
#                     mask = parallelogram_mask(tgt.size(-1), 10)

                    t = self.transformer.embed(tgt)
    
#                     dec_outs = self.transformer.decode(feats, t[:,-11:])
                    dec_outs = self.transformer.decode(feats, t, mask)

                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def init_params(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.1, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ImageEncoder(em_sz)
    itos = data.train_ds.y.vocab.itos
    transformer = make_full_model(len(itos), d_model, em_sz, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER(itos)], callback_fns=[TeacherForce])

# Experimentation

In [None]:
learn = make_learner(data, 512, 256, N=4, drops=0.1, attn_type='multi', heads=8)
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
learn.load('sm_char_word'); None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)  

In [None]:
learn.fit_one_cycle(3, max_lr=1e-3)
# CHARS / WORDS
# d_model:512, em_sz:256, resnet34, 5cycle(1e-3), N:4, drops:0.1, multi:8

# sm dataset, sz:256, bs:80, seq_len:100, 5cycle(1e-3)
# 4.847745	5.941104	0.027651	0.029362	08:52   chars/words independent; sharing img features  'sm_char_word'
# chars:    0.346770    0.03947
# words:    0.903070    0.03809
# 4.455848	5.087106	0.024507	0.030808	09:54   2d pos, parallel_mask(10), word_itos:10k  'sm_char_word2'



# combo_145<=6, sz:512, bs:15, seq_len:600
# 40.977287	36.501549	0.039365	0.133388	34:21   "", train from scratch no preloading  'combo<6_char_word'

# 1cycle(1e-3)
# 325.525482	315.047577	0.572584	0.707223	34:39   em_sz:256
# 336.021881	324.370392	0.581797	0.726042	34:45   remove .mul_(8)
# 343.716064	331.689667	0.594245	0.735993	30:36   em_sz:512
# 336.650574	330.916840	0.585928	0.727880	35:35   char_decoder:src_attn:heads:16
# 300.974121	307.491211	0.551018	0.712521	36:47   learned 2d positional embeddings, word_itos:60k
# greedy: all 0s  -- b/c word_itos:60k???
# 283.598724	282.288574	0.538124	0.703285	33:38   2d pos embs, word_itos:10k, parallelogram_mask(10)


#### Paragraphs ####
# sz:512, bs:10, resnet34, 5cycle(1e-3), N:4, multi:8
# 713.113770	723.217712	0.640478	01:21   em_sz:256, 32x32
# 722.943359	732.392761	0.650870	01:00   em_sz:512, 16x16
# -----      CUDA memory error     ------       em_sz:128, 64x64             

# 722.326721	732.504883	0.651015	01:01   em_sz:128, conv/pool(8),  8x64
# 695.523865	704.048645	0.623089	01:18   em_sz:128, conv/pool(16), 16x64
# 685.498596	694.236023	0.614935	01:02   em_sz:256, conv/pool(16), 16x32  **
# 694.890808	703.071411	0.623319	00:58   em_sz:512, conv/pool(16), 16x16

# 689.728882	696.015442	0.614972	01:02   em_sz:256, conv/stride(16), 16x32


# 681.818176	690.215942	0.612152	01:26   se_resnet50, em_sz:1024, conv/pool(16), 16x32 **
# 684.835876	694.623779	0.614549	01:21   resnet50, em_sz:1024, conv/pool(16), 16x32
# 699.177185	708.581848	0.627579	01:47   xception, em_sz:728, conv/pool(16), 16x32

# 697.656616	698.180603	0.618370	01:17   em_sz:256, conv/pool(16), 16x32, N:6
# 698.445801	712.330444	0.630357	01:10   em_sz:256, conv/pool(16), 16x32, multi:16

# 693.106262	698.929504	0.619492	01:25   em_sz:256, conv/pool(16), 4res w/ blocks:(16,8,4,2x32)

# words
# 224.618210	295.837616	0.600509	00:59   em_sz:256, conv/stride(16), 16x32

#### LINES ####
# sz:(48,512), bs:100, resnet34, 5cycle(1e-3), em_sz:512, N:4, multi:8

# 16.167061	19.248831	0.130215   maxpool, linear+mul_8   'lines1'
# 16.855366	19.663122	0.134168
# 27.059019	27.176353	0.196770   maxpool, ff
# 19.401184	21.343414	0.147686   maxpool, linear+pos
# 20.828600	22.394470	0.154748   maxpool, linear
# 19.685211	21.401609	0.148922   no maxpool, no linear+mul_8   'lines2'
# 15.595810	18.914587	0.129315   no maxpool, linear+mul_8   'lines3'  ---  linear+mul_8 crucial!

# 12.375642	14.939847	0.103393   em_sz:256, cat convs(3,2,1 x 32)  'lines4'
# 17.737593	20.629824	0.142271   em_sz:512, cat convs(2,1 x 16)
# 8.757012	11.217060	0.084667	07:18   em_sz:128, bs:25, 10 resolutions   'lines5'
# 9.839996	12.166018	0.090686	03:09   em_sz:128, bs:25, 4 resolutions(6,3,2,1)    'lines6'
# 18.749489	19.622530	0.141156	02:20   "", d_model:128

# sz:(64,512), bs:50
# 10.431609	12.825503	0.092731	01:30   em_sz:128, 4 resolutions(8,4,2,1) of 1 line
# 11.442437	13.249302	0.096183	03:35   em_sz:128, 4 resolutions(8,4,2,1)
# greedy:    2.99910   .21647

# bs: 50
# 6.106063	9.794361	0.070278	01:10   em_sz: 256, 4x32
# 12.173990	14.315442	0.096518	01:03   "", bs:100
# 7.294389	10.227886	0.072702	01:29   resnet50, em_sz: 1024, 4x32

# sz:(64,512), bs:100
# 10.622000	13.152803	0.091343	01:21   em_sz:1024, 1 resolution: (1x64)   'lines7'
# 15.201996	17.066908	0.121637	00:51   pool(32,None) before base, em_sz:512, 1x64
# 12.205278	14.841956	0.104063	01:20   pool(48,None) before base, em_sz:1024, 1x64
# 12.217975	13.971797	0.097959	00:57   pool(4,None) after base, em_sz:512, 1x64
# 10.639203	13.198123	0.092142	01:22   pool(8,None) after base, em_sz:1024, 1x64
# 10.091258	12.612871	0.088563	01:36   "", N:6    'lines8'
# 8.227972	12.432637	0.087113	01:35   2nd run, 5cycle(4e-6)
# 10.719064	13.117971	0.090754	01:50   "", N:8
# 11.734365	13.847672	0.096678	01:09   pool(8,None) after, em_sz:512, N:4, 3 resolutions
# 10.949694	13.447043	0.095033	01:11   "", em_sz:1024
# 11.054914	13.662763	0.096324	01:13   "", padding rather than stride 4->2
# 11.858457	13.944701	0.097301	00:59   "", last conv k=2, 1x63
# 10.656734	13.374981	0.093117	01:23   "", last conv k=4, 1x63
# 9.708715	12.423415	0.085666	02:45   pool(8,64) after base, em_sz:1024, w/res_blocks, 1x64   'lines9'
# greedy:    1.28441   .12355

# bs: 50
# xception
# 13.364486	17.758265	0.127069	01:37   xception(2048) 2x16
# 5.375062	10.732620	0.077556	01:49   xception(728) 4x32
# 7.213572	12.425294	0.091167	01:37   xception(728) 1x32 (2 conv layers)
# 6.362609	11.896906	0.086484	01:39   xception(728) 1x32 (2 blocks)
# 10.085477	12.171700	0.087609	03:12   custom xception(256) 8x64

# se_resnet50
# 5.696907	9.483581	0.067969	01:31   em_sz:1024, 4x32
# 9.689358	12.082891	0.088165	02:47   em_sz:512, 8x64
# 7.772936	12.125853	0.083637	01:22   em_sz:2048, 2x16

In [None]:
learn.save('sm_char_word2')

# Image PreProcessing

In [None]:
class ImageProcessor(nn.Module):
    def __init__(self, d_model, max_lines=14):
        super().__init__()

        net = models.resnet34(True)
        modules = list(net.children())[:-2]
        
        self.base = nn.Sequential(*modules)
        self.head = create_head(1024, max_lines)
        
    def forward(self, x):
        with torch.no_grad():
            base = self.base(x)
            n_lines = self.head(base)
            
            b,c,h,w = base.shape
            #res = torch.zeros((b,c,longest))
        
            preds = torch.argmax(n_lines, dim=-1)
            lines = torch.unique(preds)
            
            longest = lines[-1]*w
            res = torch.zeros((b,c,longest), device=device)
            for l in lines:
                idxs = (preds==l).nonzero()[:,0]
                outs = F.adaptive_max_pool2d(base[idxs], (l,None)).flatten(2,3)
                # could add positional encoding here
                pad = longest - outs.size(-1)
                outs = F.pad(outs, (0,pad), "constant", 0)
                res[idxs] = outs
                
        return res.permute(0,2,1)

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)
    pos = PositionalEncoding(d_model, drops, 2000)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), pos
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
#                     mask = parallelogram_mask(tgt.size(-1), 10)

                    t = self.transformer.embed(tgt)
    
#                     dec_outs = self.transformer.decode(feats, t[:,-11:])
                    dec_outs = self.transformer.decode(feats, t, mask)

                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def init_params(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
def make_learner(data, d_model, N=4, drops=0.2, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ImageProcessor(d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER()], callback_fns=[TeacherForce])

## Combine/load state_dict 

In [None]:
learn = make_learner(data, 512, N=6, drops=0, attn_type='multi', heads=8)

In [None]:
from collections import OrderedDict
new_sd = OrderedDict()

line_sd = torch.load(PATH/'models/line_98acc.pth', map_location=device)

for k, v in line_sd['model'].items():
    if k.startswith('0'):
        name = re.sub("^(0)",'img_enc.base',k)
        new_sd[name] = v
    if k.startswith('1'):
        name = re.sub("^(1)",'img_enc.head',k)
        new_sd[name] = v

In [None]:
learn.model.load_state_dict(new_sd, strict=False)

In [None]:
# sd = torch.load(PATH/'models/combo_512_mix_res9.pth', map_location=device)

# for k, v in sd['model'].items():
#     if not k.startswith('img_enc'):
#         new_sd[k] = v

# Full Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        return self.linear(x) * 8

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    #transformer.generator.weight = transformer.tgt_embed[0].lut.weight
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER()], callback_fns=[TeacherForce])

# TransformerXL

In [None]:
class PositionalEncoding(Module):
    "Encode the position with a sinusoid."
    def __init__(self, d_model:int):
        self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d_model, 2.) / d_model)))

    def forward(self, pos:Tensor):
        inp = torch.ger(pos, self.freq)  #(outer) matrix product of 2 vectors
        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
        return enc

In [None]:
class FeedForward(Module):
    def __init__(self, d_model:int, drops=0.1):
        self.core = nn.Sequential(
            nn.Linear(d_model, d_model*4), nn.ReLU(inplace=True), nn.Dropout(drops),
            nn.Linear(d_model*4, d_model), nn.Dropout(drops)
        )
        self.ln = nn.LayerNorm(d_model)
        
    def forward(self, x):
        out = self.core(x)
        return self.ln(x + out)

In [None]:
class MultiHeadRelativeAttention(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1, bias=False):
        d_head = d_model//n_heads
        self.n_heads, self.d_head = n_heads, d_head
        self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
        self.drop_att,self.drop_res = nn.Dropout(drops),nn.Dropout(drops)
        self.ln = nn.LayerNorm(d_model)
        self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ln(x + self.drop_res(self.out(self._mhra(x, mask=mask, **kwargs))))

    def _mhra(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
        #Notations from the paper:
        #x: input, r: vector of relative distance between two elements
        #u,v: learnable parameters of the model common between layers
        #mask: to avoid cheating, mem: previous hidden states
                
        #x: [bs, sl, d_model]
        #u/v: [n_heads, 1, d_head]
        #r: [sl, d_model]
        #mem: 1st:[0]; 2nd:[bs, sl, d_model]; nth: sl*i-1 up to mem_len
        bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
        context = x if mem is None else torch.cat([mem, x], dim=1)
        wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
        wq = wq[:,-x_len:]
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        # [bs, sl, n_heads, d_head]
        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)   #wk: transposed(-2,-1)
        wkr = self.r_attn(r)
        wkr = wkr.view(seq_len, self.n_heads, self.d_head)
        wkr = wkr.permute(1,2,0)  #transposed ala wk w/out bs
        #### compute attention score (AC is (a) + (c) and BD is (b) + (d) in the paper)
        AC = torch.matmul(wq+u,wk)
        BD = _line_shift(torch.matmul(wq+v, wkr))
        attn_score = (AC + BD).div_(self.d_head ** 0.5)  #scale
        if mask is not None:
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
        attn_vec = torch.matmul(attn_prob, wv)
        return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
    
def _left_shift(x:Tensor):
    "Shift the line i of `x` by p-i elements to the left."
    bs,nh,n,p = x.size()
    x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)
    x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)
    return x_shift

In [None]:
class DecoderLayerXL(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1):
        self.mhra = MultiHeadRelativeAttention(n_heads, d_model, drops=drops)
        self.ff   = FeedForward(d_model, drops=drops)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ff(self.mhra(x, mask=mask, **kwargs))

In [None]:
class TransformerXL(Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, vocab_sz:int, d_model:int, n_layers:int, n_heads:int, drops:float=0.1, mem_len:int=150):
        d_head = d_model//n_heads
        self.emb = nn.Embedding(vocab_sz, d_model)
        self.drop_emb = nn.Dropout(drops)
        self.pos_enc = PositionalEncoding(d_model)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head))
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head))
        self.layers = nn.ModuleList([DecoderLayerXL(n_heads, d_model, drops) for k in range(n_layers)])
        self.init = False

    def forward(self, x):
        if not self.init:
            self.reset()
            self.init = True
        bs,x_len = x.size()
        inp = self.drop_emb(self.emb(x)) #.mul_(self.d_model ** 0.5)
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).bool()[None,None]
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids = []
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i]
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        self._update_mems(hids)
        return self.hidden,[core_out]
    
    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()

    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def select_hidden(self, idxs):
        # used in beam search only...
        self.hidden = [h[idxs] for h in self.hidden]

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(tgt, src)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, lm, layer, N):
        super(Decoder, self).__init__()
        self.lm_layer = lm
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src):
        x = self.lm_layer(x)
        
        for layer in self.layers:
            x = layer(x, src)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src):
        x = self.sublayer[0](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    lm = TransformerXL(vocab, d_model, n_layers=10, n_heads=8, drops=0.1)
    attn = MultiHeadedAttention(d_model, attn_heads) 
    ff   = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(lm, DecoderLayer(d_model, c(attn), c(ff), drops), N),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    dec_outs = self.transformer.decode(feats, tgt)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt)
            out = self.transformer.generate(dec_outs)
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

# w/ Transformer LM

In [None]:
class DecoderLayerWithLM(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = clones(feed_forward, 3)
        self.sublayer = clones(SublayerConnection(size, dropout), 5)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        # lm
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # shared btw lm and decoder
        lm = self.sublayer[1](x, self.feed_forward[0])
        
        # decoder layer
        dec = self.sublayer[2](x, lambda x: self.src_attn(x, src, src))
        dec = self.sublayer[3](dec, self.feed_forward[1])
        return self.sublayer[4](dec+lm, self.feed_forward[2])

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8, weight_tying=False):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayerWithLM(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    if weight_tying:
        model.generator.weight = model.tgt_embed[0].lut.weight
    
    return model

# w/ AWDLSTM LM integrated

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, pos_enc):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enc = pos_enc
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.pos_enc(tgt), src, tgt_mask)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8   #(cube root of d_model?)
        return x

In [None]:
def lm_mod_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        PositionalEncoding(d_model, drops, 2000),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

In [None]:
class LM(nn.Module):
    def __init__(self, vocab, d_model, n_hidden, n_layers):
        super(LM, self).__init__()
        self.lm = AWD_LSTM(vocab, d_model, n_hidden, n_layers, pad_token=0,
                           hidden_p=0.1, input_p=0.25, embed_p=0.05, weight_p=0.2)
                
    def forward(self, tgts):
        res = self.lm(tgts, from_embeddings=True)
        pdb.set_trace()
        return res

In [None]:
class Mixer(nn.Module):
    def __init__(self, d_model, vocab, drops=0.2):
        super(Mixer, self).__init__()
        self.mixer = nn.Sequential(PositionwiseFeedForward(d_model, drops), LayerNorm(d_model))
        self.generator = nn.Linear(d_model, vocab)
        
    def forward(self, dec_outs, lm_outs):
        mix = self.mixer((lm_outs+outs)/2)
        return self.generator(mix)

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_model, vocab, drops, pad_tok=0):
        super(Embedding, self).__init__()
        self.emb = nn.Embedding(vocab, d_model, padding_idx=pad_tok)
        self.emb_drop = EmbeddingDropout(self.emb, drops)

    def forward(self, x):
        return self.emb_drop(x)  #self.lut(x) * math.sqrt(self.d_model)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, embedding, lm, mixer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.embedding = embedding
        self.transformer = transformer
        self.lm = lm
        self.mixer = mixer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    tgt_emb = self.embedding(tgt)
                    dec_outs = self.transformer.decode(feats, tgt_emb, mask)
                    lm_outs = self.lm(tgt_emb)
                    prob = self.mixer(dec_outs[:,-1], lm_outs[:,-1])
                    #outs = self.lm(dec_outs[:,-1], self.transformer.embed, decode=True)
                    #prob = self.transformer.generate(outs)
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            tgt_emb = self.embedding(tgt)
            dec_outs = self.transformer(feats, tgt_emb, tgt_mask)    # ([bs, sl, d_model])
            lm_outs = self.lm(tgt_emb)
            out = self.mixer(dec_outs, lm_outs)
            #out = self.transformer.generate(outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    vocab = len(data.vocab.itos)
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = lm_mod_model(vocab, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    embedding = Embedding(d_model, vocab, drops, pad_tok=0)
    lm = LM(vocab, d_model, 1400, 3)
    mixer = Mixer(d_model, vocab, drops)
    net = Img2Seq(img_encoder, transformer, embedding, lm, mixer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

## Add LM to model state_dict

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
# update existing model according to LM modifications
sd['model']['embedding.emb.weight'] = sd['model']['transformer.tgt_embed.0.lut.weight']
sd['model']['embedding.emb_drop.emb.weight'] = sd['model']['transformer.tgt_embed.0.lut.weight']
sd['model']["transformer.pos_enc.pe"] = sd['model']['transformer.tgt_embed.1.pe']
sd['model']['mixer.generator.weight'] = sd['model']['transformer.generator.weight']
sd['model']['mixer.generator.bias'] = sd['model']['transformer.generator.bias']

del sd['model']['transformer.tgt_embed.1.pe']
del sd['model']['transformer.tgt_embed.0.lut.weight']
del sd['model']['transformer.generator.weight']
del sd['model']['transformer.generator.bias']

In [None]:
lm_sd = torch.load('data/wikitext/models/wiki103_lm_enc.pth', map_location=device)

In [None]:
from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd.items():
    name = 'lm.lm.'+k
    new_lm_sd[name] = v

In [None]:
sd['model'].update(new_lm_sd)

learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
# tie weights of transformer embedding and generator w/ lm encodings
learn.model.embedding.emb.weight = learn.model.lm.lm.encoder.weight
learn.model.embedding.emb_drop.emb.weight = learn.model.lm.lm.encoder_dp.emb.weight
learn.model.mixer.generator.weight = learn.model.lm.lm.encoder.weight

In [None]:
learn.save('combo_512_9_wiki103_base_lm')

## load and split learner

In [None]:
learn.load('combo_512_9_wiki103_base_lm')
None

In [None]:
learn.split([learn.model.img_enc, learn.model.embedding, learn.model.transformer, learn.model.lm, learn.model.mixer])
None

In [None]:
learn.layer_groups[4]

In [None]:
learn.freeze_to(-3)
learn.model.img_enc.linear.weight.requires_grad

In [None]:
lrs = slice(2e-5, 2e-4, 2e-3)

# w/ AWDLSTM added

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8, weight_tying=False):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    if weight_tying:
        model.generator.weight = model.tgt_embed[0].lut.weight
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, lm):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.lm = lm
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    dec_prob = self.transformer.generate(dec_outs[:,-1])
                    lm_prob,_,_ = self.lm(tgt)
                    prob = (dec_prob + lm_prob[:,-1])/2
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
lm_config = dict(emb_sz=512, n_hid=1400, n_layers=3, pad_token=0, qrnn=False, bidir=False, output_p=0.2,
              hidden_p=0.2, input_p=0.5, embed_p=0.1, weight_p=0.4, tie_weights=True, out_bias=True)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    vocab = len(data.vocab.itos)
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(vocab, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    lm = get_language_model(AWD_LSTM, vocab, lm_config, drop_mult=0.5)
    net = Img2Seq(img_encoder, transformer, lm)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

## Add LM to model state_dict

In [None]:
learn.model.lm

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
lm_sd = torch.load(PATH/'models/wiki2_lm.pth', map_location=device)

In [None]:
from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd['model'].items():
    name = 'lm.'+k
    new_lm_sd[name] = v

In [None]:
lm_sd['model'].keys()

In [None]:
sd['model'].update(new_lm_sd)

learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
learn.save('combo_512_9_wiki2_lm')

# Num Lines

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
                .split_by_rand_pct(valid_pct=0.15, seed=42)
                .label_from_df(cols=2)
                .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
                .databunch(bs=bs, val_bs=bs*2, device=device)
       )

In [None]:
l2 = cnn_learner(data, models.resnet34, metrics=accuracy)
# Note: already split and frozen:  m[0][6], m[1]

In [None]:
# load saved sd into l2

In [None]:
from collections import OrderedDict
new_sd = OrderedDict()

sd = torch.load(PATH/'models/combo_512_mix_res9.pth', map_location=device)
for k, v in sd['model'].items():
    if k.startswith('img_enc.base'):
        name = k.replace('img_enc.base', '0')
        new_sd[name] = v

In [None]:
l2_sd = torch.load(PATH/'models/line_97acc.pth', map_location=device)
for k, v in l2_sd['model'].items():
    if k.startswith('1.'):
        new_sd[k] = v

In [None]:
l2.model.load_state_dict(new_sd, strict=False)

In [None]:
l2.layer_groups[1][47].weight.requires_grad, l2.layer_groups[2][9].weight.requires_grad

In [None]:
l2.lr_find()
l2.recorder.plot(suggestion=True)

In [None]:
l2.fit_one_cycle(1,1e-3)
# 0.171519	0.069522	0.976006	12:57     'line_97acc'
# 0.094188	0.050964	0.988564	12:01  preload img_enc: combo_512_mix_res9 and head: line_97acc   'line_98acc'

In [None]:
l2.save('line_98acc')

In [None]:
# learn = cnn_learner(data, models.resnet34, custom_head=create_head(1024, 1), metrics=accuracy, loss_func=MSELossFlat())
# learn.fit_one_cycle(1,1e-3)
# # 1.229723	0.113453	0.9057	12:43

# preds,y,losses = learn.get_preds(with_loss=True)

# # actual accuracy
# preds.apply_(round)
# (preds.long().squeeze(-1) == y).float().mean()

# Experiments

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0, attn_type='multi', heads=8)
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
learn.load('combo_512_mix_res8')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.data = data

In [None]:
learn.fit_one_cycle(5, max_lr=1e-5, callbacks=[SaveModelCallback(learn, name='sm_256_lines2')])
#sm, 5cycle, 1e-3

# Image Preprocessing by num_lines
# 11.384931	10.208405	0.122443   'sm_256_lines'
# 10.548290	9.845284	0.116961   2nd run; 5cycle(1e-5); 'sm_256_lines2'

# 

# 2.474220	4.081600	0.036482  N:4,F.gelu,sz:256,em_sz:512,single,drop:0  'sm_256_1'
# greedy:    1.36697   .066
# 4.430638	4.402001	0.034730  2nd run, lr:1.5e-5, add tfms, drop:0.2   'sm_256_2'
# greedy:    0.37206   .04145

# 5.158808	4.964441	0.042544  "", w/ tfms, drop:0.2   'sm_256_3'
# greedy:    0.50646   .04103
# 4.266788	4.694098	0.039822  2nd run, lr:1.5e-5   'sm_256_4'
# greedy:    0.38972   .03879

# 3.678168	3.962710	0.035466  N:8,tfms   'sm_256_5'
# greedy:    0.38097   .04405

# 2.840161	3.429177	0.030243  N:4,tfms,multi(8)  'sm_256_6'
# greedy:    0.32775   .04201
# greedy:    1.38865   .06943

# 3.377438	3.786088	0.033931   N:6,tfms,drop:0.1,multi(8)   'sm_256_7'
# greedy:    0.36201   .04266  
# greedy:    1.40557   .06902

# 3.989137	4.353169	0.037717   N:4,weight-tying,multi(16)   'sm_256_8'
# greedy:    1.44291   .07299

# 4.806501	5.007887	0.042339   N:4,weight-tying,multi(8),drops:0.2   'sm_256_9'
# greedy:    1.46753   .07424

# 12.817862	11.842313	0.105486   N:4,no init/weight tying,multi(8),drops:0.1   'sm_256_10'

# 3.515456	3.880955	0.033763   N:4,multi(8),drops:0.1   'sm_256_11'
# greedy:    1.40945   .06964

# combo_cat6 - preload 'sm_256_7'
# 7.570516	6.670641	0.015070    'combo_512_7'
# greedy:    6.29406   .02186


# combo 6lines and under, preload 'sm_256_6', 10cycle(1e-3)
# 10.784849	9.144131	0.024880     'combo_512_13'
# greedy:    9.21715   .02788
# upload:    103.456   .44797
#     pg:    148.473   .69869

# combo 6lines and over, preload 'combo_512_13', 5cycle(2e-5)
# 56.732105	38.604137	0.026153    'combo_512_14'
# greedy:    141.022   .02933
# upload:    179.606   .81724

# combo_cat_pg - preload 'combo_512_7'
# 25.674347	21.161726	0.015869    'combo_512_8'
# greedy:    114.878   .03654
#   test:    115.247   .05014

# combo_cat_pg_dl_sorted - preload 'combo_512_8', 5cycle, 1.5e-4
# 9.800321	5.676855	0.006939    'combo_512_9'
# greedy:    34.8757   .01523

#   test:    91.3226   .05102
#   test:    104.685   .05483
#   test:    116.467   .05315

# upload:    115.149   .66801

# imdb_wiki_combo - preload 'combo_512_9'; split: img_enc, transformer; 3cycle, slice(5e-5, 1e-3)
# 20.390352	16.698141	0.018726    'combo_512_12'
# greedy:    60.8384   .02474
#     pg:    112.755   .05204

# combo_145k - fit:1cycle(1e-3); preload combo_512_9; drops:0.2
# 30.075714	31.024599	0.031703	1:50:33

# combo_145k <= 6 - 5cycle(1e-4); preload sm_256_mix_res
# 19.502026	17.365313	0.044436    'combo_512_mix_res'
# greedy:    14.6870   .04744
# upload:    75.1357   .47955
#     pg:    142.710   .66233

# test_combo 65k (sm,cat,pg,dl)
# 18.262995	15.091379	0.039351   5cycle(1e-5)  'combo_512_mix_res5'
# greedy:    129.968   .03932
# upload:    81.6835   .21794
#     pg:    127.317   .05613

# 16.678631	12.867742	0.035985   1cycle(1e-5), drops:0.1   'combo_512_mix_res6'
# 14.878900	12.797726	0.035834   1cycle(1e-6), drops:0.1   'combo_512_mix_res7'
# greedy:    34.3988   .05552
# upload:    75.5511   .21010
#     pg:    118.963   .05287

# 16.086544	12.320695	0.034263   1cycle(1e-5), drops:0.1   'combo_512_mix_res8'
# 11.640953	11.339241	0.032201   1cycle(1e-5), drops:0     'combo_512_mix_res9'
# greedy:    31.5506   .05410
#   test:    74.4820   .22067
#     pg:    105.182   .05076


# 79.324310	69.801376	0.325578	23:04   preload lines/combo_512_mix_res9;  1cycle(1e-4)

# Greedy

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
sd['model']["transformer.src_adapt.0.weight"] = sd['model']['img_enc.linear.weight']
sd['model']["transformer.src_adapt.0.bias"] = sd['model']['img_enc.linear.bias']

In [None]:
learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
# learn.load('combo_512_9_wiki2_lm')
learn.load('combo_512_mix_res')
None

In [None]:
def full_test(learn, sl, dl=data.valid_dl, batches=10):
    learn.model.eval()
    iterable = iter(dl)
    g_loss,g_cer=0,0
    if batches is None:
        batches = len(dl.dl.dataset)//bs
    for i in progress_bar(range(batches)):
        x,y = next(iterable)
        g_preds = learn.model(x, seq_len=sl)
        g_res = torch.argmax(g_preds, dim=-1)
        g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, itos)[0]/bs]
        g_loss+=g[0]
        g_cer+=g[1]
    return [g_loss/batches, g_cer/batches]

In [None]:
# set_seed()
g = full_test(learn, seq_len)

In [None]:
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
# losses = np.array([learn.loss_func(g_preds[i:i+1],y[i:i+1]).item() for i in range(bs)])
# cers = np.array([cer(g_preds[i:i+1],y[i:i+1])[0] for i in range(bs)])

In [None]:
# set_seed()
x,y = next(iter(learn.data.valid_dl))

g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y, word_itos)[0]/bs]

In [None]:
#greedy
fig, axes = plt.subplots(2,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], word_itos, sep='')
    ax=show_img(x[i], ax=ax, title=p)

# Test

### data

In [None]:
FOLDER = 'uploads'
df = pd.read_csv(PATH/'uploads.csv')
len(df)

sz,bs = 512,14
seq_len = 700

In [None]:
FOLDER = 'paragraphs'
df = pd.read_csv(PATH/'test_pg.csv')
len(df)

sz,bs = 512,15
seq_len = 700

In [None]:
## Test dataset only!!!
# set_seed()  # reproducibility
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos))
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
x,y = next(iter(data.train_dl))

### learner

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
# learn.load('combo_512_9_wiki2_lm')
learn.load('combo_512_mix_res')
None

In [None]:
# learn.data = data

## GPU batch testing

In [None]:
g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

# test:    397.121   .29261  ~3:55   combo_512_9_wiki2_lm  (LM isn't aware of '\n'!!)
# test:    101.374   .05139  ~27s    combo_512_9

In [None]:
#test
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    i +=4
    p = char_label_text(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

## Preprocessing Lines

In [None]:
# test data
x,y = next(iter(data.train_dl))

In [None]:
def lines_into_paragraph(res):
    out = []
    for r in res:
        s = np.split(r, np.where(r == 2)[0])[0].numpy()
        out.append(np.append(s,4))
    out = np.concatenate(out)
    out[-1] = 2  # replace final '\n' with 'eos'
    return torch.from_numpy(out)

In [None]:
def target_into_lines(targ, split_idx=4):
    nonzero = targ[targ.nonzero()].flatten()
    lines = np.split(nonzero, np.where(nonzero == split_idx)[0])
    maxlen = len(max(lines,key=len))
    res = torch.zeros((len(lines),maxlen))
    for i,arr in enumerate(lines):
        res[i,:len(arr)] = arr
    return res

In [None]:
from scipy import signal,ndimage
import statistics

def img_into_line_tensor(img):
    arr = img[0].numpy()
    w = arr.shape[-1]
    heights = []
    lengths = []
    stds = arr.std(axis=1)
    g_stds = scipy.ndimage.gaussian_filter1d(stds, 5)
    peaks,_ = scipy.signal.find_peaks(g_stds, prominence=stds.std()//3, distance=20)
    mins = scipy.signal.argrelextrema(g_stds, np.less_equal)[0]  # np.less_equal critical for flat minima, edges
    for p in peaks:
        rows = range(mins[mins < p][-1], mins[mins > p][0])
        lengths.append(len(rows))
        heights.append(arr[rows])
    max_len = max(lengths)
    outs = torch.ones((len(heights),3 , max_len, w))
    for i,h in enumerate(heights):
        outs[i,:,:len(h)] = torch.from_numpy(h)
    return outs

In [None]:
idx = 1
inps = img_into_line_tensor(x[idx])
targs = target_into_lines(y[idx])

In [None]:
g_preds = learn.model(inps, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)

In [None]:
print(f'test cer:   {str(cer(g_preds,targs)[0]/g_preds.size(0))[1:7]}')

## Single Image (cpu)

In [None]:
x,y = data.train_ds[5]
pred = learn.predict(x)
show_img(x, title=str(pred[0]))

In [None]:
pred[2].shape

In [None]:
probs = F.softmax(pred[2], dim=-1)

In [None]:
scores, idxs = torch.topk(probs, 3, dim=-1)

In [None]:
scores[0]

In [None]:
for s,i in zip(scores,idxs):
    new = {}
    for score, idx in zip(s,i):
        new[itos[idx]] = score.data
    print(new)

In [None]:
chars = idxs.numpy()
chars.appl

In [None]:
idxs.apply_(lambda x: itos[x])

In [None]:
for i in range(30):
    print(char_label_text(idxs[i], sep=' '))

In [None]:
# learn.show_results(ds_type=data.train_ds, rows=2)

In [None]:
i = 0
x = xs[i][None]
y = ys[i][None]

In [None]:
batch = data.one_item(xs[0])

In [None]:
xs,ys = data.one_batch(detach=False, denorm=False)

In [None]:
g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)[0]]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
p = char_label_text(g_res[0])
show_img(x[0], figsize=(18,10), title=p)

In [None]:
im = PATH/'uploads'/'test1.png'
img = open_image(im)
prediction = learn.predict(img)[0]
show_img(img, title=str(prediction))