# ByteNet in Pytorch

Thomas Viehmann

## Important note: This does not seem to work yet.

This is the ByteNet model presented in [Kalchbenner et al., Neural Machine Translation in Linear Time](https://arxiv.org/abs/1610.10099).

In producing this I also looked at
- [Namju Kim's implementation](https://github.com/buriburisuri/ByteNet) using his SugarTensor framework
- [Paarth Neekhara implementation](https://github.com/paarthneekhara/byteNet-tensorflow) in plain tensorflow

All bugs are my own.

As dataset I used [the ComTrans sample from NLTK](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/comtrans.zip), see the [NLTK data](http://www.nltk.org/nltk_data/) page. The dataset loads the unzipped files.

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data

import collections
import time

In [2]:
# really simple validation splitter
class PartialDataset(torch.utils.data.Dataset):
    def __init__(self, parent_ds, offset=None, length=None, idx_list=None):
        self.parent_ds = parent_ds
        if idx_list is None:
            assert (offset is not None and length is not None), "Either idx_list or both offset and length need to be passed"
            assert len(parent_ds)>=offset+length, Exception("Parent Dataset not long enough")
            self.idx_list = torch.arange(0,len(parent_ds)).long()[offset:offset+length]
        else:
            assert (offset is None and length is None), "Only either idx_list or both offset and length need to be passed"
            self.idx_list = idx_list
        super(PartialDataset, self).__init__()
    def __len__(self):
        return self.idx_list.size(0)
    def __getitem__(self, i):
        return self.parent_ds[self.idx_list[i]]

def validation_split(dataset, val_share=0.1, shuffle=True):
    """
       Split a (training and vaidation combined) dataset into training and validation.
       Note that to be statistically sound, the items in the dataset should be statistically
       independent (e.g. not sorted by class, not several instances of the same dataset that
       could end up in either set).
    
       inputs:
          dataset:   ("training") dataset to split into training and validation
          val_share: fraction of validation data (should be 0<val_share<1, default: 0.1)
          shuffle:   pick random items for each part
       returns: input dataset split into test_ds, val_ds
       
       """
    val_offset = int(len(dataset)*(1-val_share))
    if shuffle:
        idxes = torch.randperm(len(dataset))
    else:
        idxes = torch.arange(0, len(dataset)).long()
    return PartialDataset(dataset, 0, val_offset), PartialDataset(dataset, idx_list=idxes)



A very basic Dataset class for the ComTrans dataset. See above.

In [3]:
class ComTransChar(torch.utils.data.Dataset):
    def __init__(self,fn="comtrans/alignment-de-en.txt", min_len=50, max_len=150):
        f = open(fn, encoding="latin_1")
        self.char_frequencies = collections.Counter()
        self.corpus = []
        self.max_len = max_len
        while True:
            src = f.readline().rstrip()
            if not src:
                break
            tgt = f.readline().rstrip()
            al  = [tuple(int(i) for i in p.split('-')) for p in f.readline().rstrip().split()]
            if min_len<=len(src)<=max_len and min_len<=len(tgt)<=max_len:
                self.corpus.append((src,tgt,al))
                self.char_frequencies.update(src)
                self.char_frequencies.update(tgt)
        self.idx_to_char = dict(enumerate(['','<EOS>']+sorted(self.char_frequencies)))
        self.char_to_idx = {v:k for k,v in self.idx_to_char.items()}
        
        self.src_tensor = torch.LongTensor(len(self.corpus), self.max_len + 1).zero_()
        self.tgt_tensor = torch.LongTensor(len(self.corpus), self.max_len + 1).zero_()
        
        for i,(s,t,_) in enumerate(self.corpus):
            self.src_tensor[i, :len(s)] = torch.LongTensor([self.char_to_idx[c] for c in s])
            self.src_tensor[i, len(s)] = 1 # <EOS>
            self.tgt_tensor[i, :len(t)] = torch.LongTensor([self.char_to_idx[c] for c in t])
            self.tgt_tensor[i, len(t)] = 1 # <EOS>
    def __len__(self):
        return len(self.corpus)
    def __getitem__(self, i):
        return (self.src_tensor[i], self.tgt_tensor[i])
    def tensor_to_str(self, t, cut_at_eos=True):
        res = [self.idx_to_char[i] for i in t]
        if cut_at_eos:
            res = res[:(res+['<EOS>']).index('<EOS>')]
        return ''.join(res)
    def tensors_to_strs(self, *args, cut_at_eos=True):
        return [self.tensor_to_str(t, cut_at_eos=cut_at_eos) for t in args]
    def encode_strs(self, strs, max_len=None):
        if max_len is None:
            max_len = self.max_len
        if type(strs) == str:
            strs = [strs]
        res = torch.LongTensor(len(strs), max_len).zero_()
        for i,s in enumerate(strs):
            res[i, :len(s)] = torch.LongTensor([self.char_to_idx[c] for c in s])
            res[i, len(s)] = 1 # <EOS>
        return res

corpus = ComTransChar(max_len=95)
#dl = torch.utils.data.DataLoader(corpus, batch_size=16, shuffle=True)

train_ds, val_ds = validation_split(corpus, val_share=0.01)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=True)

In [4]:
class ResUnit(nn.Module):
    def __init__(self, in_channels, size=3, dilation=1, causal=False, in_ln=True):
        super(ResUnit, self).__init__()
        self.size = size
        self.dilation = dilation
        self.causal = causal
        self.in_ln = in_ln
        if self.in_ln:
            self.ln1 = nn.InstanceNorm1d(in_channels, affine=True)
            self.ln1.weight.data.fill_(1.0)
        self.conv_in = nn.Conv1d(in_channels, in_channels//2, 1)
        self.ln2 = nn.InstanceNorm1d(in_channels//2, affine=True)
        self.ln2.weight.data.fill_(1.0)
        self.conv_dilated = nn.Conv1d(in_channels//2, in_channels//2, size, dilation=self.dilation,
                                      padding=((dilation*(size-1)) if causal else (dilation*(size-1)//2)))
        self.ln3 = nn.InstanceNorm1d(in_channels//2, affine=True)
        self.ln3.weight.data.fill_(1.0)
        self.conv_out = nn.Conv1d(in_channels//2, in_channels, 1) 

    def forward(self, inp):
        x = inp
        if self.in_ln:
            x = self.ln1(x)
        x = nn.functional.relu(x)
        x = nn.functional.relu(self.ln2(self.conv_in(x)))        
        x = self.conv_dilated(x)
        if self.causal and self.size>1:
            x = x[:,:,:-self.dilation*(self.size-1)]
        x = nn.functional.relu(self.ln3(x))
        x = self.conv_out(x)
        return x+inp

class ResBlock(nn.Sequential):
    def __init__(self, in_channels, size=5, causal=False):
        super(ResBlock, self).__init__(
                ResUnit(in_channels, size=size, dilation=1,  causal=causal, in_ln=False),
                ResUnit(in_channels, size=size, dilation=2,  causal=causal),
                ResUnit(in_channels, size=size, dilation=4,  causal=causal),
                ResUnit(in_channels, size=size, dilation=8,  causal=causal),
                ResUnit(in_channels, size=size, dilation=16, causal=causal)
            )

class Encoder(nn.Sequential):
    def __init__(self, in_channels, num_blocks = 3):
        super(Encoder, self).__init__(*([ResBlock(in_channels) for i in range(num_blocks)]+
                                         [nn.ReLU(),
                                          nn.Conv1d(in_channels, in_channels, 1),
                                          nn.ReLU()]))
        

class Decoder(nn.Sequential):
    def __init__(self, in_channels, num_chars, num_blocks = 3):
        super(Decoder, self).__init__(*([ResBlock(in_channels, size=3, causal=True) for i in range(num_blocks)]+
                                        [nn.Conv1d(in_channels, num_chars, 1)]))


In [5]:
hidden_dim = 400 # half of in the Kalchbrenner et al., section 6
num_chars = len(corpus.char_to_idx)
src_embedder = nn.Embedding(num_embeddings=num_chars, embedding_dim=hidden_dim, padding_idx=0)
tgt_embedder = nn.Embedding(num_embeddings=num_chars, embedding_dim=hidden_dim, padding_idx=0)

encoder = Encoder(hidden_dim)
decoder = Decoder(2*hidden_dim, len(corpus.char_to_idx))

src_embedder.cuda()
tgt_embedder.cuda()
encoder.cuda()
decoder.cuda()

loss_class_weights = torch.ones(num_chars).cuda()
loss_class_weights[0] = 0.0
ce_loss = torch.nn.CrossEntropyLoss(weight=loss_class_weights)

allparams = sum([list(m.parameters()) for m in [src_embedder, tgt_embedder, encoder, decoder]],[])
optim = torch.optim.Adam(allparams, lr=1e-4, betas=(0.5,0.75))

In [6]:
num_epochs = 300
last_time = time.time()
print_steps = 100
for epoch in range(num_epochs):
    for i,(src,tgt) in enumerate(train_loader):
        optim.zero_grad()
        tgt_prev = torch.cat([torch.LongTensor(tgt.size(0), 1).zero_(), tgt[:,:-1]], dim=1)
        src = Variable(src.cuda())
        tgt = Variable(tgt.cuda())
        tgt_prev = Variable(tgt_prev.cuda())
        
        src_emb = src_embedder(src).transpose(1,2)
        tgt_emb = tgt_embedder(tgt).transpose(1,2)
        tgt_emb_prev = tgt_embedder(tgt_prev).transpose(1,2)
        
        enc = encoder(src_emb)
        enc_and_prev = torch.cat([enc, tgt_emb_prev], dim=1)
        
        dec = decoder(enc_and_prev).transpose(1,2).contiguous()
        
        dec_lin = dec.view(-1, num_chars)

        loss = ce_loss(dec_lin, tgt.view(-1)) # note this does not weight by length
        
        loss.backward()
        optim.step()
        
        if i % print_steps == 0:
            this_time = time.time()
            perstep = (this_time-last_time)/print_steps
            stepstogo = len(train_loader)*(num_epochs-epoch)-i
            target_time = time.localtime(this_time+stepstogo*perstep)
            last_time = this_time
            print (epoch, num_epochs, i, len(train_loader), time.strftime("%H:%M:%S",time.localtime()),
                   time.strftime("%H:%M:%S",target_time), loss.data[0])
            test,sol = next(val_loader.__iter__()) #corpus.encode_strs('Sind wir schon gut ?')
            test = Variable(test.cuda())
            enc = src_embedder(test).transpose(1,2)
            enc_and_prev = torch.cat([enc, Variable(torch.zeros(enc.size()).cuda())],dim=1)
            res = torch.LongTensor(test.size(0),corpus.max_len).zero_()
            for i in range(corpus.max_len-1):
                output = decoder(enc_and_prev)[:,:,i]
                _,idx = output.max(1)
                #if idx.data.max()>100:
                #    print ("idx",output.size(), idx.data.max())
                res[:,i] = idx.data
                #if res.max()>100:
                #    print (output.size(), idx.data.max(), res.max())
                res_emb = tgt_embedder(idx)
                enc_and_prev.data[:,hidden_dim:,i+1] = res_emb.data
            print ('\n'.join( corpus.tensors_to_strs(*test.data[:5].cpu())
                              +corpus.tensors_to_strs(*res[:5])+corpus.tensors_to_strs(*sol[:5])))

0 300 0 183 21:02:17 21:07:59 5.128349304199219
Mit unseren Änderungsanträgen wollen wir die Beteiligung der Verbände stärken .
Daraus müssen wir dann aber auch Schlußfolgerungen ziehen .
Die Änderungsanträge 1 und 13 sind sprachliche Präzisierungen .
Der dritte Aspekt betrifft das Subsidiaritätsprinzip , auf das sich die Union stützt .
Ich möchte noch auf einen weiteren Aspekt eingehen .
 c.r   ja   t a B8aQ Pr
 o c  K l c uc aori  t e6ta o  l tieo Áo a . Br ¾o j .¾ r oBÁoy  mB ¾dK
 BGèi e oK e  r ra ie U ( .  e "  r B ore  
 j ma oo
 BKa  rq  
The aim of our amendments is to increase the participation of associations .
But we should draw the necessary conclusions from this .
Amendments Nos 1 and 3 concern linguistic nuances .
The third aspect concerns the principle of subsidiarity that underpins the European Union .
I will give you an additional piece of information .
0 300 100 183 21:02:42 00:52:50 1.7800085544586182
Wir müssen das gräßliche Übel des sexuellen Mißbrauchs von Kindern

KeyboardInterrupt: 