for fairseq preprocessing

In [1]:
from collections import Counter
from itertools import zip_longest
import os
import shutil


from fairseq.data import indexed_dataset
from fairseq.tasks.span_bert import BertDictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process

In [2]:
def main(args):
    print(args)
    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    def build_dictionary(filenames):
        d = BertDictionary()
        for filename in filenames:
            Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, args.workers)
        return d

    def train_path(lang):
        return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += f'.{lang}'
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path('dict', lang) + '.txt'

    if args.joined_dictionary:
        assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
        assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
        src_dict = build_dictionary(set([
            train_path(lang)
            for lang in [args.source_lang, args.target_lang]
        ]))
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = BertDictionary.load(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)])
        if target:
            if args.tgtdict:
                tgt_dict = BertDictionary.load(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)])

    src_dict.finalize(
        threshold=args.thresholdsrc,
        nwords=args.nwordssrc,
        padding_factor=args.padding_factor,
    )
    src_dict.save(dict_path(args.source_lang))
    if target:
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=args.padding_factor,
            )
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
        dict = BertDictionary.load(dict_path(lang))
        print('| [{}] Dictionary: {} types'.format(lang, len(dict)))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result['replaced'])
            n_seq_tok[0] += worker_result['nseq']
            n_seq_tok[1] += worker_result['ntok']

        input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers-1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize, (args, input_file, dict, prefix, lang,
                                            offsets[worker_id],
                                            offsets[worker_id + 1]), callback=merge_result)
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
        merge_result(Tokenizer.binarize(input_file, dict, lambda t: ds.add_item(t),
                                        offset=0, end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))


        ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))


        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, n_seq_tok[0], n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))



    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == 'binary':
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
        elif args.output_format == 'raw':
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang):
        if args.trainpref:
            make_dataset(args.trainpref, 'train', lang, num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(',')):
                outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
                make_dataset(validpref, outprefix, lang)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(',')):
                outprefix = 'test{}'.format(k) if k > 0 else 'test'
                make_dataset(testpref, outprefix, lang)

    make_all(args.source_lang)
    if target:
        make_all(args.target_lang)

    print('| Wrote preprocessed data to {}'.format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
        tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
        freq_map = {}
        with open(args.alignfile, 'r') as align_file:
            with open(src_file_name, 'r') as src_file:
                with open(tgt_file_name, 'r') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split('-')), a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

        with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format(
                args.source_lang, args.target_lang)), 'w') as f:
            for k, v in align_dict.items():
                print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)



def binarize(args, filename, dict, output_prefix, lang, offset, end):

    ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
    def consumer(tensor):
        ds.add_item(tensor)

    res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
    ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
    return res

def dataset_dest_prefix(args, output_prefix, lang):
    base = f'{args.destdir}/{output_prefix}'
    lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
    return f'{base}{lang_part}'


def dataset_dest_file(args, output_prefix, lang, extension):
    base = dataset_dest_prefix(args, output_prefix, lang)
    return f'{base}.{extension}'

In [4]:
class DictX(dict):
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError as k:
            raise AttributeError(k)

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError as k:
            raise AttributeError(k)

    def __repr__(self):
        return '<DictX ' + dict.__repr__(self) + '>'
    
args = DictX({
    'source_lang': None,
    'target_lang': None,
    'trainpref': '../data/train_tokened_corpus.txt',
    'validpref': '../data/valid_tokened_corpus.txt',
    'testpref': '../data/test_tokened_corpus.txt',
    'destdir': './train_data/',
    'thresholdtg':0,
    'thresholdsrc':0,
    'tgtdict': None,
    'srcdict': None,
    'nwordstgt': -1,
    'nwordssrc': -1, # number of source words to retain
    'alignfile': None,
    'output_format':'binary',
    'joined_dictionary': False,
    'only_source': True,
    'padding_factor': 1,
    'workers': 48,
})

In [5]:
main(args)

<DictX {'source_lang': None, 'target_lang': None, 'trainpref': '../data/train_tokened_corpus.txt', 'validpref': '../data/valid_tokened_corpus.txt', 'testpref': '../data/test_tokened_corpus.txt', 'destdir': './train_data/', 'thresholdtg': 0, 'thresholdsrc': 0, 'tgtdict': None, 'srcdict': None, 'nwordstgt': -1, 'nwordssrc': -1, 'alignfile': None, 'output_format': 'binary', 'joined_dictionary': False, 'only_source': True, 'padding_factor': 1, 'workers': 48}>
| [None] Dictionary: 155769 types
| [None] ../data/train_tokened_corpus.txt: 52434944 sents, 2428193651 tokens, 0.0% replaced by [UNK]
| [None] Dictionary: 155769 types
| [None] ../data/valid_tokened_corpus.txt: 6554368 sents, 303234244 tokens, 4.12e-05% replaced by [UNK]
| [None] Dictionary: 155769 types
| [None] ../data/test_tokened_corpus.txt: 6554376 sents, 303655889 tokens, 0.000104% replaced by [UNK]
| Wrote preprocessed data to ./train_data/
