# Prelims

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks.tracker import *

import pdb

In [None]:
PATH = Path('data/IAM_handwriting')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None, title=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(im.data), alpha=alpha)
    if title: ax.set_title(title)
    return ax

## Loss, Metrics, Callbacks

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        pred,targ = self.loss_prep(pred, target)
        pred = F.log_softmax(pred, dim=-1)  # need this for KLDivLoss
        true_dist = pred.data.clone()
        true_dist.fill_(self.smoothing / pred.size(1))                  # fill with 0.0012
        true_dist.scatter_(1, targ.data.unsqueeze(1), self.confidence)  # [0.0012, 0.0012, 0.90, 0.0012]
        return F.kl_div(pred, true_dist, reduction='sum')/bs
    
    def loss_prep(self, pred, target):
        "equalize input/target sl; combine bs/sl dimensions"
        bs,tsl = target.shape
        _ ,sl,vocab = pred.shape

        # F.pad( front,back for dimensions: 1,0,2 )
        if sl>tsl: target = F.pad(target, (0,sl-tsl))

        # this should only be used when testing for small seq_lens
        # if tsl>sl: target = target[:,:sl]

        if tsl>sl: pred = F.pad(pred, (0,0,0,tsl-sl))
        # not ideal => adds 96 logits all 0s...

        targ = target.contiguous().view(-1).long()
        pred = pred.contiguous().view(-1, vocab)
        return pred, targ

In [None]:
def cer(preds, targs):
    bs = targs.size(0)
    res = torch.argmax(preds, dim=-1)
    error = 0
    for i in range(bs):
        p = char_label_text(res[i])   #.replace(' ', '')
        t = char_label_text(targs[i]) #.replace(' ', '')
        error += Lev.distance(t, p)/(len(t) or 1)
    return error, bs

def char_label_text(pred, sep=''):
    ints = to_np(pred).astype(int)
    nonzero = ints[np.nonzero(ints)] #[:-1]  #remove eos token
    return sep.join([itos[i] for i in nonzero])

In [None]:
import Levenshtein as Lev

class CER(Callback):
    def __init__(self):
        super().__init__()
        self.name = 'cer'

    def on_epoch_begin(self, **kwargs):
        self.errors, self.total = 0, 0
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        error,size = cer(last_output, last_target)
        self.errors += error
        self.total += size
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self.errors/self.total)

In [None]:
def rshift(tgt, bos_token=1):
    "Shift y to the right by prepending token"
    bos = torch.zeros((tgt.size(0),1), device=device).type_as(tgt) + bos_token
    return torch.cat((bos, tgt[:,:-1]), dim=-1)

def subsequent_mask(size):
    return torch.tril(torch.ones((1,size,size), device=device).byte())
    #return torch.tril(torch.ones((1,1,size,size), device=device).byte())  # complex batches
    
def parallelogram_mask(size, diagonal):
    mask = torch.ones((1,size,size), device=device).byte()
    upper = torch.tril(mask).bool()
    lower = torch.triu(mask, diagonal=-diagonal).bool()
    return (upper & lower).byte()

In [None]:
class TeacherForce(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        s = rshift(last_target).long()
        mask = subsequent_mask(s.size(-1))
        #mask = parallelogram_mask(s.size(-1), 10)
        return {'last_input':(last_input, s, mask), 'last_target':last_target}

# Data

## sm synth dataset

In [None]:
fname = 'edited_sm_synth.csv' #'small_synth_words.csv'
CSV = PATH/fname
FOLDER = 'edited_sm_synth'

df = pd.read_csv(CSV)
len(df)

In [None]:
# sz,bs = 128,100
sz,bs = 256,100

num_lines,seq_len = 4,100

In [None]:
df = df[:20000]
len(df)

## combo

In [None]:
fname = 'combo_145k.csv'
FOLDER = 'combo_cat'

CSV = PATH/fname
df = pd.read_csv(CSV)

sz,bs = 512,10
seq_len = 750
len(df)

In [None]:
# 6 and fewer
df = df[df.num_lines <= 6]
sz,bs = 512,30
seq_len = 600
len(df)

In [None]:
# 6 and greater
df = df[df.num_lines >= 6]
sz,bs = 512,10
seq_len = 750
len(df)

## test combo (no fonts)

In [None]:
FOLDER = 'test_combo'
CSV = PATH/'test_combo.csv'
df = pd.read_csv(CSV)

sz,bs = 512,10
seq_len = 750
len(df)

# ModelData

In [None]:
tfms = get_transforms(do_flip=False, max_zoom=1, max_rotate=0, max_warp=0.1)

def force_gray(image): return image.convert('L').convert('RGB')

def label_collater(samples:BatchSamples, pad_idx:int=0):
    "Function that collect samples and pads ends of labels."
    data = to_data(samples)
    ims, lbls = zip(*data)
    imgs = torch.stack(list(ims))
    if len(data) is 1:
        labels = torch.zeros(1,1).long()
        return imgs, labels    
    max_len = max([len(s) for s in lbls])
    labels = torch.zeros(len(data), max_len+1).long() + pad_idx  # add 1 to max_len to account for bos token
    for i,lbl in enumerate(lbls):
        labels[i,:len(lbl)] = torch.from_numpy(lbl)  #padding end    
    return imgs, labels

In [None]:
itos = pickle.load(open(PATH/'itos.pkl', 'rb'))

In [None]:
class CharTokenizer(BaseTokenizer):
    def tokenizer(self, t:str) -> List[str]: return list(t)+['xxeos']
            
class CharVocab(Vocab):
    def __init__(self, itos:Collection[str]):
        self.itos = itos
        self.stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(self.itos)})

    def textify(self, nums:Collection[int], sep=''):
        return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]

class SequenceList(TextList):    
    def __init__(self, items:Iterator, vocab:Vocab, **kwargs):
        toknizr = Tokenizer(tok_func=CharTokenizer, pre_rules=[], post_rules=[], special_cases=[])
        procs = [TokenizeProcessor(tokenizer=toknizr, include_bos=False),
                 NumericalizeProcessor(vocab=vocab)]
        super().__init__(items, vocab, sep='', pad_idx=0, processor=procs)
    
    def analyze_pred(self, pred:Tensor):
        return torch.argmax(pred, dim=-1)

In [None]:
# set_seed()  # reproducibility
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
#         .split_none()
        .split_by_rand_pct(valid_pct=0.15, seed=42)
        #.label_from_df(label_cls=TextList, sep='', pad_idx=0, vocab=vocab, processor=procs)
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos))
        .transform(tfms, size=sz, resize_method=ResizeMethod.SQUISH)
        #.transform(tfms, size=sz, resize_method=ResizeMethod.PAD, padding_mode='border')
        # maintains aspect ratio but too small for good results => mostly whitespace
        .databunch(bs=bs, device=device, collate_fn=label_collater)
        #.normalize() # this sets x values to an odd range (~.3,-6)
       )

In [None]:
data.show_batch(rows=2, ds_type=DatasetType.Train, figsize=(18,10))

# Transformer Modules

In [None]:
# LayerNorm = nn.LayerNorm
LayerNorm = partial(nn.LayerNorm, eps=1e-4)  # accomodates mixed precision training
# LayerNorm = partial(nn.BatchNorm2d, eps=1e-4)

In [None]:
class SublayerConnection(nn.Module):
    "A residual connection followed by a layer norm.  Note: (for code simplicity) norm is first."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

In [None]:
class EncoderLayer(nn.Module):
    "Encoder: self-attn and feed forward"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, src, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # acts as a weak LM
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[2](x, self.feed_forward)

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    depth = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(depth)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e4)  #changed from: -1e9 to accomodate mixed precision  
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
class SingleHeadedAttention(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(SingleHeadedAttention, self).__init__()
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):        
        query, key, value = [l(x) for l, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        return self.linears[-1](x)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, h=8, dropout=0.2):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h        # assume d_v always equals d_k
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        bs = q.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        q, k, v = [l(x).view(bs, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(self.linears, (q, k, v))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_model*4)
        self.w_2 = nn.Linear(d_model*4, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        log_increment = math.log(1e4) / d_model
        div_term = torch.exp(torch.arange(0.0, d_model, 2) * -log_increment)  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)

        self.register_buffer('pe', pe)    #(1,max_len,d_model)
        # registered buffers are Tensors (not Variables)
        # not a parameter but still want in the state_dict

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# Multi-Resolution Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_adapt, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_adapt = src_adapt
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), self.embed(tgt), tgt_mask)
    
    def encode(self, src):
        return self.encoder(self.src_adapt(src))
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(tgt, src, tgt_mask)
    
    def embed(self, tgt):
        return self.tgt_embed(tgt)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)                  #16x16
        self.c1 = conv_layer(d_model, d_model, stride=2)     #8x8
        self.c2 = conv_layer(d_model, d_model, stride=2)     #4x4
        self.c3 = conv_layer(d_model, d_model, stride=2)     #2x2
                
    def forward(self, x):
        x = self.base(x)
        x1 = self.c1(x)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        cat = torch.cat([x.flatten(2,3),x1.flatten(2,3),x2.flatten(2,3),x3.flatten(2,3)], dim=-1)
        return cat.permute(0,2,1)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:-4]    
        self.base = nn.Sequential(*modules)                  #bs,128,64,64
        
        self.c1 = conv_layer(128, 256, stride=(2,1))     #32
        self.c2 = conv_layer(256, 512, stride=(2,1))     #16
        
    
        self.c3 = conv_layer(d_model, d_model, stride=(2,1))     #8
        self.c4 = conv_layer(d_model, d_model, stride=(2,1))     #4
        self.c5 = conv_layer(d_model, d_model, stride=(2,1))     #2
                
    def forward(self, x):
        x = self.c2(self.c1(self.base(x)))
        x1 = self.c3(x)
        x2 = self.c4(x1)
        x3 = self.c5(x2)
        
        pdb.set_trace()
        
        cat = torch.cat([x,x1,x2,x3], dim=2)
#         cat = torch.cat([x.flatten(2,3),x1.flatten(2,3),x2.flatten(2,3),x3.flatten(2,3)], dim=-1)
        return cat.flatten(2,3).permute(0,2,1)

In [None]:
def make_full_model(vocab, d_model, em_sz, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)
    pos = PositionalEncoding(d_model, drops, 2000)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            nn.Linear(em_sz, d_model), Lambda(lambda x: x.mul_(8)) #, pos
        ),
        nn.Sequential(
            Embeddings(d_model, vocab), pos
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
#                     mask = parallelogram_mask(tgt.size(-1), 10)

                    t = self.transformer.embed(tgt)
    
#                     dec_outs = self.transformer.decode(feats, t[:,-11:])
                    dec_outs = self.transformer.decode(feats, t, mask)

                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def init_params(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, em_sz, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    #transformer.generator.weight = transformer.tgt_embed[0].lut.weight
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER()], callback_fns=[TeacherForce])

## Experimentation

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, attn_type='multi', heads=8)
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
tot = 0
for p in learn.model.parameters():
    tot += p.sum()
tot.data
# -3712.5525
# -3709.0364  w/ weight tying
# 1315.4849   "", N=6
# -8694.8057  "", N=2

In [None]:
# set_seed()
learn.fit(1, lr=1e-3)

# sm(20k subset), 1cycle, 1e-3, d_model/em_sz:512, N:4, multi(8), drops:0.1
# 66.245880	61.198948	0.716475	06:51  baseline
# 66.746071	61.807365	0.705620	06:51  w/ weight tying    ***seems to help CER...
# 79.020279	76.020065	0.936690	06:54  split(img_enc, tfmr), lrs(1e-4,1e-2)
# 68.644821	63.186871	0.724228	07:17  N=6
# 68.506302	62.820026	0.762108	07:13  N=6 (w/out weight tying)
# 66.554543	61.799850	0.848753	06:24  N=2
# 71.178879	65.140076	0.779076	08:08  d_model: 768
# 66.679169	60.917271	0.793617	06:53  attn: single
# 66.829041	61.852840	0.724487	06:54  attn: multi(4)
# 66.263069	61.742168	0.699408	07:13  attn: multi(16)  *
# 76.158379	70.973106	0.710026	06:36  smoothing: 0.0
# 60.895245	54.817890	0.746580	06:39  smoothing: 0.2
# 69.718803	64.245934	0.784929	06:37  w/out src_adaptor .mul_(8)
# 66.709183	61.714687	0.707358	06:38  F.relu

# 66.384956	61.029957	0.737656	02:07  baseline
# 64.477470	59.358719	0.679445	02:09  parallelogram_mask(10)
# 64.270401	59.660389	0.679594	02:06  pos_enc after img_enc
# 64.752815	60.505432	0.697868	02:07  parallelogram_mask(20)
# 63.160973	57.148361	0.674351	02:08  parallelogram_mask(5)
# 59.376633	53.333813	0.649422	02:07  parallelogram_mask(1)
# 64.695885	61.779552	0.791945	02:08  parallelogram_mask(0)
# 59.957340	52.987080	0.611505	02:10  parallelogram_mask(2)  *
# 61.474442	55.054062	0.649738	02:09  parallelogram_mask(3)
# 61.727688	54.237797	0.633881	02:09  weight tying
# 58.971718	52.863750	0.643140	02:33  N=6, drops:0.2
# 58.866985	52.895126	0.655549	02:33  N=6, drops:0.1
# 60.226372	53.237236	0.609828	02:12  multi(16), drops:0.2
# 58.924366	52.224678	0.671493	02:14  multi(16), drops:0.1

In [None]:
# set_seed()
learn.fit_one_cycle(5, max_lr=1e-3)

# 5cycle,1e-3
# 3.510155	3.892951	0.033866  baseline     'sm_256_base'
# greedy:    0.43533   .04138
# 3.156924	3.474473	0.030223  p_mask(2)  'sm_256_pmask2'
# greedy:    0.45049   .04231
# greedy:    1.09986   .16270   w/ truncation -> much faster
# 3.666398	3.924176	0.032230  p_mask(2),drops:0.2
# greedy:    0.46281   .04217

# 3.974781	4.305206	0.036939   mix_res,p_mask(10)   'sm_256_pmask10'
# greedy:    0.47993   .04458
# greedy:    1.04631   .11497    w/ truncation

# 4.040658	4.392500	0.037479   mix_res   'sm_256_mix_res'
# greedy:    0.50428   .04335

In [None]:
learn.save('sm_256_mix_res')

In [None]:
learn.load('sm_256_mix_res'); None

# Full Arch

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        return self.linear(x) * 8

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', heads=8, smoothing=0.1):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, heads=heads)
    transformer.apply(init_params)
    #transformer.generator.weight = transformer.tgt_embed[0].lut.weight
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=smoothing),
                   metrics=[CER()], callback_fns=[TeacherForce])

# TransformerXL

In [None]:
class PositionalEncoding(Module):
    "Encode the position with a sinusoid."
    def __init__(self, d_model:int):
        self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d_model, 2.) / d_model)))

    def forward(self, pos:Tensor):
        inp = torch.ger(pos, self.freq)  #(outer) matrix product of 2 vectors
        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
        return enc

In [None]:
class FeedForward(Module):
    def __init__(self, d_model:int, drops=0.1):
        self.core = nn.Sequential(
            nn.Linear(d_model, d_model*4), nn.ReLU(inplace=True), nn.Dropout(drops),
            nn.Linear(d_model*4, d_model), nn.Dropout(drops)
        )
        self.ln = nn.LayerNorm(d_model)
        
    def forward(self, x):
        out = self.core(x)
        return self.ln(x + out)

In [None]:
class MultiHeadRelativeAttention(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1, bias=False):
        d_head = d_model//n_heads
        self.n_heads, self.d_head = n_heads, d_head
        self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
        self.drop_att,self.drop_res = nn.Dropout(drops),nn.Dropout(drops)
        self.ln = nn.LayerNorm(d_model)
        self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ln(x + self.drop_res(self.out(self._mhra(x, mask=mask, **kwargs))))

    def _mhra(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
        #Notations from the paper:
        #x: input, r: vector of relative distance between two elements
        #u,v: learnable parameters of the model common between layers
        #mask: to avoid cheating, mem: previous hidden states
                
        #x: [bs, sl, d_model]
        #u/v: [n_heads, 1, d_head]
        #r: [sl, d_model]
        #mem: 1st:[0]; 2nd:[bs, sl, d_model]; nth: sl*i-1 up to mem_len
        bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
        context = x if mem is None else torch.cat([mem, x], dim=1)
        wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
        wq = wq[:,-x_len:]
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        # [bs, sl, n_heads, d_head]
        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)   #wk: transposed(-2,-1)
        wkr = self.r_attn(r)
        wkr = wkr.view(seq_len, self.n_heads, self.d_head)
        wkr = wkr.permute(1,2,0)  #transposed ala wk w/out bs
        #### compute attention score (AC is (a) + (c) and BD is (b) + (d) in the paper)
        AC = torch.matmul(wq+u,wk)
        BD = _line_shift(torch.matmul(wq+v, wkr))
        attn_score = (AC + BD).div_(self.d_head ** 0.5)  #scale
        if mask is not None:
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
        attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
        attn_vec = torch.matmul(attn_prob, wv)
        return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
    
def _left_shift(x:Tensor):
    "Shift the line i of `x` by p-i elements to the left."
    bs,nh,n,p = x.size()
    x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)
    x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)
    return x_shift

In [None]:
class DecoderLayerXL(Module):
    def __init__(self, n_heads:int, d_model:int, drops=0.1):
        self.mhra = MultiHeadRelativeAttention(n_heads, d_model, drops=drops)
        self.ff   = FeedForward(d_model, drops=drops)

    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ff(self.mhra(x, mask=mask, **kwargs))

In [None]:
class TransformerXL(Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, vocab_sz:int, d_model:int, n_layers:int, n_heads:int, drops:float=0.1, mem_len:int=150):
        d_head = d_model//n_heads
        self.emb = nn.Embedding(vocab_sz, d_model)
        self.drop_emb = nn.Dropout(drops)
        self.pos_enc = PositionalEncoding(d_model)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head))
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head))
        self.layers = nn.ModuleList([DecoderLayerXL(n_heads, d_model, drops) for k in range(n_layers)])
        self.init = False

    def forward(self, x):
        if not self.init:
            self.reset()
            self.init = True
        bs,x_len = x.size()
        inp = self.drop_emb(self.emb(x)) #.mul_(self.d_model ** 0.5)
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).bool()[None,None]
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids = []
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i]
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        self._update_mems(hids)
        return self.hidden,[core_out]
    
    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()

    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def select_hidden(self, idxs):
        # used in beam search only...
        self.hidden = [h[idxs] for h in self.hidden]

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(tgt, src)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, lm, layer, N):
        super(Decoder, self).__init__()
        self.lm_layer = lm
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src):
        x = self.lm_layer(x)
        
        for layer in self.layers:
            x = layer(x, src)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src):
        x = self.sublayer[0](x, lambda x: self.src_attn(x, src, src))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    lm = TransformerXL(vocab, d_model, n_layers=10, n_heads=8, drops=0.1)
    attn = MultiHeadedAttention(d_model, attn_heads) 
    ff   = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(lm, DecoderLayer(d_model, c(attn), c(ff), drops), N),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    dec_outs = self.transformer.decode(feats, tgt)
                    prob = self.transformer.generate(dec_outs[:,-1])
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt)
            out = self.transformer.generate(dec_outs)
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(len(itos), d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    net = Img2Seq(img_encoder, transformer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

# w/ Transformer LM

In [None]:
class DecoderLayerWithLM(nn.Module):
    "Decoder: self-attn, src-attn, and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = clones(feed_forward, 3)
        self.sublayer = clones(SublayerConnection(size, dropout), 5)  # wraps layer in residual,dropout,norm
 
    def forward(self, x, src, tgt_mask=None):
        # lm
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # shared btw lm and decoder
        lm = self.sublayer[1](x, self.feed_forward[0])
        
        # decoder layer
        dec = self.sublayer[2](x, lambda x: self.src_attn(x, src, src))
        dec = self.sublayer[3](dec, self.feed_forward[1])
        return self.sublayer[4](dec+lm, self.feed_forward[2])

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8, weight_tying=False):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayerWithLM(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    if weight_tying:
        model.generator.weight = model.tgt_embed[0].lut.weight
    
    return model

# w/ AWDLSTM LM integrated

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, pos_enc):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enc = pos_enc
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.pos_enc(tgt), src, tgt_mask)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8   #(cube root of d_model?)
        return x

In [None]:
def lm_mod_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        PositionalEncoding(d_model, drops, 2000),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

In [None]:
class LM(nn.Module):
    def __init__(self, vocab, d_model, n_hidden, n_layers):
        super(LM, self).__init__()
        self.lm = AWD_LSTM(vocab, d_model, n_hidden, n_layers, pad_token=0,
                           hidden_p=0.1, input_p=0.25, embed_p=0.05, weight_p=0.2)
                
    def forward(self, tgts):
        res = self.lm(tgts, from_embeddings=True)
        pdb.set_trace()
        return res

In [None]:
class Mixer(nn.Module):
    def __init__(self, d_model, vocab, drops=0.2):
        super(Mixer, self).__init__()
        self.mixer = nn.Sequential(PositionwiseFeedForward(d_model, drops), LayerNorm(d_model))
        self.generator = nn.Linear(d_model, vocab)
        
    def forward(self, dec_outs, lm_outs):
        mix = self.mixer((lm_outs+outs)/2)
        return self.generator(mix)

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_model, vocab, drops, pad_tok=0):
        super(Embedding, self).__init__()
        self.emb = nn.Embedding(vocab, d_model, padding_idx=pad_tok)
        self.emb_drop = EmbeddingDropout(self.emb, drops)

    def forward(self, x):
        return self.emb_drop(x)  #self.lut(x) * math.sqrt(self.d_model)

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, embedding, lm, mixer):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.embedding = embedding
        self.transformer = transformer
        self.lm = lm
        self.mixer = mixer
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    tgt_emb = self.embedding(tgt)
                    dec_outs = self.transformer.decode(feats, tgt_emb, mask)
                    lm_outs = self.lm(tgt_emb)
                    prob = self.mixer(dec_outs[:,-1], lm_outs[:,-1])
                    #outs = self.lm(dec_outs[:,-1], self.transformer.embed, decode=True)
                    #prob = self.transformer.generate(outs)
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            tgt_emb = self.embedding(tgt)
            dec_outs = self.transformer(feats, tgt_emb, tgt_mask)    # ([bs, sl, d_model])
            lm_outs = self.lm(tgt_emb)
            out = self.mixer(dec_outs, lm_outs)
            #out = self.transformer.generate(outs)            # ([bs, sl, vocab])
        return out

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    vocab = len(data.vocab.itos)
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = lm_mod_model(vocab, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    embedding = Embedding(d_model, vocab, drops, pad_tok=0)
    lm = LM(vocab, d_model, 1400, 3)
    mixer = Mixer(d_model, vocab, drops)
    net = Img2Seq(img_encoder, transformer, embedding, lm, mixer)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

## Add LM to model state_dict

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
# update existing model according to LM modifications
sd['model']['embedding.emb.weight'] = sd['model']['transformer.tgt_embed.0.lut.weight']
sd['model']['embedding.emb_drop.emb.weight'] = sd['model']['transformer.tgt_embed.0.lut.weight']
sd['model']["transformer.pos_enc.pe"] = sd['model']['transformer.tgt_embed.1.pe']
sd['model']['mixer.generator.weight'] = sd['model']['transformer.generator.weight']
sd['model']['mixer.generator.bias'] = sd['model']['transformer.generator.bias']

del sd['model']['transformer.tgt_embed.1.pe']
del sd['model']['transformer.tgt_embed.0.lut.weight']
del sd['model']['transformer.generator.weight']
del sd['model']['transformer.generator.bias']

In [None]:
lm_sd = torch.load('data/wikitext/models/wiki103_lm_enc.pth', map_location=device)

In [None]:
from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd.items():
    name = 'lm.lm.'+k
    new_lm_sd[name] = v

In [None]:
sd['model'].update(new_lm_sd)

learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
# tie weights of transformer embedding and generator w/ lm encodings
learn.model.embedding.emb.weight = learn.model.lm.lm.encoder.weight
learn.model.embedding.emb_drop.emb.weight = learn.model.lm.lm.encoder_dp.emb.weight
learn.model.mixer.generator.weight = learn.model.lm.lm.encoder.weight

In [None]:
learn.save('combo_512_9_wiki103_base_lm')

## load and split learner

In [None]:
learn.load('combo_512_9_wiki103_base_lm')
None

In [None]:
learn.split([learn.model.img_enc, learn.model.embedding, learn.model.transformer, learn.model.lm, learn.model.mixer])
None

In [None]:
learn.layer_groups[4]

In [None]:
learn.freeze_to(-3)
learn.model.img_enc.linear.weight.requires_grad

In [None]:
lrs = slice(2e-5, 2e-4, 2e-3)

# w/ AWDLSTM added

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, tgt_mask=None):
        return self.decode(self.encode(src), tgt, tgt_mask)
    
    def encode(self, src):
        return self.encoder(src)
    
    def decode(self, src, tgt, tgt_mask=None):
        return self.decoder(self.tgt_embed(tgt), src, tgt_mask)
    
    def generate(self, outs):
        return self.generator(outs)

In [None]:
class ResnetBase(nn.Module):
    def __init__(self, em_sz, d_model):
        super().__init__()
        
        slices = {128: -4, 256: -3, 512: -2}
        s = slices[em_sz]

        net = models.resnet34(True)
        modules = list(net.children())[:s]
        self.base = nn.Sequential(*modules)
        
        self.linear = nn.Linear(em_sz, d_model)
        
    def forward(self, x):
        x = self.base(x)
        x = x.flatten(2,3).permute(0,2,1)
        x = self.linear(x) * 8
        return x

In [None]:
def make_full_model(vocab, d_model, N=4, drops=0.2, attn_type='multi', attn_heads=8, weight_tying=False):
    c = deepcopy
    
    if attn_type=='multi':
        attn = MultiHeadedAttention(d_model, attn_heads)
    else:
        attn = SingleHeadedAttention(d_model)
        
    ff = PositionwiseFeedForward(d_model, drops)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drops), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drops), N),
        nn.Sequential(
            Embeddings(d_model, vocab), PositionalEncoding(d_model, drops, 2000)
        ),
        nn.Linear(d_model, vocab),
    )
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
                    
    if weight_tying:
        model.generator.weight = model.tgt_embed[0].lut.weight
    
    return model

In [None]:
class Img2Seq(nn.Module):
    def __init__(self, img_encoder, transformer, lm):
        super(Img2Seq, self).__init__()
        self.img_enc = img_encoder
        self.transformer = transformer
        self.lm = lm
        
    def forward(self, src, tgt=None, tgt_mask=None, seq_len=700):
        # inference (greedy decode)
        if tgt is None:
            with torch.no_grad():
                feats = self.transformer.encode(self.img_enc(src))
                bs = src.size(0)
                tgt = torch.ones((bs,1), dtype=torch.long, device=device)

                res = []
                for i in progress_bar(range(seq_len)):
                    mask = subsequent_mask(tgt.size(-1))
                    dec_outs = self.transformer.decode(feats, tgt, mask)
                    dec_prob = self.transformer.generate(dec_outs[:,-1])
                    lm_prob,_,_ = self.lm(tgt)
                    prob = (dec_prob + lm_prob[:,-1])/2
                    res.append(prob)
                    pred = torch.argmax(prob, dim=-1, keepdim=True)
                    if (pred==0).all(): break
                    tgt = torch.cat([tgt,pred], dim=-1)
                out = torch.stack(res).transpose(1,0).contiguous()
                
        #training
        else:
            feats = self.img_enc(src)
            dec_outs = self.transformer(feats, tgt, tgt_mask)    # ([bs, sl, d_model])
            out = self.transformer.generate(dec_outs)            # ([bs, sl, vocab])
        return out

In [None]:
lm_config = dict(emb_sz=512, n_hid=1400, n_layers=3, pad_token=0, qrnn=False, bidir=False, output_p=0.2,
              hidden_p=0.2, input_p=0.5, embed_p=0.1, weight_p=0.4, tie_weights=True, out_bias=True)

In [None]:
def make_learner(data, d_model, em_sz, N=4, drops=0.2, attn_type='multi', attn_heads=8):
    vocab = len(data.vocab.itos)
    img_encoder = ResnetBase(em_sz, d_model)
    transformer = make_full_model(vocab, d_model, N=N, drops=drops, attn_type=attn_type, attn_heads=attn_heads)
    lm = get_language_model(AWD_LSTM, vocab, lm_config, drop_mult=0.5)
    net = Img2Seq(img_encoder, transformer, lm)
    return Learner(data, net, loss_func=LabelSmoothing(smoothing=0.1),
                    metrics=[CER()], callback_fns=[TeacherForce])

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

## Add LM to model state_dict

In [None]:
learn.model.lm

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
lm_sd = torch.load(PATH/'models/wiki2_lm.pth', map_location=device)

In [None]:
from collections import OrderedDict
new_lm_sd = OrderedDict()
for k, v in lm_sd['model'].items():
    name = 'lm.'+k
    new_lm_sd[name] = v

In [None]:
lm_sd['model'].keys()

In [None]:
sd['model'].update(new_lm_sd)

learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
learn.save('combo_512_9_wiki2_lm')

# Experiments

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0, attn_type='multi', heads=8)
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
learn.load('combo_512_mix_res8')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.data = data

In [None]:
learn.fit_one_cycle(1, max_lr=1e-5, callbacks=[SaveModelCallback(learn, name='combo_512_mix_res9')])
#sm, 5cycle, 1e-3

# 2.474220	4.081600	0.036482  N:4,F.gelu,sz:256,em_sz:512,single,drop:0  'sm_256_1'
# greedy:    1.36697   .066
# 4.430638	4.402001	0.034730  2nd run, lr:1.5e-5, add tfms, drop:0.2   'sm_256_2'
# greedy:    0.37206   .04145

# 5.158808	4.964441	0.042544  "", w/ tfms, drop:0.2   'sm_256_3'
# greedy:    0.50646   .04103
# 4.266788	4.694098	0.039822  2nd run, lr:1.5e-5   'sm_256_4'
# greedy:    0.38972   .03879

# 3.678168	3.962710	0.035466  N:8,tfms   'sm_256_5'
# greedy:    0.38097   .04405

# 2.840161	3.429177	0.030243  N:4,tfms,multi(8)  'sm_256_6'
# greedy:    0.32775   .04201
# greedy:    1.38865   .06943

# 3.377438	3.786088	0.033931   N:6,tfms,drop:0.1,multi(8)   'sm_256_7'
# greedy:    0.36201   .04266  
# greedy:    1.40557   .06902

# 3.989137	4.353169	0.037717   N:4,weight-tying,multi(16)   'sm_256_8'
# greedy:    1.44291   .07299

# 4.806501	5.007887	0.042339   N:4,weight-tying,multi(8),drops:0.2   'sm_256_9'
# greedy:    1.46753   .07424

# 12.817862	11.842313	0.105486   N:4,no init/weight tying,multi(8),drops:0.1   'sm_256_10'

# 3.515456	3.880955	0.033763   N:4,multi(8),drops:0.1   'sm_256_11'
# greedy:    1.40945   .06964

# combo_cat6 - preload 'sm_256_7'
# 7.570516	6.670641	0.015070    'combo_512_7'
# greedy:    6.29406   .02186


# combo 6lines and under, preload 'sm_256_6', 10cycle(1e-3)
# 10.784849	9.144131	0.024880     'combo_512_13'
# greedy:    9.21715   .02788
# upload:    103.456   .44797
#     pg:    148.473   .69869

# combo 6lines and over, preload 'combo_512_13', 5cycle(2e-5)
# 56.732105	38.604137	0.026153    'combo_512_14'
# greedy:    141.022   .02933
# upload:    179.606   .81724

# combo_cat_pg - preload 'combo_512_7'
# 25.674347	21.161726	0.015869    'combo_512_8'
# greedy:    114.878   .03654
#   test:    115.247   .05014

# combo_cat_pg_dl_sorted - preload 'combo_512_8', 5cycle, 1.5e-4
# 9.800321	5.676855	0.006939    'combo_512_9'
# greedy:    34.8757   .01523

#   test:    91.3226   .05102
#   test:    104.685   .05483
#   test:    116.467   .05315

# upload:    115.149   .66801

# imdb_wiki_combo - preload 'combo_512_9'; split: img_enc, transformer; 3cycle, slice(5e-5, 1e-3)
# 20.390352	16.698141	0.018726    'combo_512_12'
# greedy:    60.8384   .02474
#     pg:    112.755   .05204

# combo_145k - fit:1cycle(1e-3); preload combo_512_9; drops:0.2
# 30.075714	31.024599	0.031703	1:50:33

# combo_145k <= 6 - 5cycle(1e-4); preload sm_256_mix_res
# 19.502026	17.365313	0.044436    'combo_512_mix_res'
# greedy:    14.6870   .04744
# upload:    75.1357   .47955
#     pg:    142.710   .66233

# test_combo 65k (sm,cat,pg,dl)
# 18.262995	15.091379	0.039351   5cycle(1e-5)  'combo_512_mix_res5'
# greedy:    129.968   .03932
# upload:    81.6835   .21794
#     pg:    127.317   .05613

# 16.678631	12.867742	0.035985   1cycle(1e-5), drops:0.1   'combo_512_mix_res6'
# 14.878900	12.797726	0.035834   1cycle(1e-6), drops:0.1   'combo_512_mix_res7'
# greedy:    34.3988   .05552
# upload:    75.5511   .21010
#     pg:    118.963   .05287

# 16.086544	12.320695	0.034263   1cycle(1e-5), drops:0.1   'combo_512_mix_res8'
# 11.640953	11.339241	0.032201   1cycle(1e-5), drops:0     'combo_512_mix_res9'
# greedy:    31.5506   .05410
#   test:    74.4820   .22067
#     pg:    105.182   .05076

# test_combo 65k - sz:1024
# --fail--   preload combo_512_mix_res5 drop:0.1, 3cycle(2e-4)  



# Experiment - Words

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')

In [None]:
learn.load('combo_512_9')
None

In [None]:
learn.lr_find()
learn.recorder.plot(suggestion=True)

In [None]:
learn.fit_one_cycle(15, max_lr=1e-3, callbacks=[SaveModelCallback(learn, name='word_sm_2')])

# word 60k
# sm, 5cycle, 1e-3

# 22.903589	22.401066	0.520125   N:6,em_sz:512,tfms,drop:0.1,multi(8)  'word_sm_1'
# 32.389709	31.828810	0.746749   N:4,em_sz:512,tfms,drop:0.1,single
# 34.492157	32.719650	0.776098   "",em_sz:256

# word 7k (auto-tokenize from data)
# sm, 5cycle, 1e-3
# 16.987526	16.639275	0.421535   N:4,em_sz:512,tfms,drop:0.1,single
# 10.224819	10.781199	0.297889   2nd run
# 6.486265	7.934221	0.225055   3rd run    'word_sm_4'

# word 10k (itos from larger txt files)
# sm, 15cycle, 1e-3
# 3.893518	6.037587	0.209027   N:4,em_sz:512,tfms,drop:0.1,single   'word_sm_2'

# Greedy

In [None]:
learn = make_learner(data, 512, 512, N=6, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
sd = torch.load(PATH/'models/combo_512_9.pth', map_location=device)

In [None]:
sd['model']["transformer.src_adapt.0.weight"] = sd['model']['img_enc.linear.weight']
sd['model']["transformer.src_adapt.0.bias"] = sd['model']['img_enc.linear.bias']

In [None]:
learn.model.load_state_dict(sd['model'], strict=False)

In [None]:
# learn.load('combo_512_9_wiki2_lm')
learn.load('combo_512_mix_res')
None

In [None]:
def full_test(learn, sl, dl=data.valid_dl, batches=10):
    learn.model.eval()
    iterable = iter(dl)
    g_loss,g_cer=0,0
    if batches is None:
        batches = len(dl.dl.dataset)//bs
    for i in progress_bar(range(batches)):
        x,y = next(iterable)
        g_preds = learn.model(x, seq_len=sl)
        g_res = torch.argmax(g_preds, dim=-1)
        g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]
        g_loss+=g[0]
        g_cer+=g[1]
    return [g_loss/batches, g_cer/batches]

In [None]:
# set_seed()
g = full_test(learn, seq_len)

In [None]:
print(f'greedy:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
# losses = np.array([learn.loss_func(g_preds[i:i+1],y[i:i+1]).item() for i in range(bs)])
# cers = np.array([cer(g_preds[i:i+1],y[i:i+1])[0] for i in range(bs)])

In [None]:
# set_seed()
x,y = next(iter(learn.data.valid_dl))

g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]

In [None]:
#greedy
fig, axes = plt.subplots(2,3, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    p = char_label_text(g_res[i], sep='')
    ax=show_img(x[i], ax=ax, title=p)

# Test

### data

In [None]:
FOLDER = 'uploads'
df = pd.read_csv(PATH/'uploads.csv')
len(df)

sz,bs = 512,14
seq_len = 700

In [None]:
FOLDER = 'paragraphs'
df = pd.read_csv(PATH/'test_pg.csv')
len(df)

sz,bs = 512,15
seq_len = 700

In [None]:
## Test dataset only!!!
# set_seed()  # reproducibility
data = (ImageList.from_df(df, path=PATH, folder=FOLDER, after_open=force_gray)
        .split_none()
        .label_from_df(label_cls=SequenceList, vocab=CharVocab(itos))
        .transform([], size=sz, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=bs, device=device, collate_fn=label_collater)
       )

In [None]:
data.show_batch(rows=4, ds_type=DatasetType.Train, figsize=(18,20))

### learner

In [None]:
learn = make_learner(data, 512, 512, N=4, drops=0.1, attn_type='multi')
# Total # of trainable params:
# N=6: 65,786,272
# N=4: 51,073,440

In [None]:
# learn.load('combo_512_9_wiki2_lm')
learn.load('combo_512_mix_res9')
None

In [None]:
learn.data = data

### test

In [None]:
x,y = next(iter(learn.data.train_dl))

g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item()/bs, cer(g_preds, y)[0]/bs]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

# test:    397.121   .29261  ~3:55   combo_512_9_wiki2_lm  (LM isn't aware of '\n'!!)
# test:    101.374   .05139  ~27s    combo_512_9

In [None]:
#test
fig, axes = plt.subplots(2,2, gridspec_kw={'hspace': 0.4}, figsize=(18, 20))
for i,ax in enumerate(axes.flat):
    #i +=4
    p = char_label_text(g_res[i])
    ax=show_img(x[i], ax=ax, title=p)

## Single Image (cpu)

In [None]:
x,y = data.train_ds[13]
x

In [None]:
pred = learn.predict(x)
show_img(x, title=str(pred[0]))

In [None]:
outs,targs = pred[2][None],tensor(y.data)[None]
g = [learn.loss_func(outs, targs).item(), cer(outs, targs)[0]]
print(f'image:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
probs = F.softmax(pred[2], dim=-1)
scores, idxs = torch.topk(probs, 3, dim=-1)

In [None]:
for s,i in zip(scores,idxs):
    new = {}
    for score, idx in zip(s,i):
        new[itos[idx]] = score.data
    print(new)

In [None]:
chars = idxs.numpy()
chars.appl

In [None]:
idxs.apply_(lambda x: itos[x])

In [None]:
for i in range(30):
    print(char_label_text(idxs[i], sep=' '))

In [None]:
# learn.show_results(ds_type=data.train_ds, rows=2)

In [None]:
i = 0
x = xs[i][None]
y = ys[i][None]

In [None]:
batch = data.one_item(xs[0])

In [None]:
xs,ys = data.one_batch(detach=False, denorm=False)

In [None]:
g_preds = learn.model(x, seq_len=seq_len)
g_res = torch.argmax(g_preds, dim=-1)
g = [learn.loss_func(g_preds, y).item(), cer(g_preds, y)[0]]

print(f'  test:    {str(g[0])[:7]}   {str(g[1])[1:7]}')

In [None]:
p = char_label_text(g_res[0])
show_img(x[0], figsize=(18,10), title=p)

In [None]:
im = PATH/'uploads'/'test1.png'
img = open_image(im)
prediction = learn.predict(img)[0]
show_img(img, title=str(prediction))