In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai.text import *

In [None]:
EOS = '<eos>'
PATH=Path('../data/wikitext')

Small helper function to read the tokens.

In [None]:
def read_file(filename):
    tokens = []
    with open(PATH/filename, encoding='utf8') as f:
        for line in f:
            tokens.append(line.split() + [EOS])
    return np.array(tokens)

In [None]:
trn_tok = read_file('wiki.train.tokens')
val_tok = read_file('wiki.valid.tokens')
tst_tok = read_file('wiki.test.tokens')

In [None]:
len(trn_tok), len(val_tok), len(tst_tok)

In [None]:
' '.join(trn_tok[4][:20])

In [None]:
cnt = Counter(word for sent in trn_tok for word in sent)
cnt.most_common(10)

Give an id to each token and add the pad token (just in case we need it).

In [None]:
itos = [o for o,c in cnt.most_common()]
itos.insert(0,'<pad>')

In [None]:
vocab_size = len(itos); vocab_size

Creates the mapping from token to id then numericalizing our datasets.

In [None]:
stoi = collections.defaultdict(lambda : 5, {w:i for i,w in enumerate(itos)})

In [None]:
trn_ids = np.array([([stoi[w] for w in s]) for s in trn_tok])
val_ids = np.array([([stoi[w] for w in s]) for s in val_tok])
tst_ids = np.array([([stoi[w] for w in s]) for s in tst_tok])

## Testing WeightDropout

Create a bunch of parameters for deterministic tests.

In [None]:
module = nn.LSTM(20, 20)
tst_input = torch.randn(2,5,20)
tst_output = torch.randint(0,20,(10,)).long()
save_params = {}
for n,p in module._parameters.items(): save_params[n] = p.clone()

### Old WeightDropout

In [None]:
module = nn.LSTM(20, 20)
for n,p in save_params.items(): module._parameters[n] = nn.Parameter(p.clone())
dp_module = WeightDrop(module, 0.5)
opt = optim.SGD(dp_module.parameters(), 10)
dp_module.train()

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
x.requires_grad_(requires_grad=True)
h = (torch.zeros(1,5,20), torch.zeros(1,5,20))
for _ in range(5): x,h = dp_module(x,h)

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module.module,'weight_hh_l0_raw')

In [None]:
target = tst_output.clone()
loss = F.nll_loss(x.view(-1,20), target)
loss.backward()
opt.step()

In [None]:
w, w_raw = getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module.module,'weight_hh_l0_raw')
w.grad, w_raw.grad

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module.module,'weight_hh_l0_raw')

### New WeightDropout

In [None]:
class WeightDropout(nn.Module):
    "A module that warps another layer in which some weights will be replaced by 0 during training."
    
    def __init__(self, module, dropout, layer_names=['weight_hh_l0']):
        super().__init__()
        self.module,self.dropout,self.layer_names = module,dropout,layer_names
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
    
    def _setweights(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            self.module._parameters[layer] = F.dropout(raw_w, p=self.dropout, training=self.training)
            
    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)
    
    def reset(self):
        if hasattr(self.module, 'reset'): self.module.reset()

In [None]:
module = nn.LSTM(20, 20)
for n,p in save_params.items(): module._parameters[n] = nn.Parameter(p.clone())
dp_module = WeightDropout(module, 0.5)
opt = optim.SGD(dp_module.parameters(), 10)
dp_module.train()

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
x.requires_grad_(requires_grad=True)
h = (torch.zeros(1,5,20), torch.zeros(1,5,20))
for _ in range(5): x,h = dp_module(x,h)

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')

In [None]:
target = tst_output.clone()
loss = F.nll_loss(x.view(-1,20), target)
loss.backward()
opt.step()

In [None]:
w, w_raw = getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')
w.grad, w_raw.grad

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')

## Testing EmbeddingDropout

Create a bunch of parameters for deterministic tests.

In [None]:
enc = nn.Embedding(100,20, padding_idx=0)
tst_input = torch.randint(0,100,(25,)).long()
save_params = enc.weight.clone()

### Old EmbeddingDropout

In [None]:
enc = nn.Embedding(100,20, padding_idx=0)
enc.weight = nn.Parameter(save_params.clone())
enc_dp = EmbeddingDropout(enc)

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
enc_dp(x, dropout=0.5)

### New EmbeddingDropout

In [None]:
def dropout_mask(x, sz, p):
    "Returns a dropout mask of the same type as x, size sz, with probability p to cancel an element."
    return x.new(*sz).bernoulli_(1-p)/(1-p)

In [None]:
class EmbeddingDropout1(nn.Module):

    "Applies dropout in the embedding layer by zeroing out some elements of the embedding vector."
    def __init__(self, emb, dropout):
        super().__init__()
        self.emb,self.dropout = emb,dropout
        self.pad_idx = self.emb.padding_idx
        if self.pad_idx is None: self.pad_idx = -1

    def forward(self, words, dropout=0.1, scale=None):
        if self.training and self.dropout != 0:
            size = (self.emb.weight.size(0),1)
            mask = dropout_mask(self.emb.weight.data, size, self.dropout)
            masked_emb_weight = mask * self.emb.weight
        else: masked_emb_weight = self.emb.weight
        if scale: masked_emb_weight = scale * masked_emb_weight
        return F.embedding(words, masked_emb_weight, self.pad_idx, self.emb.max_norm,
                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)

In [None]:
enc = nn.Embedding(100,20, padding_idx=0)
enc.weight = nn.Parameter(save_params.clone())
enc_dp = EmbeddingDropout1(enc, 0.5)

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
enc_dp(x)

## Testing RNN model

Creating a bunch of parameters for deterministic testing.

In [None]:
tst_model = get_language_model(500, 20, 100, 2, 0, bias=True)
save_parameters = {}
for n,p in tst_model.state_dict().items(): save_parameters[n] = p.clone()
tst_input = torch.randint(0, 500, (10,5)).long()
tst_output = torch.randint(0, 500, (50,)).long()

### Old RNN model

In [None]:
tst_model = get_language_model(500, 20, 100, 2, 0, bias=True, dropout=0.4, dropoute=0.1, dropouth=0.2, 
                               dropouti=0.6, wdrop=0.5)
state_dict = OrderedDict()
for n,p in save_parameters.items(): state_dict[n] = p.clone()
tst_model.load_state_dict(state_dict)
opt = optim.SGD(tst_model.parameters(), lr=10)

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
z = tst_model(x)
z

In [None]:
y = tst_output.clone()
loss = F.nll_loss(z[0], y)
loss.backward()
opt.step()

In [None]:
tst_model[0].rnns[0].module._parameters['weight_hh_l0_raw']

### New RNN model

In [None]:
class RNNDropout(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p=p

    def forward(self, x):
        if not self.training or not self.p: return x
        m = dropout_mask(x.data, (1, x.size(1), x.size(2)), self.p)
        return m * x

In [None]:
def repackage_var1(h):
    "Detaches h from its history."
    return h.detach() if type(h) == torch.Tensor else tuple(repackage_var(v) for v in h)

In [None]:
class RNNCore(nn.Module):
    "AWD-LSTM/QRNN inspired by https://arxiv.org/abs/1708.02182"

    initrange=0.1

    def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token, bidir=False,
                 hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5, qrnn=False):
        
        super().__init__()
        self.bs,self.qrnn,self.ndir = 1, qrnn,(2 if bidir else 1)
        self.emb_sz,self.n_hid,self.n_layers = emb_sz,n_hid,n_layers
        self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
        self.dp_encoder = EmbeddingDropout1(self.encoder, embed_p)
        if self.qrnn:
            #Using QRNN requires cupy: https://github.com/cupy/cupy
            from .torchqrnn.qrnn import QRNNLayer
            self.rnns = [QRNNLayer(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir,
                save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(n_layers)]
            if weight_p != 0.:
                for rnn in self.rnns:
                    rnn.linear = WeightDropout(rnn.linear, weight_p, layer_names=['weight'])
        else:
            self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.ndir,
                1, bidirectional=bidir) for l in range(n_layers)]
            if weight_p != 0.: self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns]
        self.rnns = torch.nn.ModuleList(self.rnns)
        self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.dropouti = RNNDropout(input_p)
        self.dropouths = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])

    def forward(self, input):
        sl,bs = input.size()
        if bs!=self.bs:
            self.bs=bs
            self.reset()
        raw_output = self.dropouti(self.dp_encoder(input))
        new_hidden,raw_outputs,outputs = [],[],[]
        for l, (rnn,drop) in enumerate(zip(self.rnns, self.dropouths)):
            with warnings.catch_warnings():
                #To avoid the warning that comes because the weights aren't flattened.
                warnings.simplefilter("ignore")
                raw_output, new_h = rnn(raw_output, self.hidden[l])
            new_hidden.append(new_h)
            raw_outputs.append(raw_output)
            if l != self.n_layers - 1: raw_output = drop(raw_output)
            outputs.append(raw_output)
        self.hidden = repackage_var1(new_hidden)
        return raw_outputs, outputs

    def one_hidden(self, l):
        nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz)//self.ndir
        return self.weights.new(self.ndir, self.bs, nh).zero_()

    def reset(self):
        [r.reset() for r in self.rnns if hasattr(r, 'reset')]
        self.weights = next(self.parameters()).data
        if self.qrnn: self.hidden = [self.one_hidden(l) for l in range(self.n_layers)]
        else: self.hidden = [(self.one_hidden(l), self.one_hidden(l)) for l in range(self.n_layers)]

In [None]:
class LinearDecoder1(nn.Module):
    "To go on top of a RNN_Core module"
    
    initrange=0.1
    
    def __init__(self, n_out, n_hid, output_p, tie_encoder=None, bias=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.dropout = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input):
        raw_outputs, outputs = input
        output = self.dropout(outputs[-1])
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded, raw_outputs, outputs

In [None]:
class SequentialRNN1(nn.Sequential):
    def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()

In [None]:
def get_language_model1(vocab_sz, emb_sz, n_hid, n_layers, pad_token, tie_weights=True, qrnn=False, bias=True,
                 output_p=0.4, hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
    "To create a full AWD-LSTM"
    rnn_enc = RNNCore(vocab_sz, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token, qrnn=qrnn,
                 hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)
    enc = rnn_enc.encoder if tie_weights else None
    return SequentialRNN1(rnn_enc, LinearDecoder1(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias))

The new model has weights that are organized a bit differently.

In [None]:
save_parameters1 = {}
for n,p in save_parameters.items(): 
    if 'weight_hh_l0' not in n and n!='0.encoder_with_dropout.embed.weight':  save_parameters1[n] = p.clone()
    elif n=='0.encoder_with_dropout.embed.weight': save_parameters1['0.dp_encoder.emb.weight'] = p.clone()
    else: 
        save_parameters1[n[:-4]] = p.clone()
        splits = n.split('.')
        splits.remove(splits[-2])
        n1 = '.'.join(splits)
        save_parameters1[n1] = p.clone()

In [None]:
tst_model = get_language_model1(500, 20, 100, 2, 0)
tst_model.load_state_dict(save_parameters1)
opt = optim.SGD(tst_model.parameters(), lr=10)

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
z = tst_model(x)
z

In [None]:
y = tst_output.clone()
loss = F.nll_loss(z[0], y)
loss.backward()
opt.step()

In [None]:
tst_model[0].rnns[0]._parameters['weight_hh_l0_raw']

## Regularization

We'll keep the same param as before.

### Old reg

In [None]:
tst_model = get_language_model(500, 20, 100, 2, 0, bias=True, dropout=0.4, dropoute=0.1, dropouth=0.2, 
                               dropouti=0.6, wdrop=0.5)
state_dict = OrderedDict()
for n,p in save_parameters.items(): state_dict[n] = p.clone()
tst_model.load_state_dict(state_dict)
opt = optim.SGD(tst_model.parameters(), lr=10, weight_decay=1)

In [None]:
torch.manual_seed(7)

In [None]:
x = tst_input.clone()
z = tst_model(x)
y = tst_output.clone()
loss = F.nll_loss(z[0], y)

In [None]:
loss = seq2seq_reg(z[0], z[1:], loss, 2, 1)
loss.item()

In [None]:
loss.backward()
nn.utils.clip_grad_norm_(tst_model.parameters(), 0.1)
opt.step()

In [None]:
tst_model[0].rnns[0].module._parameters['weight_hh_l0_raw']

### New reg

In [None]:
from dataclasses import dataclass

In [None]:
@dataclass
class RNNTrainer(Callback):
    model:nn.Module
    bptt:int
    clip:float=None
    alpha:float=0.
    beta:float=0.
    
    def on_loss_begin(self, last_output, **kwargs):
        #Save the extra outputs for later and only returns the true output.
        self.raw_out,self.out = last_output[1],last_output[2]
        return last_output[0]
    
    def on_backward_begin(self, last_loss, last_input, last_output, **kwargs):
        #Adjusts the lr to the bptt selected
        #self.learn.opt.lr *= last_input.size(0) / self.bptt
        #AR and TAR
        if self.alpha != 0.:  last_loss += (self.alpha * self.out[-1].pow(2).mean()).sum()
        if self.beta != 0.:
            h = self.raw_out[-1]
            if len(h)>1: last_loss += (self.beta * (h[1:] - h[:-1]).pow(2).mean()).sum()
        return last_loss
    
    def on_backward_end(self, **kwargs):
        if self.clip:  nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)

In [None]:
save_parameters1 = {}
for n,p in save_parameters.items(): 
    if 'weight_hh_l0' not in n and n!='0.encoder_with_dropout.embed.weight':  save_parameters1[n] = p.clone()
    elif n=='0.encoder_with_dropout.embed.weight': save_parameters1['0.dp_encoder.emb.weight'] = p.clone()
    else: 
        save_parameters1[n[:-4]] = p.clone()
        splits = n.split('.')
        splits.remove(splits[-2])
        n1 = '.'.join(splits)
        save_parameters1[n1] = p.clone()

In [None]:
tst_model = get_language_model1(500, 20, 100, 2, 0)
tst_model.load_state_dict(save_parameters1)
opt = optim.SGD(tst_model.parameters(), lr=10, weight_decay=1)

In [None]:
torch.manual_seed(7)

In [None]:
cb = RNNTrainer(tst_model, 10, 0.1, 2, 1)

In [None]:
x = tst_input.clone()
z = tst_model(x)
y = tst_output.clone()
z = cb.on_loss_begin(z)
loss = F.nll_loss(z, y)
loss = cb.on_backward_begin(loss, x, z)
loss.item()

In [None]:
loss.backward()
cb.on_backward_end()
opt.step()

In [None]:
tst_model[0].rnns[0]._parameters['weight_hh_l0_raw']

## Classifier

Some data to test

In [None]:
tst_model = get_rnn_classifier(10, 50, 2, 500, 20, 100, 2, 0, layers=[60,50,2], drops=[0.1,0.1], dropoute=0.1, dropouth=0.2, 
                               dropouti=0.6, wdrop=0.5)
save_parameters = {}
for n,p in tst_model.state_dict().items(): save_parameters[n] = p.clone()
tst_input = torch.randint(0, 500, (10,50)).long()
tst_output = torch.randint(0, 2, (50,)).long()

### Old classifier

In [None]:
tst_model = get_rnn_classifier(10, 50, 2, 500, 20, 100, 2, 0, layers=[60,50,2], drops=[0.1,0.1], dropoute=0.1, dropouth=0.2, 
                               dropouti=0.6, wdrop=0.5)
state_dict = OrderedDict()
for n,p in save_parameters.items(): state_dict[n] = p.clone()
tst_model.load_state_dict(state_dict)
opt = optim.SGD(tst_model.parameters(), lr=10)

In [None]:
torch.manual_seed(7)

In [None]:
tst_model.reset()
x = tst_input.clone()
z = tst_model(x)
z

In [None]:
y = tst_output.clone()
loss = F.nll_loss(z[0], y)
loss.backward()
opt.step()
loss

In [None]:
tst_model[0].rnns[0].module._parameters['weight_hh_l0_raw']

### New classifier

In [None]:
class MultiBatchRNNCore(RNNCore):
    def __init__(self, bptt, max_seq, *args, **kwargs):
        self.max_seq,self.bptt = max_seq,bptt
        super().__init__(*args, **kwargs)

    def concat(self, arrs):
        return [torch.cat([l[si] for l in arrs]) for si in range(len(arrs[0]))]

    def forward(self, input):
        sl,bs = input.size()
        self.reset()
        raw_outputs, outputs = [],[]
        for i in range(0, sl, self.bptt):
            r, o = super().forward(input[i: min(i+self.bptt, sl)])
            if i>(sl-self.max_seq):
                raw_outputs.append(r)
                outputs.append(o)
        return self.concat(raw_outputs), self.concat(outputs)

In [None]:
def bn_dp_lin(n_in, n_out, drop, relu=True): 
    layers = [nn.BatchNorm1d(n_in), nn.Dropout(drop), nn.Linear(n_in, n_out)]
    if relu: layers.append(nn.ReLU(inplace=True))
    return layers

In [None]:
class PoolingLinearClassifier1(nn.Module):
    def __init__(self, layers, drops):
        super().__init__()
        lyrs = []
        for i in range(len(layers)-1):
            lyrs += bn_dp_lin(layers[i], layers[i + 1], drops[i], i!=len(layers)-2)
        self.layers = nn.Sequential(*lyrs)

    def pool(self, x, bs, is_max):
        f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
        return f(x.permute(1,2,0), (1,)).view(bs,-1)

    def forward(self, input):
        raw_outputs, outputs = input
        output = outputs[-1]
        sl,bs,_ = output.size()
        avgpool = self.pool(output, bs, False)
        mxpool = self.pool(output, bs, True)
        x = torch.cat([output[-1], mxpool, avgpool], 1)
        x = self.layers(x)
        return x, raw_outputs, outputs

In [None]:
def get_rnn_classifier1(bptt, max_seq, n_class, vocab_sz, emb_sz, n_hid, n_layers, pad_token, layers, drops, 
                       bidir=False, qrnn=False, hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
    rnn_enc = MultiBatchRNNCore(bptt, max_seq, vocab_sz, emb_sz, n_hid, n_layers, pad_token=pad_token, bidir=bidir,
                      qrnn=qrnn, hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)
    return SequentialRNN1(rnn_enc, PoolingLinearClassifier1(layers, drops))

In [None]:
save_parameters1 = {}
to_del = ["1.layers.0.bn.num_batches_tracked", "1.layers.1.bn.num_batches_tracked"]
old = ["1.layers.0.bn.weight", "1.layers.0.bn.bias", "1.layers.0.bn.running_mean", "1.layers.0.bn.running_var", "1.layers.0.lin.weight", "1.layers.0.lin.bias", "1.layers.1.bn.weight", "1.layers.1.bn.bias", "1.layers.1.bn.running_mean", "1.layers.1.bn.running_var", "1.layers.1.lin.weight", "1.layers.1.lin.bias"]
new = ["1.layers.0.weight", "1.layers.0.bias", "1.layers.0.running_mean", "1.layers.0.running_var", "1.layers.2.weight", "1.layers.2.bias", "1.layers.4.weight", "1.layers.4.bias", "1.layers.4.running_mean", "1.layers.4.running_var", "1.layers.6.weight", "1.layers.6.bias"]
for n,p in save_parameters.items(): 
    if n in old: save_parameters1[new[old.index(n)]] = p.clone()  
    elif 'weight_hh_l0' not in n and not n in to_del and n!='0.encoder_with_dropout.embed.weight':  
        save_parameters1[n] = p.clone()
    elif n=='0.encoder_with_dropout.embed.weight': save_parameters1['0.dp_encoder.emb.weight'] = p.clone()
    elif not n in to_del: 
        save_parameters1[n[:-4]] = p.clone()
        splits = n.split('.')
        splits.remove(splits[-2])
        n1 = '.'.join(splits)
        save_parameters1[n1] = p.clone()

In [None]:
tst_model = get_rnn_classifier1(10, 50, 2, 500, 20, 100, 2, 0, layers=[60,50,2], drops=[0.1,0.1], embed_p=0.1, hidden_p=0.2, 
                               input_p=0.6, weight_p=0.5)
tst_model.load_state_dict(save_parameters1)
opt = optim.SGD(tst_model.parameters(), lr=10)

In [None]:
tst_model

In [None]:
torch.manual_seed(7)

In [None]:
tst_model.reset()
x = tst_input.clone()
z = tst_model(x)
z

In [None]:
y = tst_output.clone()
loss = F.nll_loss(z[0], y)
loss.backward()
opt.step()
loss

In [None]:
tst_model[0].rnns[0]._parameters['weight_hh_l0_raw']