In [1]:
from __future__ import division

import onmt
import argparse
import torch
import torch.nn as nn
import math
import time
import torch.optim as optimizer
from torch import cuda
from torch.autograd import Variable

from onmt.modules.discriminator import Discriminator
from onmt.modules.gradient_reversal import ReverseLayer

In [2]:
from argparse import Namespace

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
debug = False

#### Load Enviroment Variables

In [5]:
opt = Namespace(adapt=True, batch_size=64, brnn=False, brnn_merge='concat', curriculum=False, data='data/multi30k.atok-train.pt', dropout=0.3, epochs=100, extra_shuffle=False, gpus=[1], input_feed=1, layers=2, learning_rate=.01, learning_rate_decay=0.5, log_interval=50, max_generator_batches=32, max_grad_norm=5, optim='sgd', param_init=0.1, pre_word_vecs_dec=None, pre_word_vecs_enc=None, rnn_size=500, save_model='model', start_decay_at=8, start_epoch=1, train_from='', train_from_state_dict='', word_vec_size=500)

opt.cuda = len(opt.gpus)

#### Setup up the cuda Eviroment

In [6]:
cuda.set_device(opt.gpus[0])

#### Load Dataset

In [7]:
print("Loading data from '%s'" % opt.data)

dataset = torch.load(opt.data)

# type(dataset) = <type 'dict'>
trainData = onmt.Dataset(dataset['train']['src'],
                         dataset['train']['tgt'], opt.batch_size, opt.cuda)
validData = onmt.Dataset(dataset['valid']['src'],
                         dataset['valid']['tgt'], opt.batch_size, opt.cuda)

domain_train = None
domain_valid = None
if opt.adapt:
    assert('domain_train' in dataset)
    assert('domain_valid' in dataset)
    domain_train = onmt.Dataset(dataset['domain_train']['src'], None,
                              opt.batch_size, opt.cuda)
    domain_valid = onmt.Dataset(dataset['domain_valid']['src'], None,
                              opt.batch_size, opt.cuda)


dicts = dataset['dicts']
print(' * vocabulary size. source = %d; target = %d' %
      (dicts['src'].size(), dicts['tgt'].size()))
print(' * number of training sentences. %d' %
      len(dataset['train']['src']))
print(' * maximum batch size. %d' % opt.batch_size)

Loading data from 'data/multi30k.atok-train.pt'
 * vocabulary size. source = 10843; target = 18562
 * number of training sentences. 28997
 * maximum batch size. 64


#### Dataset exploration

In [8]:
def lookup_src(x):
    return dicts['src'].idxToLabel[x]

def lookup_tgt(x):
    return dicts['tgt'].idxToLabel[x]

for a in trainData.src:
    a_list = a.numpy().tolist()
    print "--> Len: " + str(len(a_list))
    print "--> Src Sentence: " + str(" ".join(map(lookup_src, a_list)))
    print "--> Tgt Sentence: " + str(" ".join(map(lookup_tgt, a_list))) + "\n"
    break

--> Len: 4
--> Src Sentence: Men play baseball .
--> Tgt Sentence: Papier Instrumenten Meerwasser vieler



#### Define the Network

In [9]:
print('Building model...')

encoder = onmt.DomainModels.Encoder(opt, dicts['src'])
decoder = onmt.DomainModels.Decoder(opt, dicts['tgt'])

generator = nn.Sequential(
        nn.Linear(opt.rnn_size, dicts['tgt'].size()),
        nn.LogSoftmax())

if opt.adapt:
    discriminator = Discriminator(opt.word_vec_size  * opt.layers)
    gradient_reversal = ReverseLayer()
    
model = onmt.DomainModels.NMTModel(encoder, decoder, discriminator)

if len(opt.gpus) >= 1:
    model.cuda()
    generator.cuda()
else:
    model.cpu()
    generator.cpu()

if len(opt.gpus) > 1:
    model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
    generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

model.generator = generator

Building model...


#### Define a Loss function and optimizer

In [10]:
def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
    # compute generations one piece at a time
    num_correct, loss = 0, 0
    outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)

    batch_size = outputs.size(1)
    outputs_split = torch.split(outputs, opt.max_generator_batches)
    targets_split = torch.split(targets, opt.max_generator_batches)
    for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
        out_t = out_t.view(-1, out_t.size(2))
        scores_t = generator(out_t)
        loss_t = crit(scores_t, targ_t.view(-1))
        pred_t = scores_t.max(1)[1]
        num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(onmt.Constants.PAD).data).sum()
        num_correct += num_correct_t
        loss += loss_t.data[0]
        if not eval:
            loss_t.div(batch_size).backward()

    grad_output = None if outputs.grad is None else outputs.grad.data
    return loss, grad_output, num_correct

#### Test Network

In [11]:
def eval(model, criterion, data):
    total_loss = 0
    total_words = 0
    total_num_correct = 0

    model.eval()
    for i in range(len(data)):
        batch = data[i][:-1] # exclude original indices
        outputs = model(batch)
        targets = batch[1][1:]  # exclude <s> from targets
        loss, _, num_correct = memoryEfficientLoss(
                outputs, targets, model.generator, criterion, eval=True)
        total_loss += loss
        total_num_correct += num_correct
        total_words += targets.data.ne(onmt.Constants.PAD).sum()

    model.train()
    return float(total_loss) / float(total_words),\
           float(total_num_correct) / float(total_words)

def domain_eval(model, data_old, data_new):
    model.eval()
    accuracy = 0
    total_num_discrim_correct, total_num_discrim_elements = 0, 0
    for i in range(min(len(data_new),len(data_old))):
        batch_old = data_old[i][:-1] # exclude original indices
        batch_new = data_new[i][:-1]
        
        _, old_domain, new_domain = model(batch_old, domain_batch=batch_new)  
        
        tgts = Variable(torch.FloatTensor(len(old_domain) + len(new_domain),), requires_grad=False) 
            
        if opt.cuda:
            tgts = tgts.cuda()

        tgts[:] = 0.0
        tgts[:len(old_domain)] = 1.0
        discrim_correct, num_discrim_elements = get_accuracy(torch.cat([old_domain, new_domain]).data.squeeze(), tgts.data)
        
        # Discriminator counts
        total_num_discrim_correct += discrim_correct
        total_num_discrim_elements += num_discrim_elements
        
    return float(total_num_discrim_correct) / float(total_num_discrim_elements)

def get_accuracy(prediction, truth):
    assert(prediction.nelement() == truth.nelement())
    prediction[prediction < 0.5]  = 0.0
    prediction[prediction >= 0.5] = 1.0
    #accuracy = (100.0 * prediction.eq(truth).sum()) / float(prediction.nelement())
    return prediction.eq(truth).sum(), float(prediction.nelement())
    #return accuracy

#### Train the network

In [12]:
def NMTCriterion(vocabSize):
    weight = torch.ones(vocabSize)
    weight[onmt.Constants.PAD] = 0
    crit = nn.NLLLoss(weight, size_average=False)
    if opt.cuda:
        crit.cuda()
    return crit

In [13]:
def sentences (batch,domain_batch):
    for old_sentence_src, old_sentence_tgt, new_sentence in zip(batch[0],batch[1],domain_batch) :
        old_sentence_src = [dataset['dicts']['src'].idxToLabel[x] for x in old_sentence_src.data]
        old_sentence_src = " ".join(old_sentence_src)
        print "old sentence src: ", old_sentence_src 
                    
        old_sentence_tgt = [dataset['dicts']['tgt'].idxToLabel[x] for x in old_sentence_tgt.data]
        old_sentence_tgt = " ".join(old_sentence_tgt)
        print "old sentence tgt: ", old_sentence_tgt 

        new_sentence = [dataset['dicts']['src'].idxToLabel[x] for x in new_sentence.data]
        new_sentence = " ".join(new_sentence)
        print "\nnew sentence: ", new_sentence
                    
        print "-------------------"

In [14]:
def trainModel(model, trainData, validData, domain_train, domain_valid, dataset, optim):
    print(model)
    model.train()

    # define criterion of each GPU
    criterion = NMTCriterion(dataset['dicts']['tgt'].size())

    start_time = time.time()
    def trainEpoch(epoch):

        if opt.extra_shuffle and epoch > opt.curriculum:
            trainData.shuffle()

        # shuffle mini batch order
        batchOrder = torch.randperm(len(trainData))

        discriminator_criterion = None
        if opt.adapt:
            batchOrderAdapt = torch.randperm(len(domain_train))
            discriminator_criterion = nn.BCELoss()

        total_num_discrim_correct, total_num_discrim_elements = 0, 0
        total_loss, total_words, total_num_correct = 0, 0, 0
        report_loss, report_tgt_words, report_src_words, report_num_correct = 0, 0, 0, 0
        start = time.time()
        for i in range(len(trainData)):

            batchIdx = batchOrder[i] if epoch > opt.curriculum else i
            batch = trainData[batchIdx][:-1] # exclude original indices

            model.zero_grad()
            if opt.adapt:
                batchIdxAdapt = batchOrderAdapt[i] if epoch >= opt.curriculum else i
                batch_len = len(batch[0][1])
                domain_batch = domain_train[batchIdxAdapt][:-1]

                outputs, old_domain, new_domain = model(batch, domain_batch=domain_batch)
                discriminator_targets = Variable(torch.FloatTensor(len(old_domain) + len(new_domain),), requires_grad=False)

                if opt.cuda:
                    discriminator_targets = discriminator_targets.cuda()

                discriminator_targets[:] = 0.0
                discriminator_targets[:len(old_domain)] = 1.0
                discrim_correct, num_discrim_elements = get_accuracy(torch.cat([old_domain, new_domain]).data.squeeze(), discriminator_targets.data)


                discriminator_loss = discriminator_criterion(torch.cat([old_domain, new_domain]), discriminator_targets)
            else:
                outputs = model(batch)

            targets = batch[1][1:]  # exclude <s> from targets
            loss, gradOutput, num_correct = memoryEfficientLoss(
                    outputs, targets, model.generator, criterion)

            # We do the domain adaptation backward call here
            if opt.adapt:
                model.zero_grad()
                outputs.backward(gradOutput)
            else:
                outputs.backward(gradOutput)


            # update the parameters
            optim.step()

            num_words = targets.data.ne(onmt.Constants.PAD).sum()
            report_loss += loss
            report_num_correct += num_correct
            report_tgt_words += num_words
            report_src_words += sum(batch[0][1])
            total_loss += loss
            total_num_correct += num_correct
            total_words += num_words

            # Discriminator counts
            if opt.adapt:
                total_num_discrim_correct += discrim_correct
                total_num_discrim_elements += num_discrim_elements

            if i % opt.log_interval == -1 % opt.log_interval:
                print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
                      (epoch, i+1, len(trainData),
                      report_num_correct / report_tgt_words * 100,
                      math.exp(report_loss / report_tgt_words),
                      report_src_words/(time.time()-start),
                      report_tgt_words/(time.time()-start),
                      time.time()-start_time))

                if opt.adapt:
                    print "discrim_correct: ", discrim_correct
                    print "num_discrim_elements: ", num_discrim_elements, '\n'

                report_loss = report_tgt_words = report_src_words = report_num_correct = 0
                start = time.time()

        if opt.adapt:
            return total_loss / total_words, total_num_correct / total_words, total_num_discrim_correct / total_num_discrim_elements
        else:
            return total_loss / total_words, total_num_correct / total_words, 0
        
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        #  (1) train for one epoch on the training set
        train_loss, train_acc, train_discrim_acc = trainEpoch(epoch)
        print('Train perplexity: %g' % math.exp(min(train_loss, 100)))
        print('Train accuracy: %g' % (train_acc*100))
        print('Train discriminator accuracy: %g' % (train_discrim_acc * 100))

        #  (2) evaluate on the validation set
        valid_loss, valid_acc = eval(model, criterion, validData)
        valid_ppl = math.exp(min(valid_loss, 100))
        print('Validation perplexity: %g' % valid_ppl)
        print('Validation accuracy: %g' % (valid_acc*100))
        if opt.adapt:
            valid_discrim_acc = domain_eval(model, validData, domain_valid)
            print('Validation discriminator accuracy: %g' % (valid_discrim_acc * 100))
        #  (3) update the learning rate
        optim.updateLearningRate(valid_loss, epoch)

        model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
        model_state_dict = {k: v for k, v in model_state_dict.items() if 'generator' not in k}
        generator_state_dict = model.generator.module.state_dict() if len(opt.gpus) > 1 else model.generator.state_dict()
        #  (4) drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'generator': generator_state_dict,
            'dicts': dataset['dicts'],
            'opt': opt,
            'epoch': epoch,
            'optim': optim
        }
        torch.save(checkpoint,
                   '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt.save_model, 100*valid_acc, valid_ppl, epoch))

In [15]:
for p in model.parameters():
    p.data.uniform_(-opt.param_init, opt.param_init)

encoder.load_pretrained_vectors(opt)
decoder.load_pretrained_vectors(opt)

optim = onmt.Optim(
    opt.optim, opt.learning_rate, opt.max_grad_norm,
    lr_decay=opt.learning_rate_decay,
    start_decay_at=opt.start_decay_at
)

optim.set_parameters(model.parameters())

nParams = sum([p.nelement() for p in model.parameters()])
print('* number of parameters: %d' % nParams)

trainModel(model, trainData, validData, domain_train, domain_valid, dataset, optim)

* number of parameters: 49777064
NMTModel (
  (encoder): Encoder (
    (word_lut): Embedding(10843, 500, padding_idx=0)
    (rnn): LSTM(500, 500, num_layers=2, dropout=0.3)
  )
  (decoder): Decoder (
    (word_lut): Embedding(18562, 500, padding_idx=0)
    (rnn): StackedLSTM (
      (dropout): Dropout (p = 0.3)
      (layers): ModuleList (
        (0): LSTMCell(1000, 500)
        (1): LSTMCell(500, 500)
      )
    )
    (attn): GlobalAttention (
      (linear_in): Linear (500 -> 500)
      (sm): Softmax ()
      (linear_out): Linear (1000 -> 500)
      (tanh): Tanh ()
    )
    (dropout): Dropout (p = 0.3)
  )
  (discriminator): Discriminator (
    (lin1): Linear (1000 -> 1)
    (lin2): Linear (4000 -> 4000)
    (lin3): Linear (4000 -> 1)
  )
  (generator): Sequential (
    (0): Linear (500 -> 18562)
    (1): LogSoftmax ()
  )
)

Epoch  1,    50/  454; acc:   7.72; ppl: 8349.17; 4562 src tok/s; 4727 tgt tok/s;      9 s elapsed
discrim_correct:  64
num_discrim_elements:  128.0 

Epoch 