Skip to content

Commit

Permalink
Added translate script
Browse files Browse the repository at this point in the history
  • Loading branch information
Devendra Singh committed Mar 28, 2018
1 parent 2f01752 commit f08a53f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 14 deletions.
5 changes: 3 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def get_train_args():
parser = ArgumentParser(description='Implementation of "Attention is All You Need" in Pytorch')
parser = ArgumentParser(description='Implementation of Transformer in Pytorch')

parser.add_argument('--input', '-i', type=str, default='./data/ja_en',
help='Input directory')
Expand All @@ -12,7 +12,7 @@ def get_train_args():
help='Print stats at this interval')

parser.add_argument('--model', type=str, default='Transformer',
help='Model Type to train (Trasformer / MultiTaskNMT)')
help='Model Type to train ( Trasformer / MultiTaskNMT )')

# Mulltilingual Options
parser.add_argument('--pshare_decoder_param', dest='pshare_decoder_param',
Expand Down Expand Up @@ -172,6 +172,7 @@ def get_translate_args():
help='path to save the best model')
parser.add_argument('--batchsize', type=int, default=60)
parser.add_argument('--beam_size', type=int, default=5)
parser.add_argument('--max_len', type=int, default=50)
parser.add_argument('--alpha', default=1.0, type=float,
help='Length Normalization coefficient')

Expand Down
24 changes: 24 additions & 0 deletions data_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import numpy as np

infile = sys.argv[1]
max_len = int(sys.argv[2])

def read_data(loc_):
data = list()
with open(loc_) as fp:
for line in fp:
text = line.strip()
data.append(text)
return data


train_text = read_data(infile)
data_len = list(map(lambda x: len(x.split()), train_text))

print("Median words: {}".format(np.median(data_len)))
print("Mean words: {}".format(np.mean(data_len)))
print("Max words: {}".format(np.max(data_len)))
print("Min words: {}".format(np.min(data_len)))
print("Std words: {}".format(np.std(data_len)))
print("> {} words: {}".format(max_len, len(list(filter(lambda x: x > max_len, data_len)))))
52 changes: 52 additions & 0 deletions tools/bpe_translate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env bash

TF=$(pwd)
export PATH=$PATH:$TF/bin

BPE_OPS=16000
GPUARG=0

L1=$1
L2=$2
L3=$3

DATA_L1=${TF}"/data/${L1}_${L2}"
DATA_L2=${TF}"/data/${L1}_${L3}"
NAME="run_${L1}_${L2}-${L3}"
OUT="temp/$NAME"

TEST_SRC_L1=$DATA_L1/test.${L1}
TEST_TGT_L1=$DATA_L1/test.${L2}

TEST_SRC_L2=$DATA_L2/test.${L1}
TEST_TGT_L2=$DATA_L2/test.${L3}

# Apply BPE Coding to the languages
apply_bpe -c $OUT/data/bpe-codes.${BPE_OPS} < ${TEST_SRC_L1} > ${OUT}/data/test_l1.src
apply_bpe -c $OUT/data/bpe-codes.${BPE_OPS} < ${TEST_SRC_L2} > ${OUT}/data/test_l2.src


# Translate Language 1
python translate.py -i $OUT/data --data processed --batchsize 28 --beam_size 5 \
--best_model_file $OUT/models/model_best_$NAME.ckpt --src $OUT/data/test_l1.src \
--gpu $GPUARG --output $OUT/test/test_l1.out


# Translate Language 2
python translate.py -i $OUT/data --data processed --batchsize 28 --beam_size 5 \
--best_model_file $OUT/models/model_best_$NAME.ckpt --src $OUT/data/test_l2.src \
--gpu $GPUARG --output $OUT/test/test_l2.out


mv $OUT/test/test_l1.out{,.bpe}
mv $OUT/test/test_l2.out{,.bpe}


cat $OUT/test/test_l1.out.bpe | sed -E 's/(@@ )|(@@ ?$)//g' > $OUT/test/test_l1.out
cat $OUT/test/test_l2.out.bpe | sed -E 's/(@@ )|(@@ ?$)//g' > $OUT/test/test_l2.out


perl tools/multi-bleu.perl $OUT/data/test_l1.tgt < $OUT/test/test_l1.out > $OUT/test/test_l1.tc.bleu
perl tools/multi-bleu.perl $OUT/data/test_l2.tgt < $OUT/test/test_l2.out > $OUT/test/test_l2.tc.bleu

# t2t-bleu --translation=$OUT/test/valid.out --reference=/storage/devendra/temp/wmt16_de_en/newstest2014.tok.de
26 changes: 14 additions & 12 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import torch
from tqdm import tqdm

import net
from models import MultiTaskNMT, Transformer
import utils
from torch.autograd import Variable
import preprocess
from train import save_output
from config import get_translate_args
Expand All @@ -29,15 +31,14 @@ def __call__(self):
hypotheses = []
for i in tqdm(range(0, len(self.test_data), self.batch)):
sources = self.test_data[i:i + self.batch]
if self.beam_size > 1:
ys = self.model.translate(sources,
self.max_length,
beam=self.beam_size,
alpha=self.alpha)
else:
ys = [y.tolist() for y in self.model.translate(sources,
self.max_length,
beam=False)]
x_block = utils.source_pad_concat_convert(sources,
device=None)
x_block = Variable(torch.LongTensor(x_block).type(utils.LONG_TYPE),
requires_grad=False)
ys = self.model.translate(x_block,
self.max_length,
beam=self.beam_size,
alpha=self.alpha)
hypotheses.extend(ys)
return hypotheses

Expand All @@ -62,7 +63,7 @@ def main():
checkpoint['epoch'],
checkpoint['best_score']))
config = checkpoint['opts']
model = net.Transformer(config)
model = eval(args.model)(config)
model.load_state_dict(checkpoint['state_dict'])

if args.gpu >= 0:
Expand All @@ -73,7 +74,8 @@ def main():
source_data,
batch=args.batchsize // 4,
beam_size=args.beam_size,
alpha=args.alpha)()
alpha=args.alpha,
max_length=args.max_len)()
save_output(hyp, id2w, args.output)


Expand Down

0 comments on commit f08a53f

Please sign in to comment.