# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks.tracker import *

import pdb

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Loss, Metrics, Callbacks

In [None]:
def tensor2im(x):
    x = x.detach().numpy() * 255
    x = np.uint8(x)[0]
    return PIL.Image.fromarray(x, mode='L')

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = self.loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs
    
    def loss_prep(self, pred, target):
        "equalize input/target sl; combine bs/sl dimensions"
        bs,tsl = target.shape
        _ ,sl,vocab = pred.shape

        # F.pad( front,back for dimensions: 1,0,2 )
        if sl>tsl: target = F.pad(target, (0,sl-tsl))

        # this should only be used when testing for small seq_lens
        # if tsl>sl: target = target[:,:sl]

        if tsl>sl: pred = F.pad(pred, (0,0,0,tsl-sl))
        # not ideal => adds 96 logits all 0s...

        targ = target.contiguous().view(-1).long()
        pred = pred.contiguous().view(-1, vocab)
        return pred, targ

In [None]:
def cer(preds, targs):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = char_label_text(res[i])   #.replace(' ', '')
        t = char_label_text(targs[i]) #.replace(' ', '')
        error += Lev.distance(t, p)/len(t)
    return error, bs

def char_label_text(pred, sep=''):
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)] #[:-1]  #remove eos token
    return sep.join([itos[i] for i in nonzero])

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self):
        super().__init__()
        self.name = 'cer'

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = cer(last_output, last_target)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((1,size,size), device=device).byte())
    #return torch.tril(torch.ones((1,1,size,size), device=device).byte())  # complex batches

In [None]:
class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        s = rshift(last_target).long()
        mask = subsequent_mask(s.size(-1))
        return {'last_input':(last_input, s, mask), 'last_target':last_target}

# Data

## sm synth dataset

In [None]:
fname = 'edited_sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'edited_sm_synth'

df = pd.read_csv(CSV)
len(df)

In [None]:
# sz,bs = 128,100
sz,bs = 256,100

num_lines,seq_len = 4,50

## Lines

In [None]:
fname = 'edited_line2.csv'
CSV = PATH/fname
FOLDER = 'square_lines'

df = pd.read_csv(CSV)
len(df)

In [None]:
df.head()

In [None]:
sz,bs = 256,100
# sz,bs = 256,100

seq_len = 60

## Concat Lines

In [None]:
CSV = PATH/f'edited_cat_lines.csv'
FOLDER = 'edited_cat_lines'

df = pd.read_csv(CSV)
len(df)

In [None]:
df[:2000]

In [None]:
nums = '3-6'   #'7-10'  #'11-14'
CSV = PATH/f'cat_lines_{nums}.csv'

FOLDER = 'resized_cat_lines'

csv = pd.read_csv(CSV)
test = pd.read_csv(PATH/'test_pg.csv')

len(csv), len(test)

In [None]:
lengths = np.array([len(i.split(' ')) for i in csv.char_ids.values])
lengths.max()

In [None]:
sz,bs = 512,20   #1000,5  #800,8   #512,20
seq_len = 600 #600 #450  #300
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

### Concat All

In [None]:
a = pd.read_csv(PATH/'cat_lines_3.csv').sample(1000)
b = pd.read_csv(PATH/'cat_lines_4.csv').sample(1000)
c = pd.read_csv(PATH/'cat_lines_5.csv').sample(1000)
d = pd.read_csv(PATH/'cat_lines_6.csv').sample(1000)
e = pd.read_csv(PATH/'cat_lines_7.csv').sample(1000)
f = pd.read_csv(PATH/'cat_lines_8.csv').sample(1000)
g = pd.read_csv(PATH/'cat_lines_9.csv').sample(1000)
h = pd.read_csv(PATH/'cat_lines_10.csv').sample(1000)
i = pd.read_csv(PATH/'cat_lines_11.csv').sample(1000)
j = pd.read_csv(PATH/'cat_lines_12.csv').sample(1000)

In [None]:
k = pd.read_csv(PATH/'paragraph_chars.csv')

In [None]:
val_idxs = np.array(k.sample(frac=0.15, random_state=42).index)

In [None]:
trn = k[~k.index.isin(val_idxs)]
test = k[k.index.isin(val_idxs)]

In [None]:
new = pd.concat([a,b,c,d,e,f,g,h,i,j,trn], ignore_index=True)
len(new)

In [None]:
# new.to_csv(PATH/'cat_lines_pg.csv', index=False)
# test.to_csv(PATH/'test_pg.csv', index=False)

## Paragraphs

In [None]:
fname = 'edited_pg.csv' #'paragraphs.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

df = pd.read_csv(CSV)
len(df)

In [None]:
df = df[:10]
sz,bs,seq_len = 256,10,50

In [None]:
# sz,bs = 1000,5  #1024,5  #1000,5   #~2000x1000 full size
# sz,bs = 800,8
sz,bs = 512,10

seq_len = 700   #~400 chars/paragraph - max: 705
# stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

## Downloaded Images

In [None]:
CSV = PATH/'downloaded_images.csv'
FOLDER = 'downloaded_images'

df = pd.read_csv(CSV)
len(df)

In [None]:
# csv['filename'] = csv['filename'].apply(lambda x: f"dl_{x}")
# csv.head()
# csv.to_csv(CSV, index=False)

In [None]:
# sz,bs = 1000,5  #1024,5  #1000,5   #~2000x1000 full size
# sz,bs = 800,8
sz,bs = 512,41

seq_len = 700   #~400 chars/paragraph - max: 705
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

## Fonts

In [None]:
fname = 'edited_font.csv'
CSV = PATH/fname
FOLDER = 'fonts_resize'

df = pd.read_csv(CSV)
len(df)

In [None]:
df = df[df.num_lines<5]
len(df)

In [None]:
# sz,bs = 256,20
# sz,bs = 400,10
sz,bs = 512,50

## Mix

In [None]:
fname = 'mix.csv'
CSV = PATH/fname
FOLDER = 'mix'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 512,10
# sz,bs = 800,5

seq_len = 800

## Combo/Cat Lines

In [None]:
# 6 and fewer
fname = 'combo_cat6lines.csv'
sz,bs = 512,30
seq_len = 600

In [None]:
# 6 and greater of combo/cat lines + all paragraph (2-13 lines)
fname = 'combo_cat_pg.csv'
sz,bs = 512,10
seq_len = 750

In [None]:
# full mix sorted by num_lines
fname = 'combo_cat_pg_dl_sorted.csv'
sz,bs = 512,10
seq_len = 750

In [None]:
CSV = PATH/fname
FOLDER = 'combo_cat'

In [None]:
df = pd.read_csv(CSV)
len(df)

## Test

In [None]:
FOLDER = 'uploads'
df = pd.read_csv(PATH/'uploads.csv')
len(df)

sz,bs = 512,14
seq_len = 700

In [None]:
FOLDER = 'paragraphs'
df = pd.read_csv(PATH/'test_pg.csv')
len(df)

sz,bs = 512,15
seq_len = 700

# ModelData

In [None]:
tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=0, max_warp=0.1)

def force_gray(image): return image.convert('L').convert('RGB')

def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels    
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

## Chars

In [None]:
itos = pickle.load(open(PATH/'itos.pkl', 'rb'))

### Simple batches (BS, Seq Len)

In [None]:
class CharTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return list(t)
            
class CharVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

class SequenceList(TextList):    
    def __init__(self, items:Iterator, vocab:Vocab, **kwargs):
        toknizr = Tokenizer(tok_func=CharTokenizer, pre_rules=[], post_rules=[], special_cases=[BOS,EOS,UNK,PAD])
        procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False, include_eos=True),
                 NumericalizeProcessor(vocab=vocab)]
        super().__init__(items, vocab, sep='', pad_idx=0, processor=procs)
    
    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
#         .split_none()
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        #.label_from_df(label_cls=TextList, sep='', pad_idx=0, vocab=vocab, processor=procs)
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos))
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        #.transform(tfms, size=sz, resize_method=ResizeMethod.PAD, padding_mode='border')
        # maintains aspect ratio but too small for good results => mostly whitespace
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        #.normalize()
        # this sets x values to an odd range (~.3,-6)
       )

In [None]:
## Test dataset only!!!

data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos))
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

### Complex Batches (BS, Lines, Char Sequence)

In [None]:
def split_1d_array_to_2d_tensor(a, split_idx=4):
    "Requires global num_lines variable to be set..."
    b = np.split(a, np.where(a == split_idx)[0])
    maxlen = len(max(b,key=len))
    res = torch.zeros((maxlen, num_lines))
    for i,arr in enumerate(b):
        res[:len(arr),i] = torch.from_numpy(arr)
    return res

def batch_line_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1,1).long()
        return imgs, labels

    res = []
    for lbl in lbls:
        res.append(split_1d_array_to_2d_tensor(lbl))
    labels = torch.nn.utils.rnn.pad_sequence(res, batch_first=True)
    return imgs, labels.permute(0,2,1)

In [None]:
class CharTokenizer():
    def __init__(self, n_cpus:int=None):
        self.n_cpus = ifnone(n_cpus, defaults.cpus)

    def tokenize(self, t:str): return list(t)+['xxeos']

    def _process_all_1(self, texts:Collection[str]) -> List[List[str]]:
        "Process a list of `texts` in one process."
        return [self.tokenize(str(t)) for t in texts]

    def process_all(self, texts:Collection[str]) -> List[List[str]]:
        "Process a list of `texts`."
        if self.n_cpus <= 1: return self._process_all_1(texts)
        with ProcessPoolExecutor(self.n_cpus) as e:
            return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), [])

class CharVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int]):
        nums = nums[:-1]  #remove bos/eos tokens
        return ''.join([self.itos[i] for i in nums.astype(int)])

In [None]:
class SequenceList(ItemList):
    _processor = [partial(TokenizeProcessor, tokenizer=CharTokenizer(), include_bos=False), NumericalizeProcessor]

    def __init__(self, items:Iterator, itos:List[str]=None, **kwargs):
        super().__init__(items, **kwargs)
        self.vocab=CharVocab(itos)
        self.pad_idx=0
        self.copy_new += ['vocab', 'pad_idx']

    def get(self, i):
        o = super().get(i)
        return o if self.vocab is None else Text(o, self.vocab.textify(o))

    def reconstruct(self, t:Tensor):
        o = t.numpy()
        o = o[np.nonzero(o)].flatten()
        return Text(o, self.vocab.textify(o))

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SequenceList, itos=itos)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=batch_line_collater)
        .normalize()
       )

### SequenceList (old)

In [None]:
# class CharTokenizer():
#     def __init__(self, n_cpus:int=None):
#         self.n_cpus = ifnone(n_cpus, defaults.cpus)

#     def tokenize(self, t:str): return list(t)+['xxeos']

#     def _process_all_1(self, texts:Collection[str]) -> List[List[str]]:
#         "Process a list of `texts` in one process."
#         return [self.tokenize(str(t)) for t in texts]

#     def process_all(self, texts:Collection[str]) -> List[List[str]]:
#         "Process a list of `texts`."
#         if self.n_cpus <= 1: return self._process_all_1(texts)
#         with ProcessPoolExecutor(self.n_cpus) as e:
#             return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), [])

In [None]:
class SequenceItem(ItemBase):
    def __init__(self,data,vocab): self.data,self.vocab = data,vocab        
    def __str__(self): return self.textify(self.data)
    def __hash__(self): return hash(str(self))
    def textify(self, data): return ''.join([self.vocab[i] for i in data[:-1]])
        
class ArrayProcessor(PreProcessor):
    "Convert df column (string of ints) into np.array"
    def __init__(self, ds:ItemList=None): None
    def process_one(self,item): return np.array(item.split(), dtype=np.int64)
    def process(self, ds): super().process(ds)
        
class ItosProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None): self.itos = ds.itos
    def process(self, ds:ItemList): ds.itos = self.itos
        
class SequenceList(ItemList):
    _processor = [ItosProcessor, ArrayProcessor]
    
    def __init__(self, items:Iterator, itos:List[str]=None, **kwargs):
        super().__init__(items, **kwargs)
        self.itos = itos
        self.copy_new += ['itos']
        self.c = len(self.items)

    def get(self, i):
        o = super().get(i)
        return SequenceItem(o, self.itos)

    def reconstruct(self,t):
        # Converting padded tensor back into np.array
        o = t.numpy()
        o = o[np.nonzero(o)]                  # remove 0 padding
        return SequenceItem(o, self.itos)
    
    def analyze_pred(self,pred):
        return torch.argmax(pred, dim=-1)
        # method called in learn.predict() or learn.show_results()
        # to transform predictions in an output tensor suitable for reconstruct

In [None]:
# # This slows down training...
# class CustomSampler(Sampler):
#     "sort dataset by longest y sequences"
#     def __init__(self, dataset):
#         self.dataset = dataset
#         self.lengths = [len(i) for i in dataset.y.items]
#         self.sorted_idxs = np.flip(np.argsort(self.lengths))
#     def __len__(self): return len(self.dataset)
#     def __iter__(self): return iter(self.sorted_idxs)

# ds = data.train_ds
# tfms = data.train_dl.tfms
# sampler = CustomSampler(ds)
# dl = DataLoader(ds, bs, shuffle=False, sampler=sampler, num_workers=num_cpus(), collate_fn=custom_collater, drop_last=True)
# ddl = DeviceDataLoader(dl, device, tfms, custom_collater)
# data.train_dl = ddl

In [None]:
data.batch_stats()
# no normalization: [tensor([0.9403, 0.9403, 0.9403]), tensor([0.1604, 0.1604, 0.1604])]
# imagenet_stats:   [tensor([1.9973, 2.1714, 2.3839]), tensor([0.6974, 0.7130, 0.7098])]
# normalize():      [tensor([0.0225, 0.0225, 0.0225]), tensor([0.9676, 0.9676, 0.9676])]

## SentencePiece

In [None]:
def char_label_text(pred):
    return self.sp.DecodeIds(pred.tolist())

In [None]:
class SPMProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.sp = ds.sp if ds is not None else None

    def process_one(self,item): return self.sp.EncodeAsIds(item)
    def process(self, ds): super().process(ds)
    
class SPMList(ItemList):
    _processor = [SPMProcessor]

    def __init__(self, items:Iterator, sp_processor, **kwargs):
        super().__init__(items, **kwargs)
        self.sp = sp_processor
        self.copy_new += ['sp']

    def get(self, i):
        o = super().get(i)
        return Text(o, self.sp.DecodeIds(o))

    def reconstruct(self, t:Tensor):
        return Text(t, self.sp.DecodeIds(t.tolist()))

In [None]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load(str(PATH/'spm_train.model'))
sp.SetEncodeExtraOptions("bos:eos")

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SPMList, sp_processor=sp)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        .normalize()
       )

## Words

In [None]:
word_itos = word_itos[:10000]

In [None]:
word_itos = pickle.load(open(PATH/'word_itos_60k.pkl', 'rb'))

In [None]:
vocab = Vocab(word_itos)
procs = [TokenizeProcessor(include_bos=False, include_eos=True), NumericalizeProcessor()]

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=TextList, pad_idx=0, vocab=vocab, processor=procs)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
toknizr = Tokenizer(special_cases=[PAD,UNK,PAD,BOS,EOS,TK_MAJ,TK_UP,TK_REP,TK_WREP])
procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False, include_eos=True), NumericalizeProcessor(max_vocab=10000)]

data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=TextList, pad_idx=0, processor=procs)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
word_itos = data.train_ds.vocab.itos

## Display

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

# Transformer Modules

In [None]:
# LayerNorm = nn.LayerNorm
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training
# LayerNorm = partial(nn.BatchNorm2d, eps=1e-4)

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # acts as a weak LM
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
# class GeLU(nn.Module):
#     def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
#         self.activation = GeLU() #nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Full Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

# Experiment - Chars

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')

In [None]:
learn.load('combo_512_9')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.fit_one_cycle(5, max_lr=1e-3, callbacks=[SaveModelCallback(learn, name='combo_512_10')])
#sm
#5cycle,1e-3

# 2.474220	4.081600	0.036482  N:4,F.gelu,sz:256,em_sz:512,single,drop:0  'sm_256_1'
# greedy:    1.36697   .066
# 4.430638	4.402001	0.034730  2nd run, lr:1.5e-5, add tfms, drop:0.2   'sm_256_2'
# greedy:    0.37206   .04145

# 5.158808	4.964441	0.042544  "", w/ tfms, drop:0.2   'sm_256_3'
# greedy:    0.50646   .04103
# 4.266788	4.694098	0.039822  2nd run, lr:1.5e-5   'sm_256_4'
# greedy:    0.38972   .03879

# 3.678168	3.962710	0.035466  N:8,tfms   'sm_256_5'
# greedy:    0.38097   .04405

# 2.840161	3.429177	0.030243  N:4,tfms,multi(8)  'sm_256_6'
# greedy:    0.32775   .04201

# 3.377438	3.786088	0.033931   N:6,tfms,drop:0.1,multi(8)   'sm_256_7'
# greedy:    0.36201   .04266

# combo_cat6 - preload 'sm_256_7'
# 7.570516	6.670641	0.015070    'combo_512_7'
# greedy:    6.29406   .02186

# combo_cat_pg - preload 'combo_512_7'
# 25.674347	21.161726	0.015869    'combo_512_8'
# greedy:    114.878   .03654
#   test:    115.247   .05014

# combo_cat_pg_dl_sorted - preload 'combo_512_8', 5cycle, 1.5e-4
# 9.800321	5.676855	0.006939    'combo_512_9'
# greedy:    34.8757   .01523

#   test:    91.3226   .05102
#   test:    104.685   .05483
#   test:    116.467   .05315

# upload:    115.149   .66801

In [None]:
# learn.save('combo_512_8')

# Experiment - Words

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')

In [None]:
learn.load('combo_512_9')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.fit_one_cycle(15, max_lr=1e-3, callbacks=[SaveModelCallback(learn, name='word_sm_2')])

# word 60k
# sm, 5cycle, 1e-3

# 22.903589	22.401066	0.520125   N:6,em_sz:512,tfms,drop:0.1,multi(8)  'word_sm_1'
# 32.389709	31.828810	0.746749   N:4,em_sz:512,tfms,drop:0.1,single
# 34.492157	32.719650	0.776098   "",em_sz:256

# word 7k (auto-tokenize from data)
# sm, 5cycle, 1e-3
# 16.987526	16.639275	0.421535   N:4,em_sz:512,tfms,drop:0.1,single
# 10.224819	10.781199	0.297889   2nd run
# 6.486265	7.934221	0.225055   3rd run    'word_sm_4'

# word 10k (itos from larger txt files)
# sm, 15cycle, 1e-3
# 3.893518	6.037587	0.209027   N:4,em_sz:512,tfms,drop:0.1,single   'word_sm_2'

# nn.Transformer

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, vocab, d_model, em_sz, **kwargs):
        super(Img2Seq, self).__init__()
        self.img_enc = ResnetBase(d_model, em_sz)
        self.transformer = nn.Transformer(d_model, **kwargs)
        self.tgt_embed = nn.Sequential(Embeddings(d_model, vocab), PositionalEncoding(d_model))
        self.generator = nn.Linear(d_model, vocab)
        
    def forward(self, src, tgt=None, tgt_mask=None):
        feats = self.img_enc(src).permute(1,0,2)
        
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                
                feats = self.transformer.encode(feats)
                bs = src.size(0)
                tgt = torch.ones((1,bs), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(0))
                    dec_outs = self.transformer.decode(feats, self.tgt_embed(tgt), mask)
                    prob = self.generator(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            tgt = self.tgt_embed(tgt).permute(1,0,2)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.generator(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, vocab, d_model, em_sz, **kwargs):
    net = Img2Seq(vocab, d_model, em_sz, **kwargs)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, len(itos), 512, 512)

In [None]:
learn.fit_one_cycle(5, max_lr=1e-3)

# 50, 5e-3
# 58.284210  d_model:512, singlehead
# 32.917747  d_model:256
# 36.486694  see below...
# 32.104549  d_model:128
# 33.139687  d_model:256, BN instead of linear
# 41.315910  d_model:256, no BN/linear/*8
# 39.416367  d_model:256, LayerNorm(1e-5)
# 36.486694  d_model:256, LayerNorm(1e-4)  same as above...
# 33.884861  d_model:256, LayerNorm(1e-3)

In [None]:
# learn.save('tmp_overfit_7.45')

# Spatial Encoding (fixed, after flatten - 1d)

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, src_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.src_embed = src_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(self.src_embed(src))
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x).flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
#     attn = SingleHeadedAttention(d_model)
    attn = MultiHeadedAttention(d_model, 8)
    ff = PositionwiseFeedForward(d_model, drops)
    pos = PositionalEncoding(d_model, drops, 2000)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(Embeddings(d_model, vocab), pos),
        pos,
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.img_enc(src)
                feats = self.transformer.encode(feats)
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
d_model = 512
em_sz = 256
img_encoder = ResnetBase(em_sz, d_model)
transformer = make_full_model(len(itos), d_model)  #len(itos)
net = Img2Seq(img_encoder, transformer, seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# Complex Batches (lines x char_len)

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),tgt.size(1),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:,:-1]), dim=-1)

In [None]:
class PositionalEncoding(nn.Module):
    "Modified the PE function for 2 dimensions"
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        channels = d_model//2 #dims
        pe = torch.zeros(max_len, channels)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / channels
        div_term = torch.exp(torch.arange(0.0, channels, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)
        
        w_pe = F.pad(pe, (0,channels)).unsqueeze(2)
        h_pe = F.pad(pe, (channels,0)).unsqueeze(1)
        pe = w_pe + h_pe

        self.register_buffer('pe', pe)    #(1,max_len,max_len,d_model)
        
    def forward(self, x):
        # ([bs, h, w, d_model])
        x = x + self.pe[:, :x.size(1), :x.size(2)]   #addition
#         x = torch.cat([x, self.pe[:, :x.size(1), :x.size(2)]])  #concatenation
        return self.dropout(x)

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.BatchNorm2d(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.size = d_model
        self.self_attn = ImageSelfAttention(d_model, dropout=dropout)
        
        self.conv1 = nn.Conv2d(d_model, d_model*4, 1)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(d_model*4, d_model, 1)

        self.norm1 = nn.BatchNorm2d(d_model)
        self.norm2 = nn.BatchNorm2d(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x2 = self.norm1(x)
        x2 = self.self_attn(x2, x2, x2)
        x = x + self.dropout1(x2)
        
        x2 = self.conv2(self.dropout(F.relu(self.conv1(self.norm2(x)))))
        x = x + self.dropout2(x)
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, dropout):
        super(DecoderLayer, self).__init__()
        self.size = d_model
        self.self_attn = SingleHeadedAttention(d_model)
        self.src_attn = SingleHeadedAttention(d_model)

        self.linear1 = nn.Linear(d_model, d_model*4)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_model*4, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, src, tgt_mask=None):        
        tgt2 = self.norm1(tgt)
        tgt2 = self.self_attn(tgt2, tgt2, tgt2, tgt_mask)
        tgt = tgt + self.dropout1(tgt2)
        
        src = src.permute(0,2,3,1)     # bs,h,w,d_model
        tgt2 = self.src_attn(self.norm2(tgt), src, src)
        tgt = tgt + self.dropout2(tgt2)
        
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model, num_lines):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.conv = nn.Conv2d(em_sz, d_model, 1)
        self.pool = nn.AdaptiveMaxPool2d((num_lines, None))
#         self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        return self.pool(self.conv(self.base(x)))
#         x = self.base(x) #.permute(0,2,3,1)   #bs,h(#rows),w(#cols),d_model
#         x = self.linear(x) * 8        
#         return x

In [None]:
class ImageSelfAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(ImageSelfAttention, self).__init__()
        self.convs = clones(nn.Conv2d(d_model, d_model, 1), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.convs, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.convs[-1](x)

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, drops), N),
        Decoder(DecoderLayer(d_model, drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, num_lines, seq_len):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.num_lines = num_lines
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,self.num_lines,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
d_model = 512
em_sz = 256
img_encoder = ResnetBase(em_sz, d_model, num_lines)
transformer = make_full_model(len(itos), d_model)
net = Img2Seq(img_encoder, transformer, num_lines, seq_len)

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1), metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

In [None]:
def loss_prep(preds, target):
    bs,tlines,tsl = target.shape
    _,lines,sl,vocab = preds.shape
                
    # F.pad( front,back for dimensions: last,second to last,... )
    if sl>tsl: target = F.pad(target, (0,sl-tsl))
    if lines>tlines: target = F.pad(target, (0,0,0,lines-tlines))
        
    if tsl>sl: preds = F.pad(preds, (0,0,0,tsl-sl))
    if tlines>lines: target = F.pad(target, (0,0,0,0,0,tlines-lines))
    # not ideal => adds 96 logits all 0s...
        
    targ = target.contiguous().view(-1).long()
    pred = preds.contiguous().view(-1, vocab)
    return pred, targ

In [None]:
# learn.fit_one_cycle(3, 1e-3)
#  18.494501	15.946078	0.178671	03:18  3cycle,1e-3
#  11.922337	11.197309	0.118325	03:18  2nd run

In [None]:
# learn.save('complex_batches_128')
learn.load('complex_batches_128')
None

# Different size Convs

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]

        self.base = nn.Sequential(*modules)
#         self.base[0] = nn.Conv2d(3,64,7,stride=(2,1),padding=3,bias=False)
        self.base[3] = Lambda(lambda x: x) #nn.MaxPool2d(3,stride=(2,1),padding=1)
#         self.base[5][0].conv1 = nn.Conv2d(64,128,3,stride=(2,1),padding=1)
#         self.base[5][0].downsample[0] = nn.Conv2d(64,128,1,stride=(2,1))
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
#     single_attn = SingleHeadedAttention(d_model)
    multi_attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(multi_attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(multi_attn), c(multi_attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
d_model = 512
em_sz = 512 #256
# img_encoder = ResnetBase(em_sz, d_model)
# transformer = make_full_model(len(itos), d_model)  #len(itos)
net = Img2Seq(ResnetBase(em_sz, d_model), make_full_model(len(itos), d_model), seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# Deformable Conv

In [None]:
# https://github.com/oeway/pytorch-deform-conv/blob/master/torch_deform_conv

from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates


def th_flatten(a):
    """Flatten tensor"""
    return a.contiguous().view(a.nelement())


def th_repeat(a, repeats, axis=0):
    """Torch version of np.repeat for 1D"""
    assert len(a.size()) == 1
    return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))


def np_repeat_2d(a, repeats):
    """Tensorflow version of np.repeat for 2D"""

    assert len(a.shape) == 2
    a = np.expand_dims(a, 0)
    a = np.tile(a, [repeats, 1, 1])
    return a


def th_gather_2d(input, coords):
    inds = coords[:, 0]*input.size(1) + coords[:, 1]
    x = torch.index_select(th_flatten(input), 0, inds)
    return x.view(coords.size(0))


def th_map_coordinates(input, coords, order=1):
    """Tensorflow verion of scipy.ndimage.map_coordinates
    Note that coords is transposed and only 2D is supported
    Parameters
    ----------
    input : tf.Tensor. shape = (s, s)
    coords : tf.Tensor. shape = (n_points, 2)
    """

    assert order == 1
    input_size = input.size(0)

    coords = torch.clamp(coords, 0, input_size - 1)
    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1)
    coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1)

    vals_lt = th_gather_2d(input,  coords_lt.detach())
    vals_rb = th_gather_2d(input,  coords_rb.detach())
    vals_lb = th_gather_2d(input,  coords_lb.detach())
    vals_rt = th_gather_2d(input,  coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())

    vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0]
    vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0]
    mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1]
    return mapped_vals


def sp_batch_map_coordinates(inputs, coords):
    """Reference implementation for batch_map_coordinates"""
    # coords = coords.clip(0, inputs.shape[1] - 1)

    assert (coords.shape[2] == 2)
    height = coords[:,:,0].clip(0, inputs.shape[1] - 1)
    width = coords[:,:,1].clip(0, inputs.shape[2] - 1)
    np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2)

    mapped_vals = np.array([
        sp_map_coordinates(input, coord.T, mode='nearest', order=1)
        for input, coord in zip(inputs, coords)
    ])
    return mapped_vals


def th_batch_map_coordinates(input, coords, order=1):
    """Batch version of th_map_coordinates
    Only supports 2D feature maps
    Parameters
    ----------
    input : tf.Tensor. shape = (b, s, s)
    coords : tf.Tensor. shape = (b, n_points, 2)
    Returns
    -------
    tf.Tensor. shape = (b, s, s)
    """

    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    n_coords = coords.size(1)

    # coords = torch.clamp(coords, 0, input_size - 1)

    coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2)

    assert (coords.size(1) == n_coords)

    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
    coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2)
    idx = th_repeat(torch.arange(0, batch_size), n_coords).long()
#     idx = Variable(idx, requires_grad=False)
    if input.is_cuda:
        idx = idx.cuda()

    def _get_vals_by_coords(input, coords):
        indices = torch.stack([
            idx, th_flatten(coords[..., 0]), th_flatten(coords[..., 1])
        ], 1)
        inds = indices[:, 0]*input.size(1)*input.size(2)+ indices[:, 1]*input.size(2) + indices[:, 2]
        vals = th_flatten(input).index_select(0, inds)
        vals = vals.view(batch_size, n_coords)
        return vals

    vals_lt = _get_vals_by_coords(input, coords_lt.detach())
    vals_rb = _get_vals_by_coords(input, coords_rb.detach())
    vals_lb = _get_vals_by_coords(input, coords_lb.detach())
    vals_rt = _get_vals_by_coords(input, coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())
    vals_t = coords_offset_lt[..., 0]*(vals_rt - vals_lt) + vals_lt
    vals_b = coords_offset_lt[..., 0]*(vals_rb - vals_lb) + vals_lb
    mapped_vals = coords_offset_lt[..., 1]* (vals_b - vals_t) + vals_t
    return mapped_vals


def sp_batch_map_offsets(input, offsets):
    """Reference implementation for tf_batch_map_offsets"""

    batch_size = input.shape[0]
    input_height = input.shape[1]
    input_width = input.shape[2]

    offsets = offsets.reshape(batch_size, -1, 2)
    grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2)
    grid = np.repeat([grid], batch_size, axis=0)
    coords = offsets + grid
    # coords = coords.clip(0, input_size - 1)

    mapped_vals = sp_batch_map_coordinates(input, coords)
    return mapped_vals


def th_generate_grid(batch_size, input_height, input_width, dtype, cuda):
    grid = np.meshgrid(
        range(input_height), range(input_width), indexing='ij'
    )
    grid = np.stack(grid, axis=-1)
    grid = grid.reshape(-1, 2)

    grid = np_repeat_2d(grid, batch_size)
    grid = torch.from_numpy(grid).type(dtype)
    if cuda:
        grid = grid.cuda()
    return grid #Variable(grid, requires_grad=False)


def th_batch_map_offsets(input, offsets, grid=None, order=1):
    """Batch map offsets into input
    Parameters
    ---------
    input : torch.Tensor. shape = (b, s, s)
    offsets: torch.Tensor. shape = (b, s, s, 2)
    Returns
    -------
    torch.Tensor. shape = (b, s, s)
    """
    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    offsets = offsets.view(batch_size, -1, 2)
    if grid is None:
        grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda)

    coords = offsets + grid

    mapped_vals = th_batch_map_coordinates(input, coords)
    return mapped_vals

In [None]:
class ConvOffset2D(nn.Conv2d):
    """ConvOffset2D
    Convolutional layer responsible for learning the 2D offsets and output the
    deformed feature map using bilinear interpolation
    Note that this layer does not perform convolution on the deformed feature
    map. See get_deform_cnn in cnn.py for usage
    """
    def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
        """Init
        Parameters
        ----------
        filters : int
            Number of channel of the input feature map
        init_normal_stddev : float
            Normal kernel initialization
        **kwargs:
            Pass to superclass. See Con2d layer in pytorch
        """
        self.filters = filters
        self._grid_param = None
        super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs)
        self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))

    def forward(self, x):
        """Return the deformed featured map"""
        x_shape = x.size()
        offsets = super(ConvOffset2D, self).forward(x)

        # offsets: (b*c, h, w, 2)
        offsets = self._to_bc_h_w_2(offsets, x_shape)

        # x: (b*c, h, w)
        x = self._to_bc_h_w(x, x_shape)

        # X_offset: (b*c, h, w)
        x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x))

        # x_offset: (b, h, w, c)
        x_offset = self._to_b_c_h_w(x_offset, x_shape)

        return x_offset

    @staticmethod
    def _get_grid(self, x):
        batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
        dtype, cuda = x.data.type(), x.data.is_cuda
        if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
            return self._grid
        self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
        self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
        return self._grid

    @staticmethod
    def _init_weights(weights, std):
        fan_out = weights.size(0)
        fan_in = weights.size(1) * weights.size(2) * weights.size(3)
        w = np.random.normal(0.0, std, (fan_out, fan_in))
        return torch.from_numpy(w.reshape(weights.size()))

    @staticmethod
    def _to_bc_h_w_2(x, x_shape):
        """(b, 2c, h, w) -> (b*c, h, w, 2)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2)
        return x

    @staticmethod
    def _to_bc_h_w(x, x_shape):
        """(b, c, h, w) -> (b*c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
        return x

    @staticmethod
    def _to_b_c_h_w(x, x_shape):
        """(b*c, h, w) -> (b, c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))
        return x

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]

        self.base = nn.Sequential(*modules)
#         self.base[0] = nn.Conv2d(3,64,7,stride=(2,1),padding=3,bias=False)
#         self.base[3] = Lambda(lambda x: x) #nn.MaxPool2d(3,stride=(2,1),padding=1)
        self.base[4][2].conv2 = nn.Sequential(ConvOffset2D(64), nn.Conv2d(64,64,3,stride=1,padding=1))
        self.base[5][3].conv2 = nn.Sequential(ConvOffset2D(128), nn.Conv2d(128,128,3,stride=1,padding=1))
        self.base[6][5].conv2 = nn.Sequential(ConvOffset2D(256), nn.Conv2d(256,256,3,stride=1,padding=1))

#         self.base[5][0].conv1 = nn.Conv2d(64,128,3,stride=(2,1),padding=1)
#         self.base[5][0].downsample[0] = nn.Conv2d(64,128,1,stride=(2,1))
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
d_model = 512
em_sz = 256
# img_encoder = ResnetBase(em_sz, d_model)
# transformer = make_full_model(len(itos), d_model)  #len(itos)
net = Img2Seq(ResnetBase(em_sz, d_model), make_full_model(len(itos), d_model), seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# TPS

In [None]:
# https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py

class TPS_SpatialTransformerNetwork(nn.Module):
    """ Rectification Network of RARE, namely TPS based STN """

    def __init__(self, I_size, I_r_size, I_channel_num=1, F=20):
        """ Based on RARE TPS
        input:
            batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
            F: num of fiducial points
            I_size : (height, width) of the input image I
            I_r_size : (height, width) of the rectified image I_r
            I_channel_num : the number of channels of the input image I
        output:
            batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
        """
        super(TPS_SpatialTransformerNetwork, self).__init__()
        self.F = F
        self.I_size = I_size
        self.I_r_size = I_r_size  # = (I_r_height, I_r_width)
        self.I_channel_num = I_channel_num
        self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
        self.GridGenerator = GridGenerator(self.F, self.I_r_size)

    def forward(self, batch_I):
        batch_C_prime = self.LocalizationNetwork(batch_I)  # bs x K x 2
        build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # bs x n (= I_r_width x I_r_height) x 2
        build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
        batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')

        return batch_I_r


class LocalizationNetwork(nn.Module):
    """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """

    def __init__(self, F, I_channel_num):
        super(LocalizationNetwork, self).__init__()
        self.F = F
        self.I_channel_num = I_channel_num
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
                      bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.MaxPool2d(2, 2),  # batch_size x 64 x I_height/2 x I_width/2
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.MaxPool2d(2, 2),  # batch_size x 128 x I_height/4 x I_width/4
            nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.MaxPool2d(2, 2),  # batch_size x 256 x I_height/8 x I_width/8
            nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.AdaptiveAvgPool2d(1)  # batch_size x 512
        )

        self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
        self.localization_fc2 = nn.Linear(256, self.F * 2)

        # Init fc2 in LocalizationNetwork
        self.localization_fc2.weight.data.fill_(0)
        """ see RARE paper Fig. 6 (a) """
        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
        ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
        ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
        self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)

    def forward(self, batch_I):
        """
        input:     batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
        output:    batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
        """
        batch_size = batch_I.size(0)
        features = self.conv(batch_I).view(batch_size, -1)
        batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
        return batch_C_prime


class GridGenerator(nn.Module):
    """ Grid Generator of RARE, which produces P_prime by multipling T with P """

    def __init__(self, F, I_r_size):
        """ Generate P_hat and inv_delta_C for later """
        super(GridGenerator, self).__init__()
        self.eps = 1e-6
        self.I_r_height, self.I_r_width = I_r_size
        self.F = F
        self.C = self._build_C(self.F)  # F x 2
        self.P = self._build_P(self.I_r_width, self.I_r_height)
        self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float())  # F+3 x F+3
        self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float())  # n x F+3

    def _build_C(self, F):
        """ Return coordinates of fiducial points in I_r; C """
        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
        ctrl_pts_y_top = -1 * np.ones(int(F / 2))
        ctrl_pts_y_bottom = np.ones(int(F / 2))
        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
        return C  # F x 2

    def _build_inv_delta_C(self, F, C):
        """ Return inv_delta_C which is needed to calculate T """
        hat_C = np.zeros((F, F), dtype=float)  # F x F
        for i in range(0, F):
            for j in range(i, F):
                r = np.linalg.norm(C[i] - C[j])
                hat_C[i, j] = r
                hat_C[j, i] = r
        np.fill_diagonal(hat_C, 1)
        hat_C = (hat_C ** 2) * np.log(hat_C)
        # print(C.shape, hat_C.shape)
        delta_C = np.concatenate(  # F+3 x F+3
            [
                np.concatenate([np.ones((F, 1)), C, hat_C], axis=1),  # F x F+3
                np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1),  # 2 x F+3
                np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1)  # 1 x F+3
            ],
            axis=0
        )
        inv_delta_C = np.linalg.inv(delta_C)
        return inv_delta_C  # F+3 x F+3

    def _build_P(self, I_r_width, I_r_height):
        I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width  # self.I_r_width
        I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height  # self.I_r_height
        P = np.stack(  # self.I_r_width x self.I_r_height x 2
            np.meshgrid(I_r_grid_x, I_r_grid_y),
            axis=2
        )
        return P.reshape([-1, 2])  # n (= self.I_r_width x self.I_r_height) x 2

    def _build_P_hat(self, F, C, P):
        n = P.shape[0]  # n (= self.I_r_width x self.I_r_height)
        P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1))  # n x 2 -> n x 1 x 2 -> n x F x 2
        C_tile = np.expand_dims(C, axis=0)  # 1 x F x 2
        P_diff = P_tile - C_tile  # n x F x 2
        rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False)  # n x F
        rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps))  # n x F
        P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
        return P_hat  # n x F+3

    def build_P_prime(self, batch_C_prime):
        """ Generate Grid from batch_C_prime [batch_size x F x 2] """
        batch_size = batch_C_prime.size(0)
        batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
        batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
        batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
            batch_size, 3, 2).float().cuda()), dim=1)  # batch_size x F+3 x 2
        batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros)  # batch_size x F+3 x 2
        batch_P_prime = torch.bmm(batch_P_hat, batch_T)  # batch_size x n x 2
        return batch_P_prime  # batch_size x n x 2

# char_len resnet + base Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class CNNBase(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        net = models.resnet34(False)
        self.base = nn.Sequential(*list(net.children())[:-3])
        self.cLen = nn.Sequential(*list(children(net))[-3:-1], Flatten(), nn.Linear(512,1))
        self.linear = nn.Linear(256, d_model)
        
    def forward(self, x):
        feats = self.base(x)
        out = self.cLen(feats)
        feats = feats.flatten(2,3).permute(0,2,1)
        feats = self.linear(feats)
        return feats,out

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats,cLen = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats,cLen = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return (out,cLen)

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,cLen = pred
        pred,targ = loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        c_loss = F.kl_div(pred, true_dist, reduction='sum')/bs
        l_loss = F.mse_loss(cLen.view(-1), (target != 0).sum(dim=1).float(), reduction='sum')/bs
        return c_loss + l_loss

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, itos):
        super().__init__()
        self.name = 'cer'
        self.itos = itos

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = self._cer(last_output[0], last_target)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

    def _cer(self, preds, targs):
        bs,sl = targs.size()
        
        res = torch.argmax(preds, dim=2)
        error = 0
        for i in range(bs):
            p = self._char_label_text(res[i])   #.replace(' ', '')
            t = self._char_label_text(targs[i]) #.replace(' ', '')
            error += Lev.distance(t, p)/len(t)
        return error, bs

    def _char_label_text(self, pred):
#         return self.sp.DecodeIds(pred.tolist())
        ints = to_np(pred).astype(int)
        nonzero = ints[np.nonzero(ints)]
        return ''.join([self.itos[i] for i in nonzero])

In [None]:
d_model = 512

img_encoder = CNNBase(d_model)
transformer = make_full_model(len(itos), d_model)
net = Img2Seq(img_encoder, transformer, seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# ResNet for char_len

In [None]:
# net = models.resnet34(False)
net

In [None]:
modules = list(net.children())[:-3]
m = nn.Sequential(*modules)
m

In [None]:
m[6][0].downsample[0]

In [None]:
shape = Lambda(lambda x: print(x.shape))

for i,layer in enumerate(m):
    m[i] = nn.Sequential(shape, layer)
#         m[i] = nn.AvgPool2d((2,1))

In [None]:
m[5][3] = nn.AvgPool2d((2,1))
# bs, 1920, 4, 32 => [5,7,9][3]

In [None]:
net = models.resnet34(False)
m = nn.Sequential(*list(children(net))[:-3], Flatten(), nn.Linear(256,512))

In [None]:
class CustomMSE(nn.MSELoss):
    def forward(self, pred:Tensor, target:Tensor) -> Rank0Tensor:
        lens = (target != 0).sum(dim=1).float()  # num of nonzero elements (char_len)
        return super().forward(pred.view(-1), lens)

In [None]:
l = Learner(data, m, loss_func=CustomMSE())
l.fit_one_cycle(1, 1e-3)

In [None]:
sd = l.model.state_dict

In [None]:
class CNNBase(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        net = models.resnet34(False)
        modules = list(net.children())[:-3]
        modules[3] = nn.MaxPool2d(3)
        self.base = nn.Sequential(*modules)
        self.linear = nn.Linear(256, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x)
        return x

# Experiment

In [None]:
learn.fit_one_cycle(1, max_lr=1e-3)

# 33.423637	28.382051	0.238357	03:43   baseline - no modifications (gelu, multi(4), pre-resnet34)
# 33.606434	28.509262	0.684435	05:25   modified CER  ...something wrong...
# 33.010765	27.659348	0.399081	03:41   remove rshift in TeacherForcing
# 33.992241	28.652754	0.412570	03:36   ""
# 35.508675	30.228336	0.305193	03:37   keep rshift in TF; remove bos token from CharTokenizer

# 36.611359	30.429012	0.305808	03:34   single attn
# 37.610634	31.228436	0.395280	03:27   "", remove eos from CER
# 37.335598	31.324074	0.325508	03:36   multi(1)
# 34.636471	29.268423	0.292452	03:43   multi(4)
# 33.925339	28.797222	0.285785	03:42   multi(8)
# 34.031483	29.092997	0.361159	03:37   "", remove eos from CER
# 35.286541	30.412006	0.306993	03:50   multi(16)
# 36.340981	31.675489	0.326111	04:04   multi(32)

# sm_synth, sz:128, bs:100
# 42.141075	37.683475	0.463989	03:35   multi(8)
# 41.568932	36.599640	0.440583	03:32   multi(8); add PE to flattened feats

#   (3cycle, 1e-3)
# 21.168943	20.402149	0.221798	03:34   multi(8)
# valid:     17.9091   22.6781
# greedy:    179.729   27.0221
# 22.707590	20.752419	0.227430	03:36   multi(8); add PE to flattened feats
# valid:     18.6755   24.6373
# greedy:    174.812   25.5949

# Train

In [None]:
learn.load('8head512mix')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.fit_one_cycle(2, max_lr=1.58e-6, callbacks=[SaveModelCallback(learn, name='8head512mix2')])
# 3 cycles; 1e-3
### sm_synth, sz:128, bs:100
# 15.269251	13.785835	0.106348   base arch - (gelu, single attn, pre-resnet34)
# 15.385436	13.587833	0.104421   base arch (multi attn in encoder, single attn in decoder)
# 14.136460	12.912922	0.097931   "", multi attn (no tfms)
# 16.347567	14.436943	0.111617   "", reversed images - doesn't perform well from pretraining???
# 23.472956	19.903946	0.164833   normal images, resnet34 (no pretraining)
# 25.128386	21.791134	0.185044   reversed_images, resnet34 (no pretraining) - initialization???
# --normal images (black on white) function better as inputs regardless of pre-training
# 31.361557	27.413263	0.241301   densenet201 (no cut - bs,1920,4,4) (no pretraining)
# 29.163410	25.511379	0.222596   densenet201 ([:-3] - bs,1792,8,8) (no pretraining)
# 45.432755	41.811367	0.416656   resnet50 ([:-4] - bs,512,16,16; add BatchNorm2d) (no pretraining)
# 39.219181	35.872307	0.331677   resnet34(bs,256,8,8); add BatchNorm2d (no pretraining)
# --batchnorm doesn't help before final linear
# --resnet34 has best results and speed
# 28.031528	24.709288	0.217898   resnet34 (no pretraining)
# 24.876863	21.345222	0.181399   "" * 8

# Char_len, char
# 54.786438	46.906456	0.410596   full resnet -> char_len; [:-3] -> feats
# 48.422413	41.097393	0.353167   2nd run

# DeformableConv Layers
# 14.071175	12.978416	0.098672	05:00   4.2, 5.3, 6.5, multi(4), gelu, pretrained
# --no improvement, much slower

# 7.596963	7.282169	0.054113	07:24   remove base[3] (maxpool) => 16x16 features
# 8.450560	8.299768	0.060255	04:53   remove base[3] (maxpool); em_sz:512 (add 512 layers) => 8x8 features
# --additional computation layer seems to improve performance over same features w/out
# ...but not as much as 16x16 feat map output

# modify stride in convs; stride(2,1) => 8x16 feature maps
# 13.378371	12.241001	0.094132	04:29   base[6][0], multi(4), gelu, pretrained
# 11.961518	10.865971	0.081814	04:34   base[5][0]
# 9.808911	9.199651	0.068632	04:44   base[3]
# 10.367452	9.416987	0.070494	04:44   base[0]


### sm_synth, sz:256, bs:100
# 5.859333	5.713445	0.043088   base, multi-attn(4), gelu, pretrained    'v1_multi_gelu_256'    ****
# --pretraining on smaller sizes does not improve performance
# --16x16 feature map gives better results -> finer grained attn??

# 5.635329	5.135679	0.043417	29:09    bs: 20, remove base[3] (maxpool) => 32x32 features
# --potential slight improvement but much slower!!

### mix(10k), sz:512, bs:10, lr:1e-4
# 219.762802	215.301453	0.540762	08:16   base, multi-attn(4), gelu, 32x32 features
#    multi-attn(8)
# 256.292297	230.259369	0.665115	12:14(1st cycle)   multi-attn(16), bs: 5
# 206.822525	217.673355	0.555488	07:08   base, single-attn, bs: 15

# 203.974548	216.947418	0.561815	21:56(2nd cycle)   sz:800, bs:5, base, multi-attn(4), 50x50 features

### edited_fonts(74k) - base, multi(4), gelu, preload 'v1_multi_gelu_256'
# sz:256, bs: 20, lr:1e-4
# 234.073456	210.412842	0.207377	39:16   resize: squish     'v1_multi_gelu_256_fonts'
# 389.043976	349.327209	0.410615	40:07   resize: pad/border   'best256fonts'
# sz:400, bs: 10, lr:2e-5
# sz:512, bs: 10, lr:2e-5
# 39.1337090	33.233494	0.029211	1:14:35   'v1_multi_gelu_512_fonts'
# 38.2521550	32.473492	0.028432	1:13:44    2nd run,  lr: 1e-6  'best512fonts'


# sz:512, bs: 10, lr:1e-4, 8head, no preload, remove bos from CharTokenizer, remove eos from CER
# 56.5093610	44.473030	0.039634	1:27:05  *** '8head512fonts' 
# valid:     22.1713   0.02428
# greedy:    572.330   0.03317
# mixed data
# 35.9120830	32.567410	0.088160	59:11    *** '8head512mix'
# valid:     4.80507   0.15863
# greedy:    23.6504   0.12724
# test:      2282.64   0.13363
# 41.1803780	32.238232	0.087566	57:54   added tfms, preload above, lr: 1.58e-6, 2cycle   ***'8head512mix2'


### Training
# 128, 100
# 14.136460	12.912922	0.097931   "", multi attn (no tfms)
# 8.069926	8.673322	0.064796   2nd run    *** 'v1_multi_gelu_128'
# 256, 100
# 5.880878	5.727545	0.043746   
# 2.970708	3.868662	0.029560   2nd run

### Training - Edited Data - GeLU, singlehead attn
# sm
# 10.841042	9.550449	0.071475   sz:128, bs:100, lr:1e-3  'v1_gelu_128'
# 2.631370	2.669811	0.020308   sz:256, bs:100, lr:1e-3  'v1_gelu_256'
# mix
# 18.652050	14.330537	0.021167   sz:400, bs:20, lr:1e-3   'v1_edited_gelu_400'
# 11.513125	8.577786	0.016521   sz:512, bs:15, lr:5e-5, 3cycle   'v1_edited_gelu_512'

# Greedy

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(im.data), alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
def full_test(learn, sl, dl=data.valid_dl, batches=20):
    learn.model.eval()
    iterable = iter(dl)
    g_loss,g_cer=0,0
    if batches is None:
        batches = len(dl.dl.dataset)//bs
    for i in progress_bar(range(batches)):
        x,y = next(iterable)
        g_preds = learn.model(x, seq_len=sl)
        g_res = torch.argmax(g_preds, dim=-1)
        g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]
        g_loss+=g[0]
        g_cer+=g[1]
    return [g_loss/batches, g_cer/batches]

In [None]:
g = full_test(learn, seq_len)

In [None]:
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
# losses = np.array([learn.loss_func(g_preds[i:i+1],y[i:i+1]).item() for i in range(bs)])
# cers = np.array([cer(g_preds[i:i+1],y[i:i+1])[0] for i in range(bs)])

In [None]:
x,y = next(iter(learn.data.train_dl))

g_preds = learn.model(x, seq_len=20)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]

In [None]:
#greedy
fig, axes = plt.subplots(2,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], sep=' ')
    ax=show_img(x[i], ax=ax, title=p)

## Test Set

In [None]:
x,y = next(iter(learn.data.train_dl))

g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
#test
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    #i +=8
    p = char_label_text(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

## Single Test Image (cpu)

In [None]:
xs,ys = next(iter(learn.data.train_dl))

In [None]:
char_label_text(ys[8])

In [None]:
i = 8
x = xs[i][None]
y = ys[i][None]

In [None]:
g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)[0]]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
p = char_label_text(g_res[0])
show_img(x[0], figsize=(18,10), title=p)

In [None]:
im = PATH/'uploads'/'test2.png'
img = open_image(im)
prediction = learn.predict(img)[0]
show_img(img, title=prediction)

In [None]:
img

# Test

In [None]:
# def cer(preds, targs):
#     bs = targs.size(0)
#     res = torch.argmax(preds, dim=-1)
#     error = 0
#     for i in range(bs):
#         p = char_label_text(res[i])   #.replace(' ', '')
#         t = char_label_text(targs[i]) #.replace(' ', '')
#         error += Lev.distance(t, p)/len(t)
#     return error

# def char_label_text(pred):
#     ints = to_np(pred).astype(int)
#     nonzero = ints[np.nonzero(ints)] #[:-1]  #remove bos/eos token
#     return ''.join([itos[i] for i in nonzero])

In [None]:
def self_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].self_attn.attn.data.cpu()
def source_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].src_attn.attn.data.cpu()

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    if im.shape[0] == 3: im = image2np(im.data)
    ax.imshow(im, alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
x,y = next(iter(learn.data.valid_dl))
# x,y = learn.data.one_batch(ds_type=DatasetType.Train)
imgs = x #learn.data.denorm(x)

## Uploaded Images

In [None]:
def thresh_edit(fname, thresh=100, bg=245):
    im = Image.open(fname).convert('L')  #grayscale
    np_im = np.array(im)
    im.close()
    thresh_mask = np_im > thresh
    np_im[thresh_mask] = bg
    return Image.fromarray(np_im, 'L')

In [None]:
im = thresh_edit(PATH/'test3.png')
show_img(im, figsize=(15,15))

In [None]:
e = 'edit_'+str(fname.name)
edited_fname = PATH/e

In [None]:
edited_im.save(edited_fname)

In [None]:
fname = PATH/'test/edit_test4.png'
im = open_image(fname)

In [None]:
seq,res,preds = learn.predict(im)
print(seq)

In [None]:
# r = torch.tensor([g_res], dtype=torch.long, device=device)
truth = "This is a test letter. I hope this\nworks but I'm not sure it will.\nMy handwriting is not very good."
Lev.distance(truth, str(seq))/len(truth)

## Results

In [None]:
# x,y = next(v_dl)
# imgs = denorm(x)

learn.model.eval()

shifted_y = rshift(y).long()
tgt_mask = subsequent_mask(shifted_y.size(-1))
v_preds = learn.model(x, shifted_y, tgt_mask)
v_res = torch.argmax(v_preds, dim=-1)
v_attn = source_attn()

g_preds = learn.model(x)
g_res = torch.argmax(g_preds, dim=-1)
g_attn = source_attn()

In [None]:
v = [learn.loss_func(v_preds, y).item(), cer(v_preds, y)]
print(f'valid:     {str(v[0])[:7]}   {str(v[1][0])[1:7]}')

g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1][0])[1:7]}')

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
test_data = (ImageList.from_df(test_df, path=PATH, folder=TEST_FOLDER)
             .split_none()
             .label_from_df(label_cls=TextList, sep='', pad_idx=0, vocab=vocab, processor=procs)
             .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
             .databunch(bs=bs, device=device, collate_fn=label_collater)
             .normalize()
            )

In [None]:
tx,ty = next(iter(test_data.train_dl))
t_imgs = test_data.denorm(tx)

In [None]:
t_preds = learn.model(tx)
t_res = torch.argmax(t_preds, dim=-1)

In [None]:
g = [learn.loss_func(t_preds, ty).item(), cer(t_preds, ty)]
print(f'test:      {str(g[0])[:7]}   {str(g[1])[:7]}')

In [None]:
#test
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(t_res[i])
    ax=show_img(t_imgs[i], ax=ax, title=p)

In [None]:
# tfmr full
#             loss       cer       acc
# valid:     2.42856   0.02981   0.98177   3x1, 256/60, 'tfmr_full_3x1_single_attn'
# greedy:    11.9438   0.02745   0.93385

# valid:     1.98831   0.01858   0.98505   3x1, 256/60, 'exp_3x1_256'
# greedy:    16.3832   0.02167   0.90944

# valid:     1.47604   0.00797   0.99523   3x2, 400/45,  'tfmr_full_3x2'
# greedy:    21.4682   0.01097   0.93762

# valid:     7.30494   0.01206   0.99086   lg,  512/30,  'tfmr_full_lg'
# greedy:    434.610   0.02398   0.69553

# valid:     25.4131   0.04500   0.96773   lg, 512/30,   'tfmr_lg_LM_mixer'
# greedy:    734.834   0.15256   0.48912

# valid:     10.1481   0.00501   0.99602   cat9-12,  800/8,  'tfmr_cat9-12_full_800'
# greedy:    1330.63   0.04525   0.53181

# valid:     62.0793   0.05716   0.96214   pg,  512/20  'tfmr_full_paragraph'  (cpu)
# greedy:    1739.62   0.08347   0.43309
# beam:                0.07958
# valid:     71.9120   0.06819   0.95457   "", 2nd batch (gpu)
# greedy:    2027.24   0.11823   0.39238
# beam:                0.10697
# valid:     36.3009   0.03397   0.97633   "", 3rd batch (gpu)
# greedy:    1759.59   0.07070   0.40949
    
# valid:     50.1414   0.04137   0.97004   pg,  1000/5,  'tfmr_pg_1000'
# greedy:    2579.24   0.34178   0.18623

# valid:     56.2722   0.04167   0.96923   pg,  1000/5,  'tfmr_catpg_1000'
# greedy:    2432.54   0.37192   0.26174


# valid:     3.43745   0.05444   0.98735   mix(new)  'tfmr_mix_words_400'
# greedy:    43.6545   0.06126   0.91407

# valid:     2.96891   0.03338   0.98932   mix(new)  'tfmr_mix_words_512'
# greedy:    28.9982   0.03949   0.94500
# valid:     2.02411   0.02128   0.99302   5 more cycles
# greedy:    19.4735   0.02450   0.96104

# valid:     41.9800   0.03879   0.97535   pg, 'tfmr_mix_words_512'
# greedy:    1602.68   0.06259   0.49110


# valid:     53.9049   0.04822   0.96622   pg, 'tfmr_8head_512_mix'
# greedy:    1384.72   0.06304   0.51696


# valid:     5.24461   0.00634   mix   'v1_gelu_512'
# greedy:    217.438   0.02028
# valid:     24.3609   0.02405   pg
# greedy:    1935.30   0.05970

# valid:     12.7074   0.00954   mix   'v1_gelu_512_wiki103_base_lm_last_layer'
# greedy:    553.201   0.02814

# valid:     11.4251   0.00474   mix   'v1_gelu_512_wiki103_base_lm_full'
# greedy:    358.030   0.01187

# valid:     12.1104   0.01533   mix   'v1_gelu_multi_512'
# greedy:    506.505   0.04229


# edited data
# valid:     8.21437   0.06576   sm   'v1_gelu_128'
# greedy:    185.981   0.12890

# valid:     2.25060   0.01731   sm   'v1_gelu_256'
# greedy:    149.525   0.06003

# valid:     2.37294   0.01383   mix  'v1_edited_gelu_400'
# greedy:    30.3637   0.01732
# greedy:    1810.14   0.11409   test

# valid:     9.09986   0.01656   mix  'v1_edited_gelu_512'
# greedy:    529.308   0.03104
# greedy:    1675.20   0.10819   test

# Beam Search

## Multi-Batch Beam Decode

In [None]:
def beam_cer(preds, targs):
    bs = preds.size(0)
    error = 0
    for i in range(bs):
        p = char_label_text(preds[i])
        t = char_label_text(targs[i])
        error += _cer(t,p)
    return error/bs

In [None]:
def repeat_interleave(tensor,n):
    res = []
    for i in range(tensor.size(0)):
        for _ in range(n): res.append(tensor[i])
    return torch.stack(res)

In [None]:
# def normalize_score(score, i, alpha=0.6):
#     length_penalty = math.pow((5 + i)/6, alpha)
#     return score/length_penalty

### Beam Object

In [None]:
import heapq

class Beam(object):
    def __init__(self, beam_width=10):
        self.heap = list()
        self.beam_width = beam_width
        self.best_score = None

    def add(self, score, complete, seq):        
        heapq.heappush(self.heap, (score, complete, seq))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
            
    def get_seq(self):
        return [b[-1] for b in self.heap]
        
    def __iter__(self):
        return iter(self.heap)

In [None]:
class BeamSearch(nn.Module):
    def __init__(self, net, beam_width, seq_len, end_tok=0):
        super(BeamSearch, self).__init__()
        self.img_enc = net.img_enc
        self.transformer = net.transformer
        self.bw = beam_width
        self.seq_len = seq_len
        self.end_tok = end_tok
        self.feats = None
        self.beams = []
            
    def forward(self, src):
        with torch.no_grad():
            bs = src.size(0)
            
            # initialize beam per bs
            for _ in range(bs):
                beam = Beam(self.bw)
                beam.add(0.0, False, [1])
                self.beams.append(beam)
                
            # encode src
            self.feats = self.transformer.encode(self.img_enc(src))
            
            for i in tqdm(range(seq_len)):
                # gather sequences from beams; combine into tensor (bs*bw)
                prev_seq = torch.from_numpy(np.stack([b.get_seq() for b in self.beams])).view(-1,i+1)
                
                # generate new possibilities
                log_probs, chars = self.prob_func(prev_seq.to(device))
                log_probs, chars = log_probs.view(bs,-1,self.bw), chars.view(bs,-1,self.bw)
                
                for j in range(bs):
                    curr_beam = Beam(self.bw)
                    for k,(score, complete, seq) in enumerate(self.beams[j]):
                        for l,c in zip(log_probs[j,k],chars[j,k]):
                            log_prob,char = l.item(), c.item()
                            curr_beam.add((score+log_prob), (char==self.end_tok), seq+[char])
                    self.beams[j] = curr_beam
  
                # return if all max beams are complete
                if (self.top_complete()==True).all(): break
                    
                # expand feats to match beam size (only on 2nd run)
                if i==0: self.feats = repeat_interleave(self.feats, self.bw)

            return self.top_seq()

    def top_complete(self): return np.stack([max(b)[1] for b in self.beams])
    def top_seq(self): return torch.from_numpy(np.stack([max(b)[-1] for b in self.beams]))[:,1:]

    def prob_func(self, tgt):
        mask = subsequent_mask(tgt.size(-1))
        dec_outs = self.transformer.decode(self.feats, tgt, mask)
        logits = self.transformer.generate(dec_outs[:,-1])
        log_probs = logits - torch.logsumexp(logits, -1, keepdim=True) # more stable than F.softmax(logits,-1).log()
        return torch.topk(log_probs, self.bw, dim=-1)

In [None]:
search = BeamSearch(learn.model, 3, seq_len)
b_res = search(x)

# bw: 3, sl: 350, ~56s, 0.02491    # no normalized_score

In [None]:
beam_cer(b_res,y)

### Tensors

In [None]:
class BeamSearch(nn.Module):
    def __init__(self, net, beam_width, seq_len):
        super(BeamSearch, self).__init__()
        self.img_enc = net.img_enc
        self.transformer = net.transformer
        self.bw = beam_width
        self.seq_len = seq_len
        
        self.feats = None
        self.beam = None
        self.scores = None
        
        net.eval()
    
    def forward(self, src):
        with torch.no_grad():
            bs = src.size(0)
            
            # encode src
            self.feats = self.transformer.encode(self.img_enc(src))
            
            # initialize globals (beam=1 for first iteration; 3 thereafter)
            self.beam = torch.ones((bs,1), device=device, dtype=torch.long)
            self.scores = torch.zeros((bs,1), device=device, dtype=torch.float)
            
            for i in tqdm(range(seq_len)):
                # generate new topk chars per beam(bs*bw)
                log_probs, chars = self.prob_func(self.beam)  #(bs*beam, 3)

                # compute local scores
                scores = self.scores + log_probs
                
                # compute new beams per batch
                new_scores, idxs = torch.topk(scores.view(bs,-1), self.bw, dim=-1) #(bs, 3)
                self.scores = new_scores.view(-1,1)
                
                # set up new beam:
                nxt = torch.stack([c[i] for c,i in zip(chars.view(bs,-1),idxs)]).view(-1,1)
                pre = torch.stack([b[i//self.bw] for b,i in zip(self.beam.view(bs,-1,i+1),idxs)]).view(-1,i+1)
                                    
                # update globals
                self.beam = torch.cat([pre,nxt], dim=1)

                # end when top of beams are complete
                if self.top_complete(): break
                                                        
                # expand feats to match beam size (only on 2nd run)
                if i==0: self.feats = repeat_interleave(self.feats, self.bw) #.repeat(self.bw,1,1)
                    
            return self.top_sequences(), self.top_scores() 


    def top_complete(self): return (self.top_sequences()[:,-1]==0).all().item()   #byte tensor
    def top_sequences(self): return self.beam.squeeze()[0::self.bw][:,1:]
    def top_scores(self): return self.scores.squeeze()[0::self.bw]

    def prob_func(self, tgt):
        mask = subsequent_mask(tgt.size(-1))
        dec_outs = self.transformer.decode(self.feats, tgt, mask)
        logits = self.transformer.generate(dec_outs[:,-1])
        log_probs = logits - torch.logsumexp(logits, -1, keepdim=True) # more stable than F.softmax(logits,-1).log()
        return torch.topk(log_probs, self.bw, dim=-1)

In [None]:
search = BeamSearch(learn.model, 3, seq_len)
b_res,score = search(x)

# lg, sl: 250
# bw: 1, ~12s, 0.02398
# bw: 3, ~28s, 0.02378
# bw: 5, ~46s, 0.02378

# pg: 1000,5
# bs: 3, ~40s, 0.33227  (greedy: 0.28---)
# ''   , ~30s, 0.22635  (greedy: 0.23726)

# pg: 800,8
# bs: 3, ~51s, 0.25424  (greedy: 0.28348)

In [None]:
beam_cer(b_res,y)

In [None]:
#beam
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(b_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

## Single Beam Decode

In [None]:
# https://geekyisawesome.blogspot.com/2016/10/using-beam-search-to-generate-most.html

import heapq

class Beam(object):
    '''
    For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used.
    This is so that if two prefixes have equal probabilities then a complete sentence
    is preferred over an incomplete one since (0.5, False) < (0.5, True)
    '''

    def __init__(self, beam_width=10):
        self.heap = list()
        self.beam_width = beam_width
        self.best_score = None

    def add(self, score, complete, seq):        
        # keep track of best_score so far
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            
        # only add to beam if score is not more than beam_width below the best_score
        if score > self.best_score-self.beam_width:
            heapq.heappush(self.heap, (score, complete, seq))
            
        # maintain beam_width
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
                
    def __iter__(self):
        return iter(self.heap)

In [None]:
def beamsearch(prob_fn, seq_len, beam_width=5, start_tok=1, end_tok=3):
    prev_beam = Beam(beam_width)
    prev_beam.add(0.0, False, [start_tok])
    
    for i in tqdm(range(seq_len)):
        curr_beam = Beam(beam_width)
        
        # iterate over each beam
        for (score, complete, seq) in prev_beam:
            if complete == True:
                None  # only keep the completed best beam!!
#                 curr_beam.add(score, True, seq)
            else:
                # iterate through topk chars, calculating scores and adding to the beam.
                log_probs, chars = prob_fn(seq)
                for log_prob, char in zip(log_probs, chars): 
                    log_prob,char = log_prob.item(), char.item()
                    score += log_prob   #log probabilities are additive
#                     score = score_func(score, len(seq))
                    curr_beam.add(score, (char==end_tok), seq+[char])
        
        (best_score, best_complete, best_seq) = max(curr_beam)
        if best_complete == True: return (best_seq[1:], best_score)   # returns first complete beam not best...
            
        prev_beam = curr_beam
        
    (best_score, best_complete, best_seq) = max(curr_beam)
    return (best_seq[1:], best_score)

In [None]:
def beam_decode(net, src, beam_width, seq_len):
    net.eval()
    with torch.no_grad():
        feats = net.transformer.encode(net.img_enc(src))        
        return beamsearch(partial(prob_func, net=net, feats=feats, beam_width=beam_width), seq_len, beam_width)
    
def prob_func(tgt, net=None, feats=None, beam_width=5):
    tgt = torch.tensor([tgt], dtype=torch.long, device=device)
    mask = subsequent_mask(tgt.size(-1))
    dec_outs = net.transformer.decode(feats, tgt, mask)
    logits = net.transformer.generate(dec_outs[:,-1])
    
    log_probs = logits - torch.logsumexp(logits, 1)  # more numerically stable
    # log_probs = F.softmax(logits, -1).log()
    
    return torch.topk(log_probs.squeeze(0), beam_width, dim=-1)
#     return zip(res[0][0].detach(),res[1][0].detach())

def score_func(log_probs, i, alpha=0.6):
    length_penalty = math.pow((5 + i)/6, alpha)
    return log_probs/length_penalty

In [None]:
idx = 2
x1 = x[idx][None]
y1 = y[idx][None]

In [None]:
b_res, score = beam_decode(learn.model, image, 3, seq_len)    #294, 3m18s
# 294 - 1m40s
# 294 - 1m45s (w/ score_func)
# 295 - 22s (beam_width=1 ~ greedy)

In [None]:
r = torch.tensor([b_res], dtype=torch.long, device=device)
p = char_label_text(r)

_cer(truth, p)

In [None]:
# valid
p = char_label_text(v_res[idx][None])
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
# greedy
p = char_label_text(g_res[idx][None])
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
# beam (sz=3)
r = torch.tensor([b_res], dtype=torch.long, device=device)
p = char_label_text(r)
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
stoi = {k:i for i,k in enumerate(itos)}

In [None]:
print(char_label_text(g_res[idx][None]))

In [None]:
st = ''.join([itos[i] for i in b_res])
p = '\n'.join(textwrap.wrap(st, 70))
show_img(denorm(x1)[0], figsize=(10,10), title=p)

## Source Attn

In [None]:
idx = 0
img = imgs[idx]

v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn)[idx])

g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn)[idx])

In [None]:
#valid
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(vv_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])

In [None]:
#greedy
fig, axes = plt.subplots(6,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

# Attention Visualizations

In [None]:
from scipy.ndimage import gaussian_filter
k=16

def torch_scale_attns(attns):
    bs,sl,hw = attns.shape
    num = int(math.sqrt(hw))   # sz // k
    mod = attns.view(bs,sl,num,num)
    scaled = F.interpolate(mod, size=sz)
    return scaled  #([bs, sl, h, w])

def g_filter(att):
    return gaussian_filter(att, sigma=k)

In [None]:
def self_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].self_attn.attn.data.cpu()
def source_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].src_attn.attn.data.cpu()

In [None]:
def transparent_cmap(cmap, N=5):
    "Copy colormap and set alpha values"
    mycmap = plt.cm.get_cmap(cmap, N)
    mycmap._init()
    mycmap._lut[:,-1] = np.linspace(0, 0.6, N+3)
    return mycmap

#Use base cmap to create transparent
# mycmap = transparent_cmap(plt.cm.Reds)

In [None]:
def show_attn(img, attns, chars, ax, color, showChars=True):
    for i in range(attns.shape[0]):
        c = chars[i].item()
        if c not in [0,1,2,3]:
            a = g_filter(attns[i])
            y,x = scipy.ndimage.center_of_mass(a)
            #sns.heatmap(a, cmap=mycmap, cbar=False, ax=ax)
            ax.imshow(a, cmap=transparent_cmap(color), interpolation='nearest')
            if showChars: ax.text(x-8,y-10,word_itos[c], fontsize=15)

    ax.set_title(char_label_text(chars))
    ax.imshow(img.permute(1,2,0), alpha=0.6)

In [None]:
def thresh_attn(attn, thresh=0.1):
    zeros = torch.zeros_like(attn)
    new = torch.where(attn >= thresh, attn, zeros)
    
    # some attns will not have a value over the thresh in which case
    # we need to insert top k value at appropriate index
    vals, idxs = torch.topk(attn, 1, dim=-1)
    
    # reshape
    flat_new = new.flatten(0,1)
    vals = vals.flatten()
    idxs = idxs.flatten()
    
    for i in range(flat_new.size(0)):
        flat_new[i,idxs[i]] = vals[i]

    new = flat_new.view_as(new)
    return new

## Decoder Self-Attention

In [None]:
sns.set_context(context="notebook")

def draw(data, x, y, ax):
    return sns.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
                       cmap='YlOrRd', linewidths=0.05, cbar=False, ax=ax)

for layer in range(4):
    print("Decoder Self-Attention Layer", layer+1)

    fig, axes = plt.subplots(1,4, figsize=(20, 10))
    for i,ax in enumerate(axes.flat):
        # greedy decoding (no access to true values)
        pred = char_split_text(g_res[i])[20:40]
        shifted_y = rshift(g_res.float()).long()
        true = char_split_text(shifted_y[i])[20:40]
        g = draw(self_attn(layer)[i].data[20:40, 20:40], true, pred, ax=ax)
        g.set_yticklabels(g.get_yticklabels(), rotation=0) 
        g.set_xticklabels(g.get_xticklabels(), rotation=0) 
    plt.show()

## Decoder Source-Attention

In [None]:
fig, axes = plt.subplots(4,4, gridspec_kw={'hspace': 0.5}, figsize=(20, 10))
    
for idx in range(len(axes.flat)//4):
    img = x[idx]
    g_chars = g_res[idx]
    
    # 4 attn layers
    for h in range(4):
        attn = source_attn(h)
        g_attns = to_np(torch_scale_attns(attn)[idx])

        show_attn(img, g_attns, g_chars, axes[idx,h], 'YlGn', showChars=False)
        axes[idx,h].set_title(f'layer {h+1}')

## Validation vs Greedy (final layer src-attn)

In [None]:
v_scaled_attns = torch_scale_attns(thresh_attn(v_attn))
g_scaled_attns = torch_scale_attns(thresh_attn(g_attn))

In [None]:
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.5}, figsize=(20, 20))
for idx in range(len(axes.flat)//2):
    img = imgs[idx]

    v_chars = v_res[idx]
    v_attns = to_np(v_scaled_attns[idx])

    g_chars = g_res[idx]
    g_attns = to_np(g_scaled_attns[idx])
    
    # valid
    show_attn(img, v_attns, v_chars, axes[idx,0], 'YlOrRd')
    # greedy
    show_attn(img, g_attns, g_chars, axes[idx,1], 'YlGn')

# Individual Examination

In [None]:
idx=1
img = x[idx]

g_chars = g_res[idx]
g_scaled_attns = torch_scale_attns(thresh_attn(g_attn)[0:1])  # passing in bs of 1
g_attns = to_np(g_scaled_attns[0])  # removing bs

fig, ax = plt.subplots(1,1, figsize=(20, 20))
show_attn(img, g_attns, g_chars, ax, 'YlGn')

In [None]:
sns.set_context(context="notebook")

def draw(data, x, y, ax):
    mask = np.zeros_like(data)
    mask[np.triu_indices_from(mask, k=1)] = True
    return sns.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
                       mask=mask, cmap='YlOrRd', linewidths=0.05, cbar=False, ax=ax)

for layer in range(4):
    print("Decoder Self-Attention Layer", layer+1)

    fig, ax = plt.subplots(1,1, figsize=(20, 10))
    i = 1
    # greedy decoding (no access to true values)
    pred = char_split_text(g_res[i])[230:280]
    shifted_y = rshift(g_res.float()).long()
    true = char_split_text(shifted_y[i])[230:280]
    g = draw(self_attn(layer)[i].data[230:280, 230:280], true, pred, ax=ax)
    g.set_yticklabels(g.get_yticklabels(), rotation=0) 
    g.set_xticklabels(g.get_xticklabels(), rotation=0) 
    plt.show()

# Backprop - chart dependencies (batch leakage) 

In [None]:
xb,yb = next(iter(learn.data.train_dl))

In [None]:
learn.model.eval()   # this is important!!!  otherwise batchnorm will mess things up
learn.model.zero_grad()

In [None]:
xb.requires_grad_(True)
xb.grad.zero_()
None

In [None]:
shifted_y = rshift(yb).long()
tgt_mask = subsequent_mask(shifted_y.size(-1))
pb = learn.model(xb, shifted_y, tgt_mask)

In [None]:
# loss = learn.loss_func(pb, yb)

In [None]:
loss = pb[2].sum()
loss.backward()
assert (xb.grad[2] != 0).any()
assert (xb.grad[1] == 0.).all()

In [None]:
xb.grad