# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *

import pdb

In [None]:
PATH = Path('data/IAM_handwriting')
TMP_PATH = PATH/'tmp'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Loss and Metrics

In [None]:
def loss_prep(input, target):
    "equalize input/target sl; combine bs/sl dimensions"
    bs,tsl = target.shape
    _ ,sl,vocab = input.shape
        
    # F.pad( front,back for dimensions: 1,0,2 )
    if sl>tsl: target = F.pad(target, (0,sl-tsl))
        
    # this should only be used when testing for small seq_lens
    # if tsl>sl: target = target[:,:sl]
    
    if tsl>sl: input = F.pad(input, (0,0,0,tsl-sl))
    # not ideal => adds 82 logits all 0s...
        
    targ = target.contiguous().view(-1).long()
    pred = input.contiguous().view(-1, vocab)
    return pred, targ

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, itos):
        super().__init__()
        self.name = 'cer'
        self.itos = itos

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = self._cer(last_output, last_target)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

    def _cer(self, preds, targs):
        bs,sl = targs.size()
        
        res = torch.argmax(preds, dim=2)
        error = 0
        for i in range(bs):
            p = self._char_label_text(res[i])   #.replace(' ', '')
            t = self._char_label_text(targs[i]) #.replace(' ', '')
            error += Lev.distance(t, p)/len(t)
        return error, bs

    def _char_label_text(self, pred):
#         return self.sp.DecodeIds(pred.tolist())
        ints = to_np(pred).astype(int)
        nonzero = ints[np.nonzero(ints)]
        return ''.join([self.itos[i] for i in nonzero])

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((size,size), device=device).byte()).unsqueeze(0)

class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner, bos_token:int=1):
        super().__init__(learn)
        self.bos_token = bos_token
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        s = rshift(last_target, self.bos_token).long()
        mask = subsequent_mask(s.size(-1))
        return {'last_input':(last_input, s, mask), 'last_target':last_target}

# Data

## sm synth dataset

In [None]:
fname = 'edited_sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'edited_sm_synth'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 128,100
# sz,bs = 256,100

seq_len = 45

## Concat Lines

In [None]:
nums = '3-6'   #'7-10'  #'11-14'
CSV = PATH/f'cat_lines_{nums}.csv'
FOLDER = 'resized_cat_lines'

csv = pd.read_csv(CSV)
test = pd.read_csv(PATH/'test_pg.csv')

len(csv), len(test)

In [None]:
lengths = np.array([len(i.split(' ')) for i in csv.char_ids.values])
lengths.max()

In [None]:
sz,bs = 512,20   #1000,5  #800,8   #512,20
seq_len = 600 #600 #450  #300
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

### Concat All

In [None]:
a = pd.read_csv(PATH/'cat_lines_3.csv').sample(1000)
b = pd.read_csv(PATH/'cat_lines_4.csv').sample(1000)
c = pd.read_csv(PATH/'cat_lines_5.csv').sample(1000)
d = pd.read_csv(PATH/'cat_lines_6.csv').sample(1000)
e = pd.read_csv(PATH/'cat_lines_7.csv').sample(1000)
f = pd.read_csv(PATH/'cat_lines_8.csv').sample(1000)
g = pd.read_csv(PATH/'cat_lines_9.csv').sample(1000)
h = pd.read_csv(PATH/'cat_lines_10.csv').sample(1000)
i = pd.read_csv(PATH/'cat_lines_11.csv').sample(1000)
j = pd.read_csv(PATH/'cat_lines_12.csv').sample(1000)

In [None]:
k = pd.read_csv(PATH/'paragraph_chars.csv')

In [None]:
val_idxs = np.array(k.sample(frac=0.15, random_state=42).index)

In [None]:
trn = k[~k.index.isin(val_idxs)]
test = k[k.index.isin(val_idxs)]

In [None]:
new = pd.concat([a,b,c,d,e,f,g,h,i,j,trn], ignore_index=True)
len(new)

In [None]:
# new.to_csv(PATH/'cat_lines_pg.csv', index=False)
# test.to_csv(PATH/'test_pg.csv', index=False)

## Paragraphs

In [None]:
fname = 'edited_pg.csv' #'paragraphs.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

df = pd.read_csv(CSV)
len(df)

In [None]:
# sz,bs = 1000,5  #1024,5  #1000,5   #~2000x1000 full size
# sz,bs = 800,8
sz,bs = 512,10

seq_len = 700   #~400 chars/paragraph - max: 705
# stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

## Downloaded Images

In [None]:
CSV = PATH/'downloaded_images.csv'
FOLDER = 'downloaded_images'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
# csv['filename'] = csv['filename'].apply(lambda x: f"dl_{x}")
# csv.head()
# csv.to_csv(CSV, index=False)

In [None]:
# sz,bs = 1000,5  #1024,5  #1000,5   #~2000x1000 full size
# sz,bs = 800,8
sz,bs = 512,20

seq_len = 700   #~400 chars/paragraph - max: 705
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

## Test

In [None]:
fname = 'test_pg.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 512,15

seq_len = 700

## Mix

In [None]:
fname = 'mix.csv'
CSV = PATH/fname
FOLDER = 'mix'

df = pd.read_csv(CSV)
len(df)

In [None]:
sz,bs = 512,15
# sz,bs = 400,20

seq_len = 800

## ModelData

In [None]:
# tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=2.0, max_warp=0.1)
tfms = []

In [None]:
def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads end of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len).long() + pad_idx
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

### SentencePiece

In [None]:
class SPMProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None):
        self.sp = ds.sp if ds is not None else None

    def process_one(self,item): return self.sp.EncodeAsIds(item)
    def process(self, ds): super().process(ds)
    
class SPMList(ItemList):
    _processor = [SPMProcessor]

    def __init__(self, items:Iterator, sp_processor, **kwargs):
        super().__init__(items, **kwargs)
        self.sp = sp_processor
        self.copy_new += ['sp']

    def get(self, i):
        o = super().get(i)
        return Text(o, self.sp.DecodeIds(o))

    def reconstruct(self, t:Tensor):
        return Text(t, self.sp.DecodeIds(t.tolist()))

In [None]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load(str(PATH/'spm_train.model'))
sp.SetEncodeExtraOptions("bos:eos")

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SPMList, sp_processor=sp)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        .normalize()
       )

### Chars

In [None]:
itos = pickle.load(open(PATH/'itos.pkl', 'rb'))  #'char_itos.pkl'

In [None]:
class CharTokenizer():
    def __init__(self, n_cpus:int=None):
        self.n_cpus = ifnone(n_cpus, defaults.cpus)

    def tokenize(self, t:str): return ['xxbos']+list(t)+['xxeos']

    def _process_all_1(self, texts:Collection[str]) -> List[List[str]]:
        "Process a list of `texts` in one process."
        return [self.tokenize(str(t)) for t in texts]

    def process_all(self, texts:Collection[str]) -> List[List[str]]:
        "Process a list of `texts`."
        if self.n_cpus <= 1: return self._process_all_1(texts)
        with ProcessPoolExecutor(self.n_cpus) as e:
            return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), [])

class CharVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        nums = nums[1:-1]  #remove bos,eos tokens
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

In [None]:
def reverse_image(image): return PIL.ImageOps.invert(image)

vocab = CharVocab(itos)
procs = [TokenizeProcessor(tokenizer=CharTokenizer(), include_bos=False),
         NumericalizeProcessor(vocab=vocab)]

data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=reverse_image)
        #.split_none()
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=TextList, sep='', pad_idx=0, vocab=vocab, processor=procs)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        .normalize()
       )

### SequenceList (old)

In [None]:
# df['lens'] = df.label.map(len)
# df.sort_values('lens', ascending=False, inplace=True)
# df.head(5)

In [None]:
itos = pickle.load(open(PATH/'itos.pkl', 'rb'))  #'char_itos.pkl'

In [None]:
class SequenceItem(ItemBase):
    def __init__(self,data,vocab): self.data,self.vocab = data,vocab        
    def __str__(self): return self.textify(self.data)
    def __hash__(self): return hash(str(self))
    def textify(self, data): return ''.join([self.vocab[i] for i in data[:-1]])
        
class ArrayProcessor(PreProcessor):
    "Convert df column (string of ints) into np.array"
    def __init__(self, ds:ItemList=None): None
    def process_one(self,item): return np.array(item.split(), dtype=np.int64)
    def process(self, ds): super().process(ds)
        
class ItosProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None): self.itos = ds.itos
    def process(self, ds:ItemList): ds.itos = self.itos
        
class SequenceList(ItemList):
    _processor = [ItosProcessor, ArrayProcessor]
    
    def __init__(self, items:Iterator, itos:List[str]=None, **kwargs):
        super().__init__(items, **kwargs)
        self.itos = itos
        self.copy_new += ['itos']
        self.c = len(self.items)

    def get(self, i):
        o = super().get(i)
        return SequenceItem(o, self.itos)

    def reconstruct(self,t):
        # Converting padded tensor back into np.array
        o = t.numpy()
        o = o[np.nonzero(o)]                  # remove 0 padding
        return SequenceItem(o, self.itos)
    
    def analyze_pred(self,pred):
        return torch.argmax(pred, dim=-1)
        # method called in learn.predict() or learn.show_results()
        # to transform predictions in an output tensor suitable for reconstruct

In [None]:
data = (ImageList.from_df(df, path=PATH, folder=FOLDER)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SequenceList, itos=itos)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=custom_collater)
        .normalize()
       )

In [None]:
# # This slows down training...
# class CustomSampler(Sampler):
#     "sort dataset by longest y sequences"
#     def __init__(self, dataset):
#         self.dataset = dataset
#         self.lengths = [len(i) for i in dataset.y.items]
#         self.sorted_idxs = np.flip(np.argsort(self.lengths))
#     def __len__(self): return len(self.dataset)
#     def __iter__(self): return iter(self.sorted_idxs)

# ds = data.train_ds
# tfms = data.train_dl.tfms
# sampler = CustomSampler(ds)
# dl = DataLoader(ds, bs, shuffle=False, sampler=sampler, num_workers=num_cpus(), collate_fn=custom_collater, drop_last=True)
# ddl = DeviceDataLoader(dl, device, tfms, custom_collater)
# data.train_dl = ddl

In [None]:
data.batch_stats()
# no normalization: [tensor([0.9403, 0.9403, 0.9403]), tensor([0.1604, 0.1604, 0.1604])]
# imagenet_stats:   [tensor([1.9973, 2.1714, 2.3839]), tensor([0.6974, 0.7130, 0.7098])]
# normalize():      [tensor([0.0225, 0.0225, 0.0225]), tensor([0.9676, 0.9676, 0.9676])]

### Display

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train)

# Transformer Modules

In [None]:
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class GeLU(nn.Module):
    def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GeLU() #nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Full Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(False)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x).flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
d_model = 512
em_sz = 256
img_encoder = ResnetBase(em_sz, d_model)
transformer = make_full_model(len(itos), d_model)  #len(itos)
net = Img2Seq(img_encoder, transformer, seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# DenseNet Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
from torchvision.models import mobilenet_v2, densenet201, densenet121

In [None]:
net = densenet201(True)

In [None]:
def replace_pooling_layers(m):
    # replace MaxPool2d w/ AdaptiveAvgPool2d
    pool_layers = [i for i,o in enumerate(m.modules()) if isinstance(o,nn.MaxPool2d)]
    conv_outs = [m[i-3].out_channels for i in pool_layers]

    for c,o in zip(pool_layers, conv_outs):
        m[c] = nn.AdaptiveAvgPool2d(o)
    return m


# find all the layers before maxpool layers -> typically the best representation available at that grid size
block_ends = [i-1 for i,o in enumerate(children(m_vgg)) if isinstance(o,nn.AdaptiveAvgPool2d)]
block_ends

In [None]:
head = Lambda(lambda x: pdb.set_trace())

m = nn.Sequential(*list(children(net)[0]), head)
# m = nn.Sequential(*flatten_model(net)[:-1], head)

In [None]:
for i,layer in enumerate(m):
    if isinstance(layer, nn.AvgPool2d):
        print(layer, i)
#         m[i] = nn.AvgPool2d((2,1))

In [None]:
m[5][3] = nn.AvgPool2d((2,1))
# bs, 1920, 4, 32 => [5,7,9][3]

In [None]:
l = Learner(data, m, loss_func=LabelSmoothing(smoothing=0.1))
l.fit_one_cycle(1, 1e-3)

In [None]:
for i,layer in enumerate(list(net.modules())):
    if isinstance(layer, nn.AvgPool2d): print(layer, i)

In [None]:
for layer in list(net.features):
    if isinstance(layer, nn.AvgPool2d): print(layer)

In [None]:
class CNNBase(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        net = densenet201(True)
        modules = list(net.children())[0] #[:-3]
        # [0] => bs,1920,4,4
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(1920, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) #* 2 #8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
d_model = 512

img_encoder = CNNBase(d_model)
transformer = make_full_model(len(itos), d_model)
net = Img2Seq(img_encoder, transformer, seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# AMSGrad = partial(optim.Adam, amsgrad=True, eps=1e-4)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# Experiment

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(3, max_lr=1e-3)
# 3 cycles; 1e-3
# sm_synth, sz:128, bs: 100
# 15.269251	13.785835	0.106348   base arch - (gelu, single attn, pre-resnet34)
# 16.347567	14.436943	0.111617   "", reversed images - doesn't perform well from pretraining???
#     normal images, resnet34 (no pretraining)
# 25.128386	21.791134	0.185044   reversed_images, resnet34 (no pretraining) - not sure what the initialization is like yet...
#    densenet201 (no pretraining)
#    modified densenet201
#    pre-densenet201



### Training - Edited Data - GeLU, singlehead attn
# sm
# 10.841042	9.550449	0.071475   sz:128, bs:100, lr:1e-3  'v1_gelu_128'
# 2.631370	2.669811	0.020308   sz:256, bs:100, lr:1e-3  'v1_gelu_256'
# mix
# 18.652050	14.330537	0.021167   sz:400, bs:20, lr:1e-3   'v1_edited_gelu_400'
# 11.513125	8.577786	0.016521   sz:512, bs:15, lr:5e-5, 3cycle   'v1_edited_gelu_512'

In [None]:
learn.save('v1_edited_gelu_512')

In [None]:
learn.export()

In [None]:
# learn = load_learner(PATH)

# Test

In [None]:
from scipy.ndimage import gaussian_filter
k=16

def torch_scale_attns(attns):
    bs,sl,hw = attns.shape
    num = int(math.sqrt(hw))   # sz // k
    mod = attns.view(bs,sl,num,num)
    scaled = F.interpolate(mod, size=sz)
    return scaled  #([bs, sl, h, w])

def g_filter(att):
    return gaussian_filter(att, sigma=k)

In [None]:
def self_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].self_attn.attn.data.cpu()
def source_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].src_attn.attn.data.cpu()

In [None]:
def cer(preds, targs):
    bs,sl = targs.size()
    
    res = torch.argmax(preds, dim=2)
    error = 0
    for i in range(bs):
        p = char_label_text(res[i])   #.replace(' ', '')
        t = char_label_text(targs[i]) #.replace(' ', '')
        error += Lev.distance(t, p)/len(t)
    return error/bs

def char_label_text(pred):
#     return sp.DecodeIds(pred.tolist())
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)]
    return ''.join([itos[i] for i in nonzero])

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    if im.shape[0] == 3: im = image2np(im.data)
    ax.imshow(im, alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
x,y = next(iter(learn.data.train_dl))
# x,y = learn.data.one_batch(ds_type=DatasetType.Train)
imgs = learn.data.denorm(x)

## Uploaded Images

In [None]:
def thresh_edit(fname, thresh=100, bg=245):
    im = Image.open(fname).convert('L')  #grayscale
    np_im = np.array(im)
    im.close()
    thresh_mask = np_im > thresh
    np_im[thresh_mask] = bg
    return Image.fromarray(np_im, 'L')

In [None]:
im = thresh_edit(PATH/'test3.png')
show_img(im, figsize=(15,15))

In [None]:
e = 'edit_'+str(fname.name)
edited_fname = PATH/e

In [None]:
edited_im.save(edited_fname)

In [None]:
fname = PATH/'test/edit_test4.png'
im = open_image(fname)

In [None]:
seq,res,preds = learn.predict(im)
print(seq)

In [None]:
# r = torch.tensor([g_res], dtype=torch.long, device=device)
truth = "This is a test letter. I hope this\nworks but I'm not sure it will.\nMy handwriting is not very good."
Lev.distance(truth, str(seq))/len(truth)

## Results

In [None]:
# x,y = next(v_dl)
# imgs = denorm(x)

learn.model.eval()

shifted_y = rshift(y).long()
tgt_mask = subsequent_mask(shifted_y.size(-1))
v_preds = learn.model(x, shifted_y, tgt_mask)
v_res = torch.argmax(v_preds, dim=-1)
# v_attn = source_attn()

g_preds = learn.model(x)
g_res = torch.argmax(g_preds, dim=-1)
# g_attn = source_attn()

In [None]:
v = [learn.loss_func(v_preds, y).item(), cer(v_preds, y)]
print(f'valid:     {str(v[0])[:7]}   {str(v[1])[:7]}')

g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[:7]}')

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
# tfmr full
#             loss       cer       acc
# valid:     2.42856   0.02981   0.98177   3x1, 256/60, 'tfmr_full_3x1_single_attn'
# greedy:    11.9438   0.02745   0.93385

# valid:     1.98831   0.01858   0.98505   3x1, 256/60, 'exp_3x1_256'
# greedy:    16.3832   0.02167   0.90944

# valid:     1.47604   0.00797   0.99523   3x2, 400/45,  'tfmr_full_3x2'
# greedy:    21.4682   0.01097   0.93762

# valid:     7.30494   0.01206   0.99086   lg,  512/30,  'tfmr_full_lg'
# greedy:    434.610   0.02398   0.69553

# valid:     25.4131   0.04500   0.96773   lg, 512/30,   'tfmr_lg_LM_mixer'
# greedy:    734.834   0.15256   0.48912

# valid:     10.1481   0.00501   0.99602   cat9-12,  800/8,  'tfmr_cat9-12_full_800'
# greedy:    1330.63   0.04525   0.53181

# valid:     62.0793   0.05716   0.96214   pg,  512/20  'tfmr_full_paragraph'  (cpu)
# greedy:    1739.62   0.08347   0.43309
# beam:                0.07958
# valid:     71.9120   0.06819   0.95457   "", 2nd batch (gpu)
# greedy:    2027.24   0.11823   0.39238
# beam:                0.10697
# valid:     36.3009   0.03397   0.97633   "", 3rd batch (gpu)
# greedy:    1759.59   0.07070   0.40949
    
# valid:     50.1414   0.04137   0.97004   pg,  1000/5,  'tfmr_pg_1000'
# greedy:    2579.24   0.34178   0.18623

# valid:     56.2722   0.04167   0.96923   pg,  1000/5,  'tfmr_catpg_1000'
# greedy:    2432.54   0.37192   0.26174


# valid:     3.43745   0.05444   0.98735   mix(new)  'tfmr_mix_words_400'
# greedy:    43.6545   0.06126   0.91407

# valid:     2.96891   0.03338   0.98932   mix(new)  'tfmr_mix_words_512'
# greedy:    28.9982   0.03949   0.94500
# valid:     2.02411   0.02128   0.99302   5 more cycles
# greedy:    19.4735   0.02450   0.96104

# valid:     41.9800   0.03879   0.97535   pg, 'tfmr_mix_words_512'
# greedy:    1602.68   0.06259   0.49110


# valid:     53.9049   0.04822   0.96622   pg, 'tfmr_8head_512_mix'
# greedy:    1384.72   0.06304   0.51696


# valid:     5.24461   0.00634   mix   'v1_gelu_512'
# greedy:    217.438   0.02028
# valid:     24.3609   0.02405   pg
# greedy:    1935.30   0.05970

# valid:     12.7074   0.00954   mix   'v1_gelu_512_wiki103_base_lm_last_layer'
# greedy:    553.201   0.02814

# valid:     11.4251   0.00474   mix   'v1_gelu_512_wiki103_base_lm_full'
# greedy:    358.030   0.01187

# valid:     12.1104   0.01533   mix   'v1_gelu_multi_512'
# greedy:    506.505   0.04229


# edited data
# valid:     8.21437   0.06576   sm   'v1_gelu_128'
# greedy:    185.981   0.12890

# valid:     2.25060   0.01731   sm   'v1_gelu_256'
# greedy:    149.525   0.06003

# valid:     2.37294   0.01383   mix  'v1_edited_gelu_400'
# greedy:    30.3637   0.01732
# greedy:    1810.14   0.11409   test

# valid:     9.09986   0.01656   mix  'v1_edited_gelu_512'
# greedy:    529.308   0.03104
# greedy:    1675.20   0.10819   test

# Beam Search

## Multi-Batch Beam Decode

In [None]:
def beam_cer(preds, targs):
    bs = preds.size(0)
    error = 0
    for i in range(bs):
        p = char_label_text(preds[i])
        t = char_label_text(targs[i])
        error += _cer(t,p)
    return error/bs

In [None]:
def repeat_interleave(tensor,n):
    res = []
    for i in range(tensor.size(0)):
        for _ in range(n): res.append(tensor[i])
    return torch.stack(res)

In [None]:
# def normalize_score(score, i, alpha=0.6):
#     length_penalty = math.pow((5 + i)/6, alpha)
#     return score/length_penalty

### Beam Object

In [None]:
import heapq

class Beam(object):
    def __init__(self, beam_width=10):
        self.heap = list()
        self.beam_width = beam_width
        self.best_score = None

    def add(self, score, complete, seq):        
        heapq.heappush(self.heap, (score, complete, seq))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
            
    def get_seq(self):
        return [b[-1] for b in self.heap]
        
    def __iter__(self):
        return iter(self.heap)

In [None]:
class BeamSearch(nn.Module):
    def __init__(self, net, beam_width, seq_len, end_tok=0):
        super(BeamSearch, self).__init__()
        self.img_enc = net.img_enc
        self.transformer = net.transformer
        self.bw = beam_width
        self.seq_len = seq_len
        self.end_tok = end_tok
        self.feats = None
        self.beams = []
            
    def forward(self, src):
        with torch.no_grad():
            bs = src.size(0)
            
            # initialize beam per bs
            for _ in range(bs):
                beam = Beam(self.bw)
                beam.add(0.0, False, [1])
                self.beams.append(beam)
                
            # encode src
            self.feats = self.transformer.encode(self.img_enc(src))
            
            for i in tqdm(range(seq_len)):
                # gather sequences from beams; combine into tensor (bs*bw)
                prev_seq = torch.from_numpy(np.stack([b.get_seq() for b in self.beams])).view(-1,i+1)
                
                # generate new possibilities
                log_probs, chars = self.prob_func(prev_seq.to(device))
                log_probs, chars = log_probs.view(bs,-1,self.bw), chars.view(bs,-1,self.bw)
                
                for j in range(bs):
                    curr_beam = Beam(self.bw)
                    for k,(score, complete, seq) in enumerate(self.beams[j]):
                        for l,c in zip(log_probs[j,k],chars[j,k]):
                            log_prob,char = l.item(), c.item()
                            curr_beam.add((score+log_prob), (char==self.end_tok), seq+[char])
                    self.beams[j] = curr_beam
  
                # return if all max beams are complete
                if (self.top_complete()==True).all(): break
                    
                # expand feats to match beam size (only on 2nd run)
                if i==0: self.feats = repeat_interleave(self.feats, self.bw)

            return self.top_seq()

    def top_complete(self): return np.stack([max(b)[1] for b in self.beams])
    def top_seq(self): return torch.from_numpy(np.stack([max(b)[-1] for b in self.beams]))[:,1:]

    def prob_func(self, tgt):
        mask = subsequent_mask(tgt.size(-1))
        dec_outs = self.transformer.decode(self.feats, tgt, mask)
        logits = self.transformer.generate(dec_outs[:,-1])
        log_probs = logits - torch.logsumexp(logits, -1, keepdim=True) # more stable than F.softmax(logits,-1).log()
        return torch.topk(log_probs, self.bw, dim=-1)

In [None]:
search = BeamSearch(learn.model, 3, seq_len)
b_res = search(x)

# bw: 3, sl: 350, ~56s, 0.02491    # no normalized_score

In [None]:
beam_cer(b_res,y)

### Tensors

In [None]:
class BeamSearch(nn.Module):
    def __init__(self, net, beam_width, seq_len):
        super(BeamSearch, self).__init__()
        self.img_enc = net.img_enc
        self.transformer = net.transformer
        self.bw = beam_width
        self.seq_len = seq_len
        
        self.feats = None
        self.beam = None
        self.scores = None
        
        net.eval()
    
    def forward(self, src):
        with torch.no_grad():
            bs = src.size(0)
            
            # encode src
            self.feats = self.transformer.encode(self.img_enc(src))
            
            # initialize globals (beam=1 for first iteration; 3 thereafter)
            self.beam = torch.ones((bs,1), device=device, dtype=torch.long)
            self.scores = torch.zeros((bs,1), device=device, dtype=torch.float)
            
            for i in tqdm(range(seq_len)):
                # generate new topk chars per beam(bs*bw)
                log_probs, chars = self.prob_func(self.beam)  #(bs*beam, 3)

                # compute local scores
                scores = self.scores + log_probs
                
                # compute new beams per batch
                new_scores, idxs = torch.topk(scores.view(bs,-1), self.bw, dim=-1) #(bs, 3)
                self.scores = new_scores.view(-1,1)
                
                # set up new beam:
                nxt = torch.stack([c[i] for c,i in zip(chars.view(bs,-1),idxs)]).view(-1,1)
                pre = torch.stack([b[i//self.bw] for b,i in zip(self.beam.view(bs,-1,i+1),idxs)]).view(-1,i+1)
                                    
                # update globals
                self.beam = torch.cat([pre,nxt], dim=1)

                # end when top of beams are complete
                if self.top_complete(): break
                                                        
                # expand feats to match beam size (only on 2nd run)
                if i==0: self.feats = repeat_interleave(self.feats, self.bw) #.repeat(self.bw,1,1)
                    
            return self.top_sequences(), self.top_scores() 


    def top_complete(self): return (self.top_sequences()[:,-1]==0).all().item()   #byte tensor
    def top_sequences(self): return self.beam.squeeze()[0::self.bw][:,1:]
    def top_scores(self): return self.scores.squeeze()[0::self.bw]

    def prob_func(self, tgt):
        mask = subsequent_mask(tgt.size(-1))
        dec_outs = self.transformer.decode(self.feats, tgt, mask)
        logits = self.transformer.generate(dec_outs[:,-1])
        log_probs = logits - torch.logsumexp(logits, -1, keepdim=True) # more stable than F.softmax(logits,-1).log()
        return torch.topk(log_probs, self.bw, dim=-1)

In [None]:
search = BeamSearch(learn.model, 3, seq_len)
b_res,score = search(x)

# lg, sl: 250
# bw: 1, ~12s, 0.02398
# bw: 3, ~28s, 0.02378
# bw: 5, ~46s, 0.02378

# pg: 1000,5
# bs: 3, ~40s, 0.33227  (greedy: 0.28---)
# ''   , ~30s, 0.22635  (greedy: 0.23726)

# pg: 800,8
# bs: 3, ~51s, 0.25424  (greedy: 0.28348)

In [None]:
beam_cer(b_res,y)

In [None]:
#beam
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(b_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

## Single Beam Decode

In [None]:
# https://geekyisawesome.blogspot.com/2016/10/using-beam-search-to-generate-most.html

import heapq

class Beam(object):
    '''
    For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used.
    This is so that if two prefixes have equal probabilities then a complete sentence
    is preferred over an incomplete one since (0.5, False) < (0.5, True)
    '''

    def __init__(self, beam_width=10):
        self.heap = list()
        self.beam_width = beam_width
        self.best_score = None

    def add(self, score, complete, seq):        
        # keep track of best_score so far
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            
        # only add to beam if score is not more than beam_width below the best_score
        if score > self.best_score-self.beam_width:
            heapq.heappush(self.heap, (score, complete, seq))
            
        # maintain beam_width
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
                
    def __iter__(self):
        return iter(self.heap)

In [None]:
def beamsearch(prob_fn, seq_len, beam_width=5, start_tok=1, end_tok=3):
    prev_beam = Beam(beam_width)
    prev_beam.add(0.0, False, [start_tok])
    
    for i in tqdm(range(seq_len)):
        curr_beam = Beam(beam_width)
        
        # iterate over each beam
        for (score, complete, seq) in prev_beam:
            if complete == True:
                None  # only keep the completed best beam!!
#                 curr_beam.add(score, True, seq)
            else:
                # iterate through topk chars, calculating scores and adding to the beam.
                log_probs, chars = prob_fn(seq)
                for log_prob, char in zip(log_probs, chars): 
                    log_prob,char = log_prob.item(), char.item()
                    score += log_prob   #log probabilities are additive
#                     score = score_func(score, len(seq))
                    curr_beam.add(score, (char==end_tok), seq+[char])
        
        (best_score, best_complete, best_seq) = max(curr_beam)
        if best_complete == True: return (best_seq[1:], best_score)   # returns first complete beam not best...
            
        prev_beam = curr_beam
        
    (best_score, best_complete, best_seq) = max(curr_beam)
    return (best_seq[1:], best_score)

In [None]:
def beam_decode(net, src, beam_width, seq_len):
    net.eval()
    with torch.no_grad():
        feats = net.transformer.encode(net.img_enc(src))        
        return beamsearch(partial(prob_func, net=net, feats=feats, beam_width=beam_width), seq_len, beam_width)
    
def prob_func(tgt, net=None, feats=None, beam_width=5):
    tgt = torch.tensor([tgt], dtype=torch.long, device=device)
    mask = subsequent_mask(tgt.size(-1))
    dec_outs = net.transformer.decode(feats, tgt, mask)
    logits = net.transformer.generate(dec_outs[:,-1])
    
    log_probs = logits - torch.logsumexp(logits, 1)  # more numerically stable
    # log_probs = F.softmax(logits, -1).log()
    
    return torch.topk(log_probs.squeeze(0), beam_width, dim=-1)
#     return zip(res[0][0].detach(),res[1][0].detach())

def score_func(log_probs, i, alpha=0.6):
    length_penalty = math.pow((5 + i)/6, alpha)
    return log_probs/length_penalty

In [None]:
idx = 2
x1 = x[idx][None]
y1 = y[idx][None]

In [None]:
b_res, score = beam_decode(learn.model, image, 3, seq_len)    #294, 3m18s
# 294 - 1m40s
# 294 - 1m45s (w/ score_func)
# 295 - 22s (beam_width=1 ~ greedy)

In [None]:
r = torch.tensor([b_res], dtype=torch.long, device=device)
p = char_label_text(r)

_cer(truth, p)

In [None]:
# valid
p = char_label_text(v_res[idx][None])
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
# greedy
p = char_label_text(g_res[idx][None])
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
# beam (sz=3)
r = torch.tensor([b_res], dtype=torch.long, device=device)
p = char_label_text(r)
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
stoi = {k:i for i,k in enumerate(itos)}

In [None]:
print(char_label_text(g_res[idx][None]))

In [None]:
st = ''.join([itos[i] for i in b_res])
p = '\n'.join(textwrap.wrap(st, 70))
show_img(denorm(x1)[0], figsize=(10,10), title=p)

## Source Attn

In [None]:
idx = 0
img = imgs[idx]

v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn)[idx])

g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn)[idx])

In [None]:
#valid
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(vv_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])

In [None]:
#greedy
fig, axes = plt.subplots(6,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

# Attention Visualizations

In [None]:
def transparent_cmap(cmap, N=5):
    "Copy colormap and set alpha values"
    mycmap = matplotlib.cm.get_cmap(cmap, N)
    mycmap._init()
    mycmap._lut[:,-1] = np.linspace(0, 0.6, N+3)
    return mycmap

#Use base cmap to create transparent
# mycmap = transparent_cmap(plt.cm.Reds)

In [None]:
def show_attn(img, attns, chars, ax, color, showChars=True):
    for i in range(attns.shape[0]):
        c = chars[i].item()
        if c not in [0,1,2,3]:
            a = g_filter(attns[i])
            y,x = scipy.ndimage.center_of_mass(a)
            #sns.heatmap(a, cmap=mycmap, cbar=False, ax=ax)
            ax.imshow(a, cmap=transparent_cmap(color), interpolation='nearest')
            if showChars: ax.text(x-8,y-10,itos[c], fontsize=15)

    ax.set_title(char_label_text(chars))
    ax.imshow(img, alpha=0.6)

In [None]:
def thresh_attn(attn, thresh=0.1):
    zeros = torch.zeros_like(attn)
    new = torch.where(attn >= thresh, attn, zeros)
    
    # some attns will not have a value over the thresh in which case
    # we need to insert top k value at appropriate index
    vals, idxs = torch.topk(attn, 1, dim=-1)
    
    # reshape
    flat_new = new.flatten(0,1)
    vals = vals.flatten()
    idxs = idxs.flatten()
    
    for i in range(flat_new.size(0)):
        flat_new[i,idxs[i]] = vals[i]

    new = flat_new.view_as(new)
    return new

## Decoder Self-Attention

In [None]:
sns.set_context(context="notebook")

def draw(data, x, y, ax):
    return sns.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
                       cmap='YlOrRd', linewidths=0.05, cbar=False, ax=ax)

for layer in range(4):
    print("Decoder Self-Attention Layer", layer+1)

    fig, axes = plt.subplots(1,4, figsize=(20, 10))
    for i,ax in enumerate(axes.flat):
        # greedy decoding (no access to true values)
        pred = char_split_text(g_res[i])[20:40]
        shifted_y = rshift(g_res.float()).long()
        true = char_split_text(shifted_y[i])[20:40]
        g = draw(self_attn(layer)[i].data[20:40, 20:40], true, pred, ax=ax)
        g.set_yticklabels(g.get_yticklabels(), rotation=0) 
        g.set_xticklabels(g.get_xticklabels(), rotation=0) 
    plt.show()

## Decoder Source-Attention

In [None]:
fig, axes = plt.subplots(1,4, gridspec_kw={'hspace': 0.5}, figsize=(20, 10))
    
for idx in range(len(axes.flat)//4):
    img = imgs[idx]
    g_chars = g_res[idx]
    
    # 4 attn layers
    for h in range(4):
        attn = source_attn(h)
        g_attns = to_np(torch_scale_attns(attn)[idx])

        show_attn(img, g_attns, g_chars, axes[idx,h], 'YlGn', showChars=False)
        axes[idx,h].set_title(f'layer {h+1}')

## Validation vs Greedy (final layer src-attn)

In [None]:
v_scaled_attns = torch_scale_attns(thresh_attn(v_attn))
g_scaled_attns = torch_scale_attns(thresh_attn(g_attn))

In [None]:
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.5}, figsize=(20, 20))
for idx in range(len(axes.flat)//2):
    img = imgs[idx]

    v_chars = v_res[idx]
    v_attns = to_np(v_scaled_attns[idx])

    g_chars = g_res[idx]
    g_attns = to_np(g_scaled_attns[idx])
    
    # valid
    show_attn(img, v_attns, v_chars, axes[idx,0], 'YlOrRd')
    # greedy
    show_attn(img, g_attns, g_chars, axes[idx,1], 'YlGn')

# Individual Examination

In [None]:
idx=0
img = imgs[idx]

g_chars = g_res[idx]
g_scaled_attns = torch_scale_attns(thresh_attn(g_attn)[0:1])  # passing in bs of 1
g_attns = to_np(g_scaled_attns[0])  # removing bs

fig, ax = plt.subplots(1,1, figsize=(20, 20))
show_attn(img, g_attns, g_chars, ax, 'YlGn')

In [None]:
sns.set_context(context="notebook")

def draw(data, x, y, ax):
    mask = np.zeros_like(data)
    mask[np.triu_indices_from(mask, k=1)] = True
    return sns.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
                       mask=mask, cmap='YlOrRd', linewidths=0.05, cbar=False, ax=ax)

for layer in range(4):
    print("Decoder Self-Attention Layer", layer+1)

    fig, ax = plt.subplots(1,1, figsize=(20, 10))
    i = 1
    # greedy decoding (no access to true values)
    pred = char_split_text(g_res[i])[230:280]
    shifted_y = rshift(g_res.float()).long()
    true = char_split_text(shifted_y[i])[230:280]
    g = draw(self_attn(layer)[i].data[230:280, 230:280], true, pred, ax=ax)
    g.set_yticklabels(g.get_yticklabels(), rotation=0) 
    g.set_xticklabels(g.get_xticklabels(), rotation=0) 
    plt.show()