# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *

import pdb

In [None]:
PATH = Path('data/IAM_handwriting')
TMP_PATH = PATH/'tmp'

In [None]:
# # Quick test of export.pkl
# from fastai.tfmr_extensions import *

# learn = load_learner(PATH)

# fname = PATH/'test/edit_test4.png'
# im = open_image(fname)

# seq,res,preds = learn.predict(im)
# print(seq)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Loss and Metrics

In [None]:
def loss_prep(input, target):
    "equalize input/target sl; combine bs/sl dimensions"
    bs,tsl = target.shape
    _ ,sl,vocab = input.shape
        
    # F.pad( front,back for dimensions: 1,0,2 )
    if sl>tsl: target = F.pad(target, (0,sl-tsl))
        
    # this should only be used when testing for small seq_lens
    # if tsl>sl: target = target[:,:sl]
    
    if tsl>sl: input = F.pad(input, (0,0,0,tsl-sl))
    # not ideal => adds 82 logits all 0s...
        
    targ = target.contiguous().view(-1).long()
    pred = input.contiguous().view(-1, vocab)
    return pred, targ

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self, itos):
        super().__init__()
        self.name = 'cer'
        self.itos = itos

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = self._cer(last_output, last_target)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

    def _cer(self, preds, targs):
        bs,sl = targs.size()
        
        res = torch.argmax(preds, dim=2)
        error = 0
        for i in range(bs):
            p = self._char_label_text(res[i])   #.replace(' ', '')
            t = self._char_label_text(targs[i]) #.replace(' ', '')
            error += Lev.distance(t, p)/len(t)
        return error, bs

    def _char_label_text(self, pred):
        ints = to_np(pred).astype(int)
        nonzero = ints[np.nonzero(ints)]
        return ''.join([self.itos[i] for i in nonzero])

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((size,size), device=device).byte()).unsqueeze(0)

class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner, bos_token:int=1):
        super().__init__(learn)
        self.bos_token = bos_token
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        s = rshift(last_target, self.bos_token).long()
        mask = subsequent_mask(s.size(-1))
        return {'last_input':(last_input, s, mask), 'last_target':last_target}

# Data

## Small synth dataset

In [None]:
fname = 'sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'small_synth_words'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
csv.head()

In [None]:
sz,bs = 128,100
# sz,bs = 256,60

seq_len = 45
stats = (np.array([0.903, 0.903, 0.903], dtype=np.float32), np.array([0.197, 0.197, 0.197], dtype=np.float32))  # inception_stats

## 3x2

In [None]:
fname = '3x2_synth.csv' #'multi_synth_words.csv'
CSV = PATH/fname
FOLDER = 'multi_synth_words'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
# sz,bs = 512,30  
# sz,bs = 400,35   
# sz,bs = 256,60
sz,bs = 128,100

seq_len = 75
stats = (np.array([0.931, 0.931, 0.931], dtype=np.float32), np.array([0.175, 0.175, 0.175], dtype=np.float32))

## Large

In [None]:
fname = 'large_synth_words.csv'
CSV = PATH/fname
FOLDER = 'large_synth_words'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
sz,bs = 512,30  #256,60  #512,30  #400,45
seq_len = 250
stats = (np.array([0.93186, 0.93186, 0.93186]), np.array([0.17579, 0.17579, 0.17579]))

## Concat Lines

In [None]:
nums = '3-6'   #'7-10'  #'11-14'
CSV = PATH/f'cat_lines_{nums}.csv'
FOLDER = 'resized_cat_lines'

csv = pd.read_csv(CSV)
test = pd.read_csv(PATH/'test_pg.csv')

len(csv), len(test)

In [None]:
lengths = np.array([len(i.split(' ')) for i in csv.char_ids.values])
lengths.max()

In [None]:
sz,bs = 512,20   #1000,5  #800,8   #512,20
seq_len = 600 #600 #450  #300
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

### Concat All

In [None]:
a = pd.read_csv(PATH/'cat_lines_3.csv').sample(1000)
b = pd.read_csv(PATH/'cat_lines_4.csv').sample(1000)
c = pd.read_csv(PATH/'cat_lines_5.csv').sample(1000)
d = pd.read_csv(PATH/'cat_lines_6.csv').sample(1000)
e = pd.read_csv(PATH/'cat_lines_7.csv').sample(1000)
f = pd.read_csv(PATH/'cat_lines_8.csv').sample(1000)
g = pd.read_csv(PATH/'cat_lines_9.csv').sample(1000)
h = pd.read_csv(PATH/'cat_lines_10.csv').sample(1000)
i = pd.read_csv(PATH/'cat_lines_11.csv').sample(1000)
j = pd.read_csv(PATH/'cat_lines_12.csv').sample(1000)

In [None]:
k = pd.read_csv(PATH/'paragraph_chars.csv')

In [None]:
val_idxs = np.array(k.sample(frac=0.15, random_state=42).index)

In [None]:
trn = k[~k.index.isin(val_idxs)]
test = k[k.index.isin(val_idxs)]

In [None]:
new = pd.concat([a,b,c,d,e,f,g,h,i,j,trn], ignore_index=True)
len(new)

In [None]:
# new.to_csv(PATH/'cat_lines_pg.csv', index=False)
# test.to_csv(PATH/'test_pg.csv', index=False)

## Paragraphs

In [None]:
fname = 'pg.csv' #'paragraphs.csv'
CSV = PATH/fname
FOLDER = 'paragraphs'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
# sz,bs = 1000,5  #1024,5  #1000,5   #~2000x1000 full size
# sz,bs = 800,8
sz,bs = 512,10

seq_len = 700   #~400 chars/paragraph - max: 705
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

In [None]:
pg_vals = np.array([1057,  144,   99,  412, 1008,   86, 1338, 1474,  196,  807, 1164,   32, 1140,  568,  655, 1010,  730,
831, 1047,  970, 1331, 1224, 1262,   48,  121,  870, 1044, 1051,  327, 1459, 1073, 1247,  465,  601,
745,  254, 1056,  701,  771,   74, 1246, 1080, 1258,  170,  861,  898,  356,  133,  642,  723,  382,
267,   94,  565, 1450,  347, 1086, 1121, 1198, 1506,  350,  262, 1259, 1195,  518, 1367,  623, 1497,
778,  633,  962,  390,  669, 1529,  197,  872, 1003, 1053, 1457,  618,  596,  800,   46,  289,  688,
841,  265, 1069,  616,  300, 1099, 1241,  317,   78,   54,  348, 1495,  570,  766,  115, 1404,  612,
844,   76,   61,  389, 1451,  435,  564,  206,   70,  354, 1268,  780,  470, 1155,  644,   31,  907,
278,  991,  885, 1291, 1394, 1145,  494, 1297,  240,  741, 1127,  901,  826, 1440,  516, 1431,  261,
1112, 1416, 1355,  720, 1285,   16, 1147, 1325,  969, 1389,  441,  194, 1372,  593,  236, 1334, 1079,
976,  802,  713,    1, 1135, 1032,  546,  956, 1512, 1310,  746,  888,  445,  799, 1308,  505,  402,
1114,  276,  160,  378,  641,  571,  590,   17,  437,  648,  630,  490, 1486,   44, 1027, 1527,  411,
1452,  514,  560,  958, 1243, 1120,  397,  599,  668, 1311, 1514,  264, 1084, 1107,   38,   50,  930,
909, 1077,  239,  769,  455, 1170, 1211,  452])

## Downloaded Images

In [None]:
CSV = PATH/'downloaded_images.csv'
FOLDER = 'downloaded_images'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
# csv['filename'] = csv['filename'].apply(lambda x: f"dl_{x}")
# csv.head()
# csv.to_csv(CSV, index=False)

In [None]:
# sz,bs = 1000,5  #1024,5  #1000,5   #~2000x1000 full size
# sz,bs = 800,8
sz,bs = 512,20

seq_len = 700   #~400 chars/paragraph - max: 705
stats = (np.array([0.941, 0.941, 0.941], dtype=np.float32), np.array([0.128, 0.128, 0.128], dtype=np.float32))

## Mix

### Create CSV

In [None]:
# a = pd.read_csv(PATH/'small_synth_words.csv')
# b = pd.read_csv(PATH/'large_synth_words.csv')
# c = pd.read_csv(PATH/'cat_lines.csv')
# d = pd.read_csv(PATH/'paragraphs.csv')
e = pd.read_csv(PATH/'downloaded_images.csv')

In [None]:
# new = pd.concat([a,b,c,d], ignore_index=True)
new = pd.concat([csv,e], ignore_index=True)
len(new)

In [None]:
new.to_csv(PATH/'mix_words_dl.csv', index=False)

### Fix Bad Data

In [None]:
row = csv[csv['char_ids'].str.startswith('32 70 69 74 75 64 75 76 75 64 70 69 1 75 70 78 56 73 59 1 71')]
row

In [None]:
c_ids = csv.loc[row.index.item()].char_ids
c_ids

In [None]:
stoi = {v:k for k,v in enumerate(itos)}

In [None]:
string = ''.join([itos[int(i)] for i in c_ids.split(' ')])
string

In [None]:
string.index('aluminium')

In [None]:
string = string.replace('aluminium', 'of the', 1)
string

In [None]:
ids = np.array([stoi[letter] for letter in string[:-5]] + [3])
ids

In [None]:
str_ids = ' '.join([str(l) for l in ids])
str_ids

In [None]:
samp.loc[103388].char_ids = str_ids
df.loc[103388].char_ids = str_ids

### Load

In [None]:
fname = 'full_mix.csv'  #'mix_words_dl.csv' #'mix_words.csv'   #53539
CSV = PATH/fname
FOLDER = 'mix_words'

csv = pd.read_csv(CSV)
len(csv)

In [None]:
sz,bs = 512,15
# sz,bs = 400,20

seq_len = 800
stats = (np.array([0.931, 0.931, 0.931], dtype=np.float32), np.array([0.175, 0.175, 0.175], dtype=np.float32))

## ModelData

In [None]:
df = pd.read_csv(CSV)

In [None]:
itos = pickle.load(open(TMP_PATH/'itos.pkl', 'rb'))  #'char_itos.pkl'

In [None]:
def custom_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads end of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len).long() + pad_idx
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

In [None]:
class SequenceItem(ItemBase):
    def __init__(self,data,vocab): self.data,self.vocab = data,vocab        
    def __str__(self): return self.textify(self.data)
    def __hash__(self): return hash(str(self))
    def textify(self, data): return ''.join([self.vocab[i] for i in data[:-1]])

class ArrayProcessor(PreProcessor):
    "Convert df column (string of ints) into np.array"
    def __init__(self, ds:ItemList=None): None
    def process_one(self,item): return np.array(item.split(), dtype=np.int64)
    def process(self, ds): super().process(ds)
        
class ItosProcessor(PreProcessor):
    def __init__(self, ds:ItemList=None): self.itos = ds.itos
    def process(self, ds:ItemList): ds.itos = self.itos
        
class SequenceList(ItemList):
    _processor = [ItosProcessor, ArrayProcessor]
    
    def __init__(self, items:Iterator, itos:List[str]=None, **kwargs):
        super().__init__(items, **kwargs)
        self.itos = itos
        self.copy_new += ['itos']
        self.c = len(self.items)

    def get(self, i):
        o = super().get(i)
        return SequenceItem(o, self.itos)

    def reconstruct(self,t):
        # Converting padded tensor back into np.array
        o = t.numpy()
        o = o[np.nonzero(o)]                  # remove 0 padding
        return SequenceItem(o, self.itos)
    
    def analyze_pred(self,pred):
        return torch.argmax(pred, dim=-1)
        # method called in learn.predict() or learn.show_results()
        # to transform predictions in an output tensor suitable for reconstruct

In [None]:
tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=2.0, max_warp=0.1)

data = (ImageList.from_df(df, path=PATH, folder=FOLDER)
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        .label_from_df(label_cls=SequenceList, itos=itos)
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=custom_collater)
        .normalize(stats)
       )

In [None]:
# data.show_batch(rows=3, ds_type=DatasetType.Train)

# Transformer Modules

In [None]:
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class GeLU(nn.Module):
    def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GeLU() #nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Full Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x).flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
d_model = 512
em_sz = 256
img_encoder = ResnetBase(em_sz, d_model)
transformer = make_full_model(len(itos), d_model)
net = Img2Seq(img_encoder, transformer, seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1), metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# LM Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x).flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2):
    c = deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, lm, mixer, seq_len=500):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.lm = lm
        self.mixer = mixer
        self.seq_len = seq_len
        
    def forward(self, src, tgt=None, tgt_mask=None, lm=None):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(self.seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    if lm is not None:
                        lm_prob = self.lm(tgt)[0][:,-1]
                        prob = self.mixer(prob+lm_prob)
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            outs = self.transformer.generate(dec_outs)           # ([bs, sl, vocab])
            lm_outs = self.lm(rshift(tgt))[0]
            out = self.mixer(outs+lm_outs)
        return out

In [None]:
d_model = 512
em_sz = 256
img_encoder = ResnetBase(em_sz, d_model)
transformer = make_full_model(len(itos), d_model)
lm = get_language_model(AWD_LSTM, len(itos))
mixer = nn.Linear(len(itos), len(itos))
net = Img2Seq(img_encoder, transformer, lm, mixer, seq_len)

# AdamW16 = partial(optim.Adam, betas=(0.9,0.99), eps=1e-4)  ->  #.to_fp16(max_scale=256)
# partial: way to always call a function with a given set of arguments or keywords

learn = Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1), metrics=[CER(itos)], callback_fns=TeacherForce)
learn.clip_grad(0.25)
None

# Add LM to model state_dict

In [None]:
sd = torch.load(PATH/'models/v1_gelu_512.pth', map_location=device)

In [None]:
lm_sd = torch.load('data/wikitext/wikitext-2-raw/models/wiki2_base.pth', map_location=device)

from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd['model'].items():
    name = 'lm.'+k
    new_lm_sd[name] = v

In [None]:
sd['model'].update(new_lm_sd)

In [None]:
learn.load('v1_gelu_512', strict=False)
None

In [None]:
learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
learn.split([learn.model.img_enc, learn.model.transformer, learn.model.lm, learn.model.mixer])

In [None]:
learn.freeze_to(-1)

In [None]:
# learn.model.lm[1].decoder.weight.requires_grad
learn.model.mixer.weight.requires_grad

In [None]:
learn.save('v1_gelu_512_wiki2_base_lm')

In [None]:
learn.fit_one_cycle(1, max_lr=2e-5)

### Training - GeLU, singlehead attn
# Mix
# with LM
# 53.998314	32.622627	0.024337   last layer only

In [None]:
learn.load('v1_gelu_512_wiki2_base_lm_last_layer')
None

In [None]:
learn.freeze_to(1)

# Experiment

In [None]:
learn.load('v1_gelu_512', strict=False)
None

In [None]:
learn.data = data

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
# def convert_h5_to_pth(fname, learner):
#     sd = torch.load(PATH/f"models/{fname}.h5")#, map_location='cpu')
#     learner.model.load_state_dict(sd)
#     learner.save(f"v1_{fname}")

# convert_h5_to_pth('tfmr_mix_words_512', learn)

In [None]:
learn.fit_one_cycle(1, max_lr=5e-3)
# 'v1_exp_baseline'
# 0	92.709564	85.651016	0.706222	05:29
# 1	61.329845	50.974323	0.388731	04:58
# 2	47.806641	41.577847	0.311466	05:10
# 3	42.097706	36.231873	0.265740	05:03
# 4	40.572376	34.968216	0.255079	05:06

# 3x2, sz:128, bs:100, fit_one_cycle(5, 1e-4)
# 43.886196	36.105576	0.265459   baseline
# 40.572376	34.968216	0.255079   baseline    'v1_exp_baseline'          ***
# 55.031853	49.186695	0.373046   remove encoder; multihead(8) attn
# 40.953770	36.203377	0.269332   multihead(8) attn
# 40.215485	35.141953	0.259177   multihead(4) attn  'v1_exp_multiattn'   ***

# lr: 1e-3
# 12.597528	10.658405	0.074232   baseline   'v1_exp_mask'   ***
# 13.165689	11.060272	0.075740   nn.ReLU(True) (new itos.pkl)   'v1_exp_relu'
# 22.972273	20.039871	0.134025   AdamW16 optimizer
# 14.912671	12.835587	0.088028   d_model=400, hidden=1152
# 12.332767	10.326035	0.070659   GeLU activation  'v1_exp_gelu'   ***
# 13.439913	11.265744	0.077952   Swish activation
# 10.887406	9.458515	0.064375   GeLU, multihead(4) attn   'v1_exp_gelu_multihead' ***

# different architectures (no pretraining)
# 22.601816	17.765850	0.122246   xresnet34 [256]  'v1_exp_xres34'
# 49.068211	43.858349	0.337089   xresnet34 [512]
# 19.258896	15.660585	0.107812   resnet34           ***
# 22.787096	18.036461	0.124185   resnet50
# 23.279938	18.810421	0.132146   xception [728]   'v1_exp_xception'

# mixed precision
# 13.537477	11.407567	0.078925   mixed precision  fp_16(max_scale=256), AdamW16     'v1_exp_fp16'
# 18.175873	15.319780	0.105136   mixed precision bs: 200 (5749 GPU MB)


### Training - GeLU, multihead(4) attn
# 3x2
# 10.887406	9.458515	0.064375   sz:128, bs:100, lr:1e-3   'v1_exp_gelu_multihead'
# 2.238591	2.254460	0.012747   sz:256, bs:60, lr:1e-3    'v1_gelu_multi_256'
# Mix
#    sz:400, bs:20, lr:1e-3

### Training - GeLU, singlehead attn
# 3x2
# 12.332767	10.326035	0.070659   sz:128, bs:100, lr:1e-3   'v1_exp_gelu'   
# 2.631910	2.435502	0.015014   sz:256, bs:60, lr:1e-3    'v1_gelu_256'
# Mix
# 29.340302	22.766819	0.023261   sz:400, bs:20, lr:1e-3    'v1_gelu_400'
# 14.096607	11.613899	0.014052   sz:512, bs:15, lr:5e-5    'v1_gelu_512'

# with LM
# 53.998314	32.622627	0.024337   last layer only

In [None]:
learn.save('v1_gelu_512')

In [None]:
learn.export()

In [None]:
# learn = load_learner(PATH)

# Word + Char Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        
        self.word_decoder = decoder
        self.char_decoder = decoder
        
        self.src_embed = src_embed
        
        self.tgt_embed = tgt_embed
        
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(self.src_embed(src))
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * 18

In [None]:
class ImageAdaptor(nn.Module):
    def __init__(self, em_sz, d_model, drop=0.2):
        super().__init__()
        
        self.linear = nn.Linear(em_sz, d_model)
        self.dropout = nn.Dropout(drop)
        
    def forward(self, x):
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 4
        x = self.dropout(x)
        return x

In [None]:
def make_full_model(c_vocab, w_vocab, em_sz=256, d_model=512, N=4, drop=0.2):
    c = deepcopy
    attn = SingleHeadedAttention(d_model)
#     attn = MultiHeadedAttention(d_model, 4)
    ff = PositionwiseFeedForward(d_model, drop)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drop), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drop), N),
        ImageAdaptor(em_sz, d_model),
        nn.Sequential(
            Embeddings(d_model, c_vocab), PositionalEncoding(d_model, 1, drop, 2000)
        ),
        nn.Sequential(
            Embeddings(d_model, w_vocab), PositionalEncoding(d_model, 1, drop, 2000)
        ),
        nn.Linear(d_model, vocab)
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None):
        feats = self.img_enc(src)                            # ([bs, h*w, d_model])
        dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
        out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

    def greedy_decode(self, src, seq_len):
        with torch.no_grad():
            feats = self.transformer.encode(self.img_enc(src))
            bs = src.size(0)
            tgt = torch.ones((bs,1), dtype=torch.long, device=device)

            res = []                
            for i in tqdm(range(seq_len)):
                dec_outs = self.transformer.decode(feats, Variable(tgt))
                prob = self.transformer.generate(dec_outs[:,-1])
                res.append(prob)
                pred = torch.argmax(prob, dim=-1, keepdim=True)
                if (pred==0).all(): break
                tgt = torch.cat([tgt,pred], dim=-1)
            out = torch.stack(res).transpose(1,0).contiguous()
            return out      

In [None]:
d_model = 512
em_sz = 256
img_encoder = ResnetBase(em_sz)
transformer = make_full_model(len(itos), em_sz, d_model)
net = Img2Seq(img_encoder, transformer)

opt_fn = partial(optim.Adam, betas=(0.7, 0.99)) #, lr=0, betas=(0.9, 0.98), eps=1e-9)
# partial: way to always call a function with a given set of arguments or keywords

learn = RNN_Learner(data, BasicModel(to_gpu(net)), opt_fn=opt_fn)
learn.clip = 0.25
learn.crit = LabelSmoothing(smoothing=0.1)   #XE_loss
learn.metrics = [cer, acc]

# Train

### load state dict for changes

In [None]:
#LM
LM_PATH = Path('data/wikitext/wikitext-2-raw')
sd = torch.load(LM_PATH/'models'/'LM.h5', map_location=lambda storage, loc: storage)

In [None]:
sd = torch.load(PATH/'models'/'tfmr_paragraph.h5', map_location=lambda storage, loc: storage)
sd.pop('img_enc.linear.bias')   # need to remove mismatched linear weight

In [None]:
learn.model.load_state_dict(sd, strict=False)

### load model

In [None]:
learn.load('tfmr_mix_words_512')

## LR find

In [None]:
learn.lr_find(stepper=TfmrStepper)
learn.sched.plot(n_skip=0, n_skip_end=2)

## Experimentation

In [None]:
#gpu
lr=1e-4
learn.fit(lr, 5, cycle_len=1, stepper=TfmrStepper, use_clr=(20,5))
# 3x1: XE/bs, kaiming_normal, no pos_enc, emb/out weight tying, no encoder, 2 layers, 256/256, single attn
# 52.199187  50.1736    0.730092    lr:1e-4
# 17.318725  13.803769  0.205331    2nd run, lr: 1e-3
# 11.96539   9.991971   0.151522    3rd run, lr: 1e-4     **tfmr

# Loss fns
# 43.177135  41.245482  1.185836    LabelSmoothing (KL/bs)
# 4215.8745  4012.3099  0.957795    LabelSmoothing (KL)
# 45.200728  43.662235  0.753881    LabelSmoothing (KL/bs w/out padding logic)  perplexity < BLEU/acc
# 5340.5770  5129.8346  0.738853    XE                    --- scaling by BS not a factor
# 52.199187  50.1736    0.730092    XE/bs  **

# Initialization/Activation
# 52.199187  50.1736    0.730092    kaiming_normal  (ff activation: leaky_relu)
# 52.987157  50.976083  0.758013    kaiming_uniform
# 49.553779  47.270616  0.702353    xavier_normal
# 49.380884  47.345175  0.673508    xavier_uniform  (ff activation: leaky_relu)   **
# 49.849071  47.369221  0.708807    xavier_uniform  (ff activation: relu)

# Positional Encoding
# 35.426869  30.449724  0.423573    target only; no embed/out weight tying   **
# 35.390268  31.216677  0.45097     "" ; w/ scaling factor [* math.sqrt(d_model)]
# 45.596748  42.213515  0.572662    target only; w/ embed/out weight tying

# Encoder
# 34.46771   29.207974  0.408425    **

# Attention
# 32.087941  26.431674  0.362717    SingleHead - linear layers for (q,k,v)
# 37.739475  32.962685  0.454753    MultiHead (8)
# 36.09057   31.346036  0.432342    MultiHead (4)
# 31.069687  26.00539   0.360032    MultiHead (1)     **

# N layers (6)
# 51.119606  48.809905  0.645302

# d_model = 512
# 21.032752  17.234559  0.230763    ~11:35  **

# em_sz
# 27.589121  22.698359  0.307399    128  ~14:09
# 28.320952  24.183105  0.338207    512  ~11:31

# include layer_norm in encoder/decoder
# 24.035412  19.824303  0.270881   modified: self.norm(x + self.dropout(sublayer(x)))
# 20.347157  17.23408   0.228328   original:  x + self.dropout(sublayer(self.norm(x)))   **tfmr_experiment
# 10.732795  9.875141   0.12633    2nd run

In [None]:
learn.save('tfmr_experiment')

## Freezing

In [None]:
learn.freeze_to(1)

In [None]:
lr=1e-4
learn.fit(lr, 3, cycle_len=1, stepper=TfmrStepper, use_clr=(20,5))

In [None]:
learn.unfreeze()
lr = np.array([lr/10, lr])

## Initial size

In [None]:
lr=1e-5
learn.fit(lr, 5, cycle_len=1, stepper=TfmrStepper, use_clr=(20,8), best_save_name=f'best_{FOLDER}_{sz}')
# 3x1, sz: 128, bs: 100
# 20.347157  17.23408   0.228328
# 10.732795  9.875141   0.12633     2nd run

# 6.753466   6.750452   0.085309    Full Transformer (4N/4h), 10cycles(20,10)   ~33m   'tfmr_full_3x1'
# 7.0736     6.798651   0.085855    "", single attn, 10cycles(20,10)            ~38m

# 7.161385   8.286392   0.102471   0.92869    exp_3x1

# 3x2; sz: 128, bs: 100
# 65.659014  59.8375    0.492443   Transformer (4N/8h), 5cycles(20,8)    'tfmr_8head_128'
#     PosEnc2d instead of Encoder (4N/4h) 5cycles(20,5)  'posenc2d'

# 3x2; sz: 256, bs: 60
# 16.186042  14.892255  0.101064   Transformer (4N/8h), 5cycles(20,8)    'tfmr_8head_256'


# 3x2; sz: 256, bs: 60
# 14.65408   13.10881   0.090017    preload tfmr_3x1_256
# 4.464029   5.072127   0.037011    Full Transformer (4N/0h), 10cycles(20,10)   ~1h 46m

# 6.853619   7.136484   0.050812   0.962716    exp_3x2

# Mixed Synth, sz: 128, bs: 100
# 15.684915  14.209854  0.133399    enc/dec/mixer, N:4, pre-loaded LM  ~1h

# lg; sz: 256, bs: 60
# 161.081839 137.861294 0.231951    enc/dec/mixer, N:4, pre-loaded LM  ~1h 45m

# lg; sz: 400, bs: 30
# 98.491235  89.396089  0.123849    preload tfmr_3x2      'tfmr_lg'
# 101.903571 88.239827  0.122516    -same as above-       'tfmr_lg2'
# bs: 45
# 161.90945  169.713365 0.207615    dropout:0.5    'tfmr_lg_tmp'
# 29.706235  25.84234   0.039261    Full Transformer (4N/0h), 10cycles(20,10)   ~3h 47m

# concat lines, sz: 512, bs: 20
# LMmixer
# 68.669186  56.487096  0.123718    345, preload 'tfmr_lg_LM_mixer'
# 144.601336 116.728641 0.153257    678, preload 345
#      9-12, preload 678, lr: 1e-5

# Full
# 26.620874  24.083611  0.05163     345, preload 'tfmr_full_lg'  ~2hr  'tfmr_cat345_full'
# 46.148448  37.998259  0.04748     678, preload 345   ~2h 25m  'tfmr_cat678_full'
# 56.235635  42.104473  0.03309     9-12, preload 678   ~3h 25m   'tfmr_cat9-12_full'

# paragraph; sz:512, bs:30
# 289.464371 202.553455 0.175994     preload tfmr_lg2, 4 cycles
# 178.432805 145.702918 0.128241     2nd run, 5 cycles
# 144.489454 130.875442 0.116076     3rd run              'tfmr_paragraph'
# bs: 20
# 66.923881  70.373825  0.069164     Full Transformer (4N/0h), 10cycles(20,10)     'tfmr_full_paragraph'
# 81.810724  78.974658  0.078329     "", reversed src_attn/self_attn

# 9-12; sz: 512, bs: 30
# 139.165167 102.821483 0.089816     preload 'tfmr_full_paragraph'      '9-12_512'

# paragraph; sz: 800, bs: 8
# 24.8       45.8       0.0407       preload tfmr_cat9-12_full_800, 3cycles    'tfmr_pg_full_800'

# concat lines + pg, sz: 800, bs: 8
# 9.978522   9.815259   0.006313     preload tfmr_cat9-12_full_800, 1cycle     'tfmr_catpg_800'


# mix_words (new dataset); sz: 400, bs: 30
# 60.475961  48.045776  0.059005     Tfmr (4N/0h); preload 'tfmr_full_lg' (512)     'tfmr_mix_words_400'

# mix_words (new dataset); sz: 400, bs: 10 
# 83.258275  66.635521  0.08263      Transformer (4N/8h), 5cycles(20,8)
# 42.608748  38.862213  0.047246     2nd run                                 'tfmr_8head_400_mix'

# downloaded_images; sz: 512, bs: 10
# 119.754323 31.507531  0.05412      'best_downloaded_images_512'

In [None]:
learn.save('tfmr_8head_400_mix')

## Increase size

In [None]:
# sz,bs = 512,20
# sz,bs = 1024,5
# sz,bs = 256,60
sz,bs = 400,45

In [None]:
learn.set_data(data)

In [None]:
lr=1e-4
learn.fit(lr, 5, cycle_len=1, stepper=TfmrStepper, use_clr=(20,5), best_save_name=f'best_{FOLDER}_{sz}')
# 3x1; sz: 256, bs: 60
# 8.169301   7.199782   0.087933    fresh start
# 4.305115   4.585243   0.054132    2nd run (increase dropout 0.2)     'tfmr_3x1_256'
# 3.6729     3.756866   0.047393    resize from 128

# 1.489701   2.297418   0.027936   0.981084     exp_3x1_256

# 2.450041   2.695453   0.032867    Full Transformer (4N/4h)   ~43m 48s       'tfmr_full_3x1'
# 2.577439   2.748884   0.033926    Full Transformer (4N/0h)   ~45m 54s       'tfmr_full_3x1_single_attn'

# 3x2; sz: 400, bs: 45
# 5.701717   5.83973    0.039572    resize from 256 (increase dropout 0.3)     'tfmr_3x2'
# 2.24673    2.478204   0.017335    Full Transformer (4N/0h)   ~2h 31m     'tfmr_full_3x2'

# 6.739319   6.472689   0.042743    Transformer (4N/8h), 3cycles(20,5)    'tfmr_8head_400'

# Mixed Synth, sz: 256, bs: 60
# 3.584531   3.645265   0.031744    enc/dec/mixer, N:4, pre-loaded LM  ~2h 30m

# lg; sz: 512, bs: 30
# 38.786294  45.805161  0.061061    resize from 400  -quit after 3 epochs-
# 8.317326   8.491839   0.010576    Full Transformer (4N/0h)   ~2h 31m     'tfmr_full_lg'
# 48.016887  37.537244  0.054766    enc/dec/mixer, N:4, pre-loaded LM  ~6h 53m   'tfmr_lg_LM_mixer'

# preload tfmr_lg2; increase dropout to 0.3 in attention and ff; wd; lower lr
# 98.273312  60.214675  0.08372    lr: 1.5e-5, wd: 1e-5, 3 cycles(20,4)                    'tfmr_lg_wd'
# 81.961758  53.715966  0.075622   lr: 2e-5, no wd, 4 cycles(20,5)  -quit after 2 epochs-  'tfmr_lg_wd_tmp'
# 45.696747  37.355512  0.053213   "", {drops}, preloaded tfmr_lg_wd_tmp                   'tfmr_lg2'

# concat lines, sz: 800, bs: 8
# Full
# 3.880624   4.531283   0.005706    345, preload 'tfmr_cat9-12_full'  ~  'tfmr_cat345_full_800'
# 5.569003   6.855545   0.004001    678, preload above, 5cycles  ~2h 45m  'tfmr_cat678_full_800'
# 9.628103   10.939916  0.00433     9-12, preload above, 3cycles  ~2h 18m  'tfmr_cat9-12_full_800'

# paragraph; sz:800, bs:10
# 100.395217 86.058532  0.074781     resize from 512
# 78.234521  81.089716  0.067993     2nd run           'tfmr_paragraph_800'

# paragraph; sz: 1024, bs:5
# 55.707705  56.274219  0.05457     Full Transformer (4N/0h), 10cycles(20,10)   ~23m   'tfmr_full_paragraph2'

# sz: 1000, bs:5
# 22.210477  47.557047  0.041825    pretrained on tfmr_catpg_1000    'tfmr_pg_1000'

# concat lines + pg, sz: 1000, bs: 5
# 8.495862   10.158758  0.0068      3cycles     'tfmr_catpg_1000'


# mix_words (new dataset); sz: 512, bs: 20
# 25.864771  20.991407  0.028696    3/5 iterations -> tfmr_mix_words_400     'tfmr_mix_words_512'
# 20.654874  17.808776  0.024131    5cycles(20,5)     'tfmr_mix_words_512_2'

# mix_words (new dataset); sz: 512, bs: 8
# 15.195948  14.417563  0.019355    Transformer (4N/8h), 6cycles(20,8)    'tfmr_8head_512_mix'

In [None]:
learn.save('tfmr_mix_words_512_2')

# Experiment

In [None]:
learn.load('tfmr_cat9-12_full_800')

In [None]:
learn.freeze_groups([0])
learn.model.img_enc.trainable, learn.model.transformer.trainable

In [None]:
lr=1e-4
learn.fit(lr, 5, cycle_len=1, stepper=TfmrStepper, use_clr=(20,5))

# sz: 128
# input: 3x1
# 7.298699   7.093821   0.091642   0.936467    BASELINE - Full tfmr   ~29m
# 12.254474  10.917412  0.141336   0.900215    adaptor: conv/bn + drop, no scaling


# 28.902502  24.930458  0.398251   0.732572    tfmr (img_enc frozen)   **3x1_base
# 42.132684  35.906375  0.338968   0.770622    "", sz: 256, 3cycle(20,5)   ??resizing not an improvement??




# input: mix (sample: 50000), img_enc frozen
# 62.628027  59.523581  0.679014    Full Transformer -- 4N/0h, img_enc:256 w/ linear, w/ scaling factor  ~13m

# 81.75129   79.767102  0.806713    "", img_enc:256 w/ conv/bn, pos_enc, tgt_emb drop:0.5, w/out scaling factor  ~13m
# 76.193838  73.24591   0.757938    "", tgt_emb drop:0.1  ~13m
# [105.264, 0.655, 13.205]
# 75.780639  72.219551  0.761694    "", no pos_enc  ~14m
# [102.794, 0.679, 10.355]
# 81.270569  78.084324  0.788466    "", linear embedding w/ weight tying, smoothed: 0.7  ~13m
# 72.355454  69.852579  0.733982    2nd run
# [69.245, 0.932, 187.734]
# 95.489009  91.202586  0.791639    "", XE loss


# 1e-4, 5cycle(20,10), img_enc frozen
# input: mix (sample: 50000)
# 53.373692  48.385205  0.544684   0.696936    BASELINE - Full tfmr
# greedy:    134.89598  0.643921   0.458704
# 44.957038  39.8659    0.437462   0.75087     2nd run

# 35.53787   30.969923  0.310315   0.811553    BASELINE - unfrozen img_enc    **base-unfrozen
# greedy:    117.58634  0.392103   0.551786

# 73.066214  70.389176  0.747244   0.572001    4N/0h, img_enc:256 w/ conv/bn, linear: prob dist (smoothed: 0.7)
# 66.118537  63.631764  0.69066    0.605459    2nd run

# input: 3x1
# 28.902502  24.930458  0.398251   0.732572    BASELINE - Full tfmr (img_enc frozen)   **3x1_base
# 42.132684  35.906375  0.338968   0.770622    "", sz: 256, 3cycle(20,5)   ??resizing not an improvement??

In [None]:
learn.save('3x1_base')

In [None]:
sz,bs = 256,60

In [None]:
learn.set_data = data

In [None]:
lr=1e-4
learn.fit(lr, 10, cycle_len=1, stepper=TfmrStepper, use_clr=(20,10))

# sz:256 
# 60.575186  52.815338  0.335582   0.803861     Baseline - full tfmr
# 32.735003  27.785917  0.147773   0.907831     "" , unfrozen    **base
# val:       0.00983    0.591396   0.75641
# greedy:    0.04553    0.495198   0.509295


# 101.034633 95.16612   0.625264   0.642998    4N/0h, img_enc:256 w/ conv/bn, linear: prob dist (smoothed: 0.7)
# 79.018152  70.1724    0.456446   0.735568    2nd run  ~28m
# 64.660601  55.688069  0.352824   0.792006    3rd run    **prob_embed_experiment
# [463.614, 2.677, 1025.052]   FAIL

In [None]:
learn.unfreeze()
learn.model.img_enc.trainable, learn.model.transformer.trainable

lrs = np.array([lr/10, lr])

In [None]:
lr=1e-4
learn.fit(lrs, 5, cycle_len=1, stepper=TfmrStepper, use_clr=(20,10))

# 1e-4, 5cycle(20,10)
# input: 3x1, sz: 256
# 25.84745   21.318996  0.179581   0.874763    BASELINE - Full tfmr, img_enc unfrozen   **3x1_base
# greedy:    115.10662  0.513799   0.356322
# val:        36.26467  0.793716   0.598851

### Decode Stats

In [None]:
v_dl = iter(data.val_dl)
denorm = data.val_ds.denorm

In [None]:
x,y = next(v_dl)
imgs = denorm(x)

shifted_y = rshift(y).long()
tgt_mask = subsequent_mask(shifted_y.size(-1)) #make_tgt_mask(shifted_y)

learn.model.eval()

v_preds = learn.model(x, shifted_y, tgt_mask)
v_res = torch.argmax(v_preds, dim=-1)
v_attn = source_attn()

g_preds = learn.model.greedy_decode(x, seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g_attn = source_attn()

In [None]:
v = [learn.crit(v_preds, y).item(), cer(v_preds, y), acc(v_preds, y).item()]
print(f'valid:     {str(v[0])[:7]}   {str(v[1])[:7]}   {str(v[2])[:7]}')

g = [learn.crit(g_preds, y).item(), cer(g_preds, y), acc(g_preds, y).item()]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[:7]}   {str(g[2])[:7]}')

In [None]:
# tfmr full
#             loss       cer        acc
# valid:     6.55702   0.08746   0.934839   3x1, 128/100
# greedy:    42.5730   0.10094   0.759062

# previous best models
# valid:     2.42856   0.02981   0.981771   3x1, 256/60, 'tfmr_full_3x1_single_attn'
# greedy:    11.9438   0.02745   0.933854

# valid:     1.47604   0.00797   0.995238   3x2, 400/45,  'tfmr_full_3x2'
# greedy:    21.4682   0.01097   0.937622

# valid:     7.30494   0.01206   0.990868   lg,  512,30,  'tfmr_full_lg'
# greedy:    866.544   0.14109   0.372146

# 


In [None]:
# res = torch.argmax(preds, dim=-1)
denorm = data.val_ds.denorm

fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    t = char_label_text(res[i])
    ax=show_img(denorm(x[i])[0], ax=ax, title=t)

### Previous Results

In [None]:
# 3x1: XE/bs, kaiming_normal, no pos_enc, emb/out weight tying, no encoder, 2 layers, 256/256, single attn
# 52.199187  50.1736    0.730092    lr:1e-4
# 17.318725  13.803769  0.205331    2nd run, lr: 1e-3
# 11.96539   9.991971   0.151522    3rd run, lr: 1e-4     **tfmr

# Loss fns
# 43.177135  41.245482  1.185836    LabelSmoothing (KL/bs)
# 4215.8745  4012.3099  0.957795    LabelSmoothing (KL)
# 45.200728  43.662235  0.753881    LabelSmoothing (KL/bs w/out padding logic)  perplexity < BLEU/acc
# 5340.5770  5129.8346  0.738853    XE                    --- scaling by BS not a factor
# 52.199187  50.1736    0.730092    XE/bs  **

# Initialization/Activation
# 52.199187  50.1736    0.730092    kaiming_normal  (ff activation: leaky_relu)
# 52.987157  50.976083  0.758013    kaiming_uniform
# 49.553779  47.270616  0.702353    xavier_normal
# 49.380884  47.345175  0.673508    xavier_uniform  (ff activation: leaky_relu)   **
# 49.849071  47.369221  0.708807    xavier_uniform  (ff activation: relu)

# Positional Encoding
# 35.426869  30.449724  0.423573    target only; no embed/out weight tying   **
# 35.390268  31.216677  0.45097     "" ; w/ scaling factor [* math.sqrt(d_model)]
# 45.596748  42.213515  0.572662    target only; w/ embed/out weight tying

# Encoder
# 34.46771   29.207974  0.408425    **

# Attention
# 32.087941  26.431674  0.362717    SingleHead - linear layers for (q,k,v)
# 37.739475  32.962685  0.454753    MultiHead (8)
# 36.09057   31.346036  0.432342    MultiHead (4)
# 31.069687  26.00539   0.360032    MultiHead (1)     **

# N layers (6)
# 51.119606  48.809905  0.645302

# d_model = 512
# 21.032752  17.234559  0.230763    ~11:35  **

# em_sz
# 27.589121  22.698359  0.307399    128  ~14:09
# 28.320952  24.183105  0.338207    512  ~11:31

# include layer_norm in encoder/decoder
# 24.035412  19.824303  0.270881   modified: self.norm(x + self.dropout(sublayer(x)))
# 20.347157  17.23408   0.228328   original:  x + self.dropout(sublayer(self.norm(x)))   **tfmr_experiment
# 10.732795  9.875141   0.12633    2nd run



# 3x2 attention/pos_enc experiments
# 44.391676  40.633241  0.308526    baseline (src before self attn)    ~1h 7m
# 18.111974  18.16437   0.117756    10 cycles                          ~1h 52m      tmp_base
# 57.85619   51.015644  0.418202    pos_enc2d/no encoder; 256 + conv 1   ~44m       
# 27.356499  24.327117  0.17227     2nd run - attention not working well            tmp_2
# 16.581052  15.659028  0.104724    3rd run                                         tmp_3
# 55.944376  52.067722  0.417235    baseline w/ pos_enc2d; 256 + conv1    ~57m      tmp2


# 3x1, 1e-4, 10 cycles, N: 2
# 23.452119  20.750421  0.30148     conv, enc: pos2d, out: Linear    ~20m
# 10.747468  11.945359  0.148952    2nd run

# 22.816471  20.652263  0.300892    w/ STN (32/7/5/32/32) before img_encoder
# 24.154148  21.474449  0.314114    w/ STN (32/7/5/64/64) before img_encoder
# 24.315829  21.928237  0.319867    adding bn after conv (no STN)
# 25.852014  23.393568  0.345979    "", w/ STN(em_sz, d_model, 2) after img_encoder
# 22.586863  20.337497  0.294407    "", w/ STN (32/7/5/32/32) before img_encoder     ~21m
# 11.062696  11.752348  0.146163    2nd run

# 17.356167  15.935936  0.223704    linear, enc: None, out: Linear    ~19m
# 15.614188  14.634603  0.205307    conv/bn, enc: None, out: Linear    ~19m   **
# 17.618474  15.867109  0.220721    ""    ~20m
# 8.634967   9.732577   0.13152     2nd run    'tmp'  [src-attn tracks well]

# 17.213856  16.036616  0.228302    conv/bn, enc: None, out: MLP    ~20m
# 17.269863  16.052437  0.224499    conv/bn, enc: Encoder, out: Linear    ~23m
# 17.999301  16.085523  0.223856    conv/bn, enc: pos2d, out: Linear    ~19m
# 19.352792  17.537506  0.249201    conv/bn, STN(32/7/5), enc: pos2d, out: Linear ~21m
# 21.778735  19.844897  0.288944    conv/bn, STN(32/7/5), enc: pos2d, out: MLP    ~22m

# 15.627235  14.005776  0.18804     dec: w/ self-attn + src-attn
# 7.056934   7.818964   0.10034     2nd run    'tmp2'    [loss decrease but src-attn tracks worse]

# 17.841002  16.682544  0.221875    enc: src-attn, dec: self-attn, mixer: +tanh   ~22m
# 6.         8.         0.12        2nd run

# 15.780708  14.452696  0.205315    enc: src-attn, dec: self-attn, mixer: cat/lin   ~20m    'tmp3'

# N: 4
# 14.322309  13.357375  0.190247    enc: src-attn, dec: self-attn, mixer: cat/lin   ~25m
# 7.48091    7.132944   0.092341    "", dec: preloaded LM

# paragraph, 1e-4, 10 cycles, N: 4
# 691.015346 621.679181 0.645048    enc: src-attn, dec: preloaded LM, mixer: cat/lin

# Test

In [None]:
from scipy.ndimage import gaussian_filter
k=16

def torch_scale_attns(attns):
    bs,sl,hw = attns.shape
    num = int(math.sqrt(hw))   # sz // k
    mod = attns.view(bs,sl,num,num)
    scaled = F.interpolate(mod, size=sz)
    return scaled  #([bs, sl, h, w])

def g_filter(att):
    return gaussian_filter(att, sigma=k)

In [None]:
def self_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].self_attn.attn.data.cpu()
def source_attn(layer=-1): return learn.model.transformer.decoder.layers[layer].src_attn.attn.data.cpu()

In [None]:
def cer(preds, targs):
    bs,sl = targs.size()
    
    res = torch.argmax(preds, dim=2)
    error = 0
    for i in range(bs):
        p = char_label_text(res[i])   #.replace(' ', '')
        t = char_label_text(targs[i]) #.replace(' ', '')
        error += Lev.distance(t, p)/len(t)
    return error/bs

def char_label_text(pred):
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)]
    return ''.join([itos[i] for i in nonzero])

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    if im.shape[0] == 3: im = image2np(im.data)
    ax.imshow(im, alpha=alpha)
    if title: ax.set_title(title)
    return ax

In [None]:
x,y = next(iter(learn.data.valid_dl))
# x,y = learn.data.one_batch(ds_type=DatasetType.Valid)
imgs = learn.data.denorm(x)

## With pretrained LM

In [None]:
learn.load('v1_gelu_512')
None

In [None]:
learn.model.eval()

lm = get_language_model(AWD_LSTM, len(itos))

f = Path('data/wikitext/wikitext-2-raw/models/wiki2_base.pth')
sd = torch.load(f)

lm.load_state_dict(sd['model'])
lm.to(device)
lm.reset()

In [None]:
lm_preds = learn.model(x, lm=lm.eval())

In [None]:
lm_res = torch.argmax(lm_preds, dim=-1)

In [None]:
''.join([itos[i] for i in lm_res[0]])

In [None]:
show_img(imgs[0])

In [None]:
#lm
fig, axes = plt.subplots(2,1, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(lm_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

## Uploaded Images

In [None]:
def thresh_edit(fname, thresh=100, bg=245):
    im = Image.open(fname).convert('L')  #grayscale
    np_im = np.array(im)
    im.close()
    thresh_mask = np_im > thresh
    np_im[thresh_mask] = bg
    return Image.fromarray(np_im, 'L')

In [None]:
im = thresh_edit(PATH/'test3.png')
show_img(im, figsize=(15,15))

In [None]:
e = 'edit_'+str(fname.name)
edited_fname = PATH/e

In [None]:
edited_im.save(edited_fname)

In [None]:
fname = PATH/'test/edit_test4.png'
im = open_image(fname)

In [None]:
seq,res,preds = learn.predict(im)
print(seq)

In [None]:
# r = torch.tensor([g_res], dtype=torch.long, device=device)
truth = "This is a test letter. I hope this\nworks but I'm not sure it will.\nMy handwriting is not very good."
Lev.distance(truth, str(seq))/len(truth)

## Results

In [None]:
# x,y = next(v_dl)
# imgs = denorm(x)

learn.model.eval()

shifted_y = rshift(y).long()
tgt_mask = subsequent_mask(shifted_y.size(-1))
v_preds = learn.model(x, shifted_y, tgt_mask)
v_res = torch.argmax(v_preds, dim=-1)
# v_attn = source_attn()

g_preds = learn.model(x)
g_res = torch.argmax(g_preds, dim=-1)
# g_attn = source_attn()

In [None]:
v = [learn.loss_func(v_preds, y).item(), cer(v_preds, y)]
print(f'valid:     {str(v[0])[:7]}   {str(v[1])[:7]}')

g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)]
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[:7]}')

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(3,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
# tfmr full
#             loss       cer       acc
# valid:     2.42856   0.02981   0.98177   3x1, 256/60, 'tfmr_full_3x1_single_attn'
# greedy:    11.9438   0.02745   0.93385

# valid:     1.98831   0.01858   0.98505   3x1, 256/60, 'exp_3x1_256'
# greedy:    16.3832   0.02167   0.90944

# valid:     1.47604   0.00797   0.99523   3x2, 400/45,  'tfmr_full_3x2'
# greedy:    21.4682   0.01097   0.93762

# valid:     7.30494   0.01206   0.99086   lg,  512/30,  'tfmr_full_lg'
# greedy:    434.610   0.02398   0.69553

# valid:     25.4131   0.04500   0.96773   lg, 512/30,   'tfmr_lg_LM_mixer'
# greedy:    734.834   0.15256   0.48912

# valid:     10.1481   0.00501   0.99602   cat9-12,  800/8,  'tfmr_cat9-12_full_800'
# greedy:    1330.63   0.04525   0.53181

# valid:     62.0793   0.05716   0.96214   pg,  512/20  'tfmr_full_paragraph'  (cpu)
# greedy:    1739.62   0.08347   0.43309
# beam:                0.07958
# valid:     71.9120   0.06819   0.95457   "", 2nd batch (gpu)
# greedy:    2027.24   0.11823   0.39238
# beam:                0.10697
# valid:     36.3009   0.03397   0.97633   "", 3rd batch (gpu)
# greedy:    1759.59   0.07070   0.40949
    
# valid:     50.1414   0.04137   0.97004   pg,  1000/5,  'tfmr_pg_1000'
# greedy:    2579.24   0.34178   0.18623

# valid:     56.2722   0.04167   0.96923   pg,  1000/5,  'tfmr_catpg_1000'
# greedy:    2432.54   0.37192   0.26174


# valid:     3.43745   0.05444   0.98735   mix(new)  'tfmr_mix_words_400'
# greedy:    43.6545   0.06126   0.91407

# valid:     2.96891   0.03338   0.98932   mix(new)  'tfmr_mix_words_512'
# greedy:    28.9982   0.03949   0.94500
# valid:     2.02411   0.02128   0.99302   5 more cycles
# greedy:    19.4735   0.02450   0.96104

# valid:     41.9800   0.03879   0.97535   pg, 'tfmr_mix_words_512'
# greedy:    1602.68   0.06259   0.49110


# valid:     53.9049   0.04822   0.96622   pg, 'tfmr_8head_512_mix'
# greedy:    1384.72   0.06304   0.51696


# valid:     5.24461   0.00634   mix   'v1_gelu_512'
# greedy:    217.438   0.02028

## 3x1 256

### Images

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#valid
fig, axes = plt.subplots(3,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(3,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

### Source Attn

In [None]:
idx = 1
img = imgs[idx]

v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn)[idx])

g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn)[idx])

In [None]:
#valid
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(v_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])

In [None]:
#greedy
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

## 3x2 400

### Images

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i])
    ax=show_img(imgs[i], ax=ax, title=p)

### Source Attn

In [None]:
idx = 1
img = imgs[idx]

v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn)[idx])

g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn)[idx])

In [None]:
#valid
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(v_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])

In [None]:
#greedy
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

## lg 512

### Images

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

### Source Attn

In [None]:
idx = 1
img = imgs[idx]

v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn)[idx])

g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn)[idx])

In [None]:
#valid
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(v_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])

In [None]:
#greedy
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

## cat 9-12

### Images

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

### Source Attn

In [None]:
idx = 1
img = imgs[idx]

In [None]:
#valid
v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn[:idx, :20]))[-1]

fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    
    a = g_filter(v_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])
    
del v_attns

In [None]:
#greedy
g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn[:idx, :20]))[-1]

fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

## pg

### Images

In [None]:
#valid
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(v_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

In [None]:
#greedy
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

### Source Attn

In [None]:
idx = 1
img = imgs[idx]

In [None]:
#valid
v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn[:idx, :20]))[-1]

fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    
    a = g_filter(v_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])
    
del v_attns

In [None]:
#greedy
g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn[:idx, :20]))[-1]

fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

# Beam Search

## Multi-Batch Beam Decode

In [None]:
def beam_cer(preds, targs):
    bs = preds.size(0)
    error = 0
    for i in range(bs):
        p = char_label_text(preds[i])
        t = char_label_text(targs[i])
        error += _cer(t,p)
    return error/bs

In [None]:
def repeat_interleave(tensor,n):
    res = []
    for i in range(tensor.size(0)):
        for _ in range(n): res.append(tensor[i])
    return torch.stack(res)

In [None]:
# def normalize_score(score, i, alpha=0.6):
#     length_penalty = math.pow((5 + i)/6, alpha)
#     return score/length_penalty

### Beam Object

In [None]:
import heapq

class Beam(object):
    def __init__(self, beam_width=10):
        self.heap = list()
        self.beam_width = beam_width
        self.best_score = None

    def add(self, score, complete, seq):        
        heapq.heappush(self.heap, (score, complete, seq))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
            
    def get_seq(self):
        return [b[-1] for b in self.heap]
        
    def __iter__(self):
        return iter(self.heap)

In [None]:
class BeamSearch(nn.Module):
    def __init__(self, net, beam_width, seq_len, end_tok=0):
        super(BeamSearch, self).__init__()
        self.img_enc = net.img_enc
        self.transformer = net.transformer
        self.bw = beam_width
        self.seq_len = seq_len
        self.end_tok = end_tok
        self.feats = None
        self.beams = []
            
    def forward(self, src):
        with torch.no_grad():
            bs = src.size(0)
            
            # initialize beam per bs
            for _ in range(bs):
                beam = Beam(self.bw)
                beam.add(0.0, False, [1])
                self.beams.append(beam)
                
            # encode src
            self.feats = self.transformer.encode(self.img_enc(src))
            
            for i in tqdm(range(seq_len)):
                # gather sequences from beams; combine into tensor (bs*bw)
                prev_seq = torch.from_numpy(np.stack([b.get_seq() for b in self.beams])).view(-1,i+1)
                
                # generate new possibilities
                log_probs, chars = self.prob_func(prev_seq.to(device))
                log_probs, chars = log_probs.view(bs,-1,self.bw), chars.view(bs,-1,self.bw)
                
                for j in range(bs):
                    curr_beam = Beam(self.bw)
                    for k,(score, complete, seq) in enumerate(self.beams[j]):
                        for l,c in zip(log_probs[j,k],chars[j,k]):
                            log_prob,char = l.item(), c.item()
                            curr_beam.add((score+log_prob), (char==self.end_tok), seq+[char])
                    self.beams[j] = curr_beam
  
                # return if all max beams are complete
                if (self.top_complete()==True).all(): break
                    
                # expand feats to match beam size (only on 2nd run)
                if i==0: self.feats = repeat_interleave(self.feats, self.bw)

            return self.top_seq()

    def top_complete(self): return np.stack([max(b)[1] for b in self.beams])
    def top_seq(self): return torch.from_numpy(np.stack([max(b)[-1] for b in self.beams]))[:,1:]

    def prob_func(self, tgt):
        mask = subsequent_mask(tgt.size(-1))
        dec_outs = self.transformer.decode(self.feats, tgt, mask)
        logits = self.transformer.generate(dec_outs[:,-1])
        log_probs = logits - torch.logsumexp(logits, -1, keepdim=True) # more stable than F.softmax(logits,-1).log()
        return torch.topk(log_probs, self.bw, dim=-1)

In [None]:
search = BeamSearch(learn.model, 3, seq_len)
b_res = search(x)

# bw: 3, sl: 350, ~56s, 0.02491    # no normalized_score

In [None]:
beam_cer(b_res,y)

### Tensors

In [None]:
class BeamSearch(nn.Module):
    def __init__(self, net, beam_width, seq_len):
        super(BeamSearch, self).__init__()
        self.img_enc = net.img_enc
        self.transformer = net.transformer
        self.bw = beam_width
        self.seq_len = seq_len
        
        self.feats = None
        self.beam = None
        self.scores = None
        
        net.eval()
    
    def forward(self, src):
        with torch.no_grad():
            bs = src.size(0)
            
            # encode src
            self.feats = self.transformer.encode(self.img_enc(src))
            
            # initialize globals (beam=1 for first iteration; 3 thereafter)
            self.beam = torch.ones((bs,1), device=device, dtype=torch.long)
            self.scores = torch.zeros((bs,1), device=device, dtype=torch.float)
            
            for i in tqdm(range(seq_len)):
                # generate new topk chars per beam(bs*bw)
                log_probs, chars = self.prob_func(self.beam)  #(bs*beam, 3)

                # compute local scores
                scores = self.scores + log_probs
                
                # compute new beams per batch
                new_scores, idxs = torch.topk(scores.view(bs,-1), self.bw, dim=-1) #(bs, 3)
                self.scores = new_scores.view(-1,1)
                
                # set up new beam:
                nxt = torch.stack([c[i] for c,i in zip(chars.view(bs,-1),idxs)]).view(-1,1)
                pre = torch.stack([b[i//self.bw] for b,i in zip(self.beam.view(bs,-1,i+1),idxs)]).view(-1,i+1)
                                    
                # update globals
                self.beam = torch.cat([pre,nxt], dim=1)

                # end when top of beams are complete
                if self.top_complete(): break
                                                        
                # expand feats to match beam size (only on 2nd run)
                if i==0: self.feats = repeat_interleave(self.feats, self.bw) #.repeat(self.bw,1,1)
                    
            return self.top_sequences(), self.top_scores() 


    def top_complete(self): return (self.top_sequences()[:,-1]==0).all().item()   #byte tensor
    def top_sequences(self): return self.beam.squeeze()[0::self.bw][:,1:]
    def top_scores(self): return self.scores.squeeze()[0::self.bw]

    def prob_func(self, tgt):
        mask = subsequent_mask(tgt.size(-1))
        dec_outs = self.transformer.decode(self.feats, tgt, mask)
        logits = self.transformer.generate(dec_outs[:,-1])
        log_probs = logits - torch.logsumexp(logits, -1, keepdim=True) # more stable than F.softmax(logits,-1).log()
        return torch.topk(log_probs, self.bw, dim=-1)

In [None]:
search = BeamSearch(learn.model, 3, seq_len)
b_res,score = search(x)

# lg, sl: 250
# bw: 1, ~12s, 0.02398
# bw: 3, ~28s, 0.02378
# bw: 5, ~46s, 0.02378

# pg: 1000,5
# bs: 3, ~40s, 0.33227  (greedy: 0.28---)
# ''   , ~30s, 0.22635  (greedy: 0.23726)

# pg: 800,8
# bs: 3, ~51s, 0.25424  (greedy: 0.28348)

In [None]:
beam_cer(b_res,y)

In [None]:
#beam
fig, axes = plt.subplots(1,3, gridspec_kw={'hspace': 0.4}, figsize=(20, 10))
for i,ax in enumerate(axes.flat):
    p = char_label_text(b_res[i], chunk=55)
    ax=show_img(imgs[i], ax=ax, title=p)

## Single Beam Decode

In [None]:
# https://geekyisawesome.blogspot.com/2016/10/using-beam-search-to-generate-most.html

import heapq

class Beam(object):
    '''
    For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used.
    This is so that if two prefixes have equal probabilities then a complete sentence
    is preferred over an incomplete one since (0.5, False) < (0.5, True)
    '''

    def __init__(self, beam_width=10):
        self.heap = list()
        self.beam_width = beam_width
        self.best_score = None

    def add(self, score, complete, seq):        
        # keep track of best_score so far
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            
        # only add to beam if score is not more than beam_width below the best_score
        if score > self.best_score-self.beam_width:
            heapq.heappush(self.heap, (score, complete, seq))
            
        # maintain beam_width
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
                
    def __iter__(self):
        return iter(self.heap)

In [None]:
def beamsearch(prob_fn, seq_len, beam_width=5, start_tok=1, end_tok=3):
    prev_beam = Beam(beam_width)
    prev_beam.add(0.0, False, [start_tok])
    
    for i in tqdm(range(seq_len)):
        curr_beam = Beam(beam_width)
        
        # iterate over each beam
        for (score, complete, seq) in prev_beam:
            if complete == True:
                None  # only keep the completed best beam!!
#                 curr_beam.add(score, True, seq)
            else:
                # iterate through topk chars, calculating scores and adding to the beam.
                log_probs, chars = prob_fn(seq)
                for log_prob, char in zip(log_probs, chars): 
                    log_prob,char = log_prob.item(), char.item()
                    score += log_prob   #log probabilities are additive
#                     score = score_func(score, len(seq))
                    curr_beam.add(score, (char==end_tok), seq+[char])
        
        (best_score, best_complete, best_seq) = max(curr_beam)
        if best_complete == True: return (best_seq[1:], best_score)   # returns first complete beam not best...
            
        prev_beam = curr_beam
        
    (best_score, best_complete, best_seq) = max(curr_beam)
    return (best_seq[1:], best_score)

In [None]:
def beam_decode(net, src, beam_width, seq_len):
    net.eval()
    with torch.no_grad():
        feats = net.transformer.encode(net.img_enc(src))        
        return beamsearch(partial(prob_func, net=net, feats=feats, beam_width=beam_width), seq_len, beam_width)
    
def prob_func(tgt, net=None, feats=None, beam_width=5):
    tgt = torch.tensor([tgt], dtype=torch.long, device=device)
    mask = subsequent_mask(tgt.size(-1))
    dec_outs = net.transformer.decode(feats, tgt, mask)
    logits = net.transformer.generate(dec_outs[:,-1])
    
    log_probs = logits - torch.logsumexp(logits, 1)  # more numerically stable
    # log_probs = F.softmax(logits, -1).log()
    
    return torch.topk(log_probs.squeeze(0), beam_width, dim=-1)
#     return zip(res[0][0].detach(),res[1][0].detach())

def score_func(log_probs, i, alpha=0.6):
    length_penalty = math.pow((5 + i)/6, alpha)
    return log_probs/length_penalty

In [None]:
idx = 2
x1 = x[idx][None]
y1 = y[idx][None]

In [None]:
b_res, score = beam_decode(learn.model, image, 3, seq_len)    #294, 3m18s
# 294 - 1m40s
# 294 - 1m45s (w/ score_func)
# 295 - 22s (beam_width=1 ~ greedy)

In [None]:
r = torch.tensor([b_res], dtype=torch.long, device=device)
p = char_label_text(r)

_cer(truth, p)

In [None]:
# valid
p = char_label_text(v_res[idx][None])
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
# greedy
p = char_label_text(g_res[idx][None])
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
# beam (sz=3)
r = torch.tensor([b_res], dtype=torch.long, device=device)
p = char_label_text(r)
t = char_label_text(y1[0])
_cer(t,p)

In [None]:
stoi = {k:i for i,k in enumerate(itos)}

In [None]:
print(char_label_text(g_res[idx][None]))

In [None]:
st = ''.join([itos[i] for i in b_res])
p = '\n'.join(textwrap.wrap(st, 70))
show_img(denorm(x1)[0], figsize=(10,10), title=p)

## Source Attn

In [None]:
idx = 0
img = imgs[idx]

v_chars = v_res[idx]
v_attns = to_np(torch_scale_attns(v_attn)[idx])

g_chars = g_res[idx]
g_attns = to_np(torch_scale_attns(g_attn)[idx])

In [None]:
#valid
fig, axes = plt.subplots(5,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(vv_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[v_chars[i].item()])

In [None]:
#greedy
fig, axes = plt.subplots(6,4, gridspec_kw={'hspace': 0.3}, figsize=(20, 20))
for i,ax in enumerate(axes.flat):
    a = g_filter(g_attns[i])
    ax.imshow(img, alpha=None)
    ax.imshow(a, cmap='Blues', interpolation='nearest', alpha=0.3)
    ax.set_title(itos[g_chars[i].item()])

# Attention Visualizations

In [None]:
def transparent_cmap(cmap, N=5):
    "Copy colormap and set alpha values"
    mycmap = matplotlib.cm.get_cmap(cmap, N)
    mycmap._init()
    mycmap._lut[:,-1] = np.linspace(0, 0.6, N+3)
    return mycmap

#Use base cmap to create transparent
# mycmap = transparent_cmap(plt.cm.Reds)

In [None]:
def show_attn(img, attns, chars, ax, color, showChars=True):
    for i in range(attns.shape[0]):
        c = chars[i].item()
        if c not in [0,1,2,3]:
            a = g_filter(attns[i])
            y,x = scipy.ndimage.center_of_mass(a)
            #sns.heatmap(a, cmap=mycmap, cbar=False, ax=ax)
            ax.imshow(a, cmap=transparent_cmap(color), interpolation='nearest')
            if showChars: ax.text(x-8,y-10,itos[c], fontsize=15)

    ax.set_title(char_label_text(chars))
    ax.imshow(img, alpha=0.6)

In [None]:
def thresh_attn(attn, thresh=0.1):
    zeros = torch.zeros_like(attn)
    new = torch.where(attn >= thresh, attn, zeros)
    
    # some attns will not have a value over the thresh in which case
    # we need to insert top k value at appropriate index
    vals, idxs = torch.topk(attn, 1, dim=-1)
    
    # reshape
    flat_new = new.flatten(0,1)
    vals = vals.flatten()
    idxs = idxs.flatten()
    
    for i in range(flat_new.size(0)):
        flat_new[i,idxs[i]] = vals[i]

    new = flat_new.view_as(new)
    return new

## Decoder Self-Attention

In [None]:
sns.set_context(context="notebook")

def draw(data, x, y, ax):
    return sns.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
                       cmap='YlOrRd', linewidths=0.05, cbar=False, ax=ax)

for layer in range(4):
    print("Decoder Self-Attention Layer", layer+1)

    fig, axes = plt.subplots(1,4, figsize=(20, 10))
    for i,ax in enumerate(axes.flat):
        # greedy decoding (no access to true values)
        pred = char_split_text(g_res[i])[20:40]
        shifted_y = rshift(g_res.float()).long()
        true = char_split_text(shifted_y[i])[20:40]
        g = draw(self_attn(layer)[i].data[20:40, 20:40], true, pred, ax=ax)
        g.set_yticklabels(g.get_yticklabels(), rotation=0) 
        g.set_xticklabels(g.get_xticklabels(), rotation=0) 
    plt.show()

## Decoder Source-Attention

In [None]:
fig, axes = plt.subplots(1,4, gridspec_kw={'hspace': 0.5}, figsize=(20, 10))
    
for idx in range(len(axes.flat)//4):
    img = imgs[idx]
    g_chars = g_res[idx]
    
    # 4 attn layers
    for h in range(4):
        attn = source_attn(h)
        g_attns = to_np(torch_scale_attns(attn)[idx])

        show_attn(img, g_attns, g_chars, axes[idx,h], 'YlGn', showChars=False)
        axes[idx,h].set_title(f'layer {h+1}')

## Validation vs Greedy (final layer src-attn)

In [None]:
v_scaled_attns = torch_scale_attns(thresh_attn(v_attn))
g_scaled_attns = torch_scale_attns(thresh_attn(g_attn))

In [None]:
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.5}, figsize=(20, 20))
for idx in range(len(axes.flat)//2):
    img = imgs[idx]

    v_chars = v_res[idx]
    v_attns = to_np(v_scaled_attns[idx])

    g_chars = g_res[idx]
    g_attns = to_np(g_scaled_attns[idx])
    
    # valid
    show_attn(img, v_attns, v_chars, axes[idx,0], 'YlOrRd')
    # greedy
    show_attn(img, g_attns, g_chars, axes[idx,1], 'YlGn')

# Individual Examination

In [None]:
idx=0
img = imgs[idx]

g_chars = g_res[idx]
g_scaled_attns = torch_scale_attns(thresh_attn(g_attn)[0:1])  # passing in bs of 1
g_attns = to_np(g_scaled_attns[0])  # removing bs

fig, ax = plt.subplots(1,1, figsize=(20, 20))
show_attn(img, g_attns, g_chars, ax, 'YlGn')

In [None]:
sns.set_context(context="notebook")

def draw(data, x, y, ax):
    mask = np.zeros_like(data)
    mask[np.triu_indices_from(mask, k=1)] = True
    return sns.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
                       mask=mask, cmap='YlOrRd', linewidths=0.05, cbar=False, ax=ax)

for layer in range(4):
    print("Decoder Self-Attention Layer", layer+1)

    fig, ax = plt.subplots(1,1, figsize=(20, 10))
    i = 1
    # greedy decoding (no access to true values)
    pred = char_split_text(g_res[i])[230:280]
    shifted_y = rshift(g_res.float()).long()
    true = char_split_text(shifted_y[i])[230:280]
    g = draw(self_attn(layer)[i].data[230:280, 230:280], true, pred, ax=ax)
    g.set_yticklabels(g.get_yticklabels(), rotation=0) 
    g.set_xticklabels(g.get_xticklabels(), rotation=0) 
    plt.show()

# Memory utility methods

In [None]:
# prints currently alive Variables     # Tensors and 
import gc
for obj in gc.get_objects():
    try:
        if (hasattr(obj, 'data') and torch.is_tensor(obj.data)):     # torch.is_tensor(obj) or 
            print(obj.name, type(obj), obj.size())
    except:
        pass

In [None]:
def tensor_size(tensor):
    return (tensor.element_size() * tensor.v_res.nelement())

tensor_size(v_res)