In [1]:
import re
import os
import time
import pickle

import torch
from tqdm import tqdm
from transformers import AutoModelWithLMHead, AutoTokenizer

### 导入 中 - 英 翻译模型

In [2]:
model_zh_en = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
tokenizer_zh_en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
model_en_zh = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
tokenizer_en_zh = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")



### 导入 中 - 德 翻译模型

In [3]:
model_zh_de = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-zh-de")
tokenizer_zh_de = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-de")
model_de_zh = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-de-ZH")
tokenizer_de_zh = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-ZH")

### 使用GPU加速

In [4]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model_zh_en.to(device)
model_en_zh.to(device)
model_zh_de.to(device)
model_de_zh.to(device)

MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(61916, 512, padding_idx=61915)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(61916, 512, padding_idx=61915)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
   

In [48]:
def batch_translation(texts, tokenizer_zh_fg, model_zh_fg, tokenizer_fg_zh, model_fg_zh, max_length=200):
    """
    批量文本翻译
    """
    t = time.time()
    fg_texts = []
    zh_texts = []
    # 中文转外语
    embedded_zh = tokenizer_zh_fg.batch_encode_plus(texts, padding=True)['input_ids']
    embedded_zh = torch.tensor(embedded_zh).to(device)
    embedded_zh = model_zh_fg.generate(embedded_zh, max_length=max_length, num_beams=4, early_stopping=True)
    for i in range(embedded_zh.size()[0]):
        fg_texts.append(re.sub('<pad>', '', tokenizer_zh_fg.decode(embedded_zh[i])))

    # 外语转英文
    embedded_fg = tokenizer_fg_zh.batch_encode_plus(fg_texts, padding=True)['input_ids']
    embedded_fg = torch.tensor(embedded_fg).to(device)
    embedded_fg = model_fg_zh.generate(embedded_fg, max_length=max_length,
                                       num_beams=4, early_stopping=True)
    for i in range(embedded_fg.size()[0]):
        zh_texts.append(re.sub('<pad>| ', '', tokenizer_fg_zh.decode(embedded_fg[i])))

    assert len(texts) == len(zh_texts)

#     print(time.time() - t)

    return zh_texts

### 读取训练数据

In [12]:
data_dir = os.path.join(os.getcwd(), 'tcdata')
files = {
    'bq_corpus': os.path.join(data_dir, 'bq_corpus', 'train.tsv'),
    'lcqmc': os.path.join(data_dir, 'lcqmc', 'train.tsv'),
    'paws-x-zh': os.path.join(data_dir, 'paws-x-zh', 'train.tsv')
}

In [19]:
data = {}
for k, n in files.items():
    try:
        fr = open(n, 'r', encoding='utf8')
    except:
        fr = open(n, 'r', encoding='gbk')
    tmp = []
    for line in fr.readlines():
        tmp.append(line.strip().split('\t'))
    data[k] = tmp

/home/xiaobu-semantic-matching-2021-master/tcdata/bq_corpus/train.tsv
/home/xiaobu-semantic-matching-2021-master/tcdata/lcqmc/train.tsv
/home/xiaobu-semantic-matching-2021-master/tcdata/paws-x-zh/train.tsv


### 翻译数据

In [None]:
batch_size = 100
for k, d in data.items():
    fw_ze = open(os.path.join(data_dir, k, 'zh_en_aug_train.tsv'), 'a', encoding='utf8')
    fw_zd = open(os.path.join(data_dir, k, 'zh_de_aug_train.tsv'), 'a', encoding='utf8')
    n = int(len(d) / batch_size) + 1
    for i in tqdm(range(n)):
        texts = d[i * batch_size: (i + 1) * batch_size]
        texts_a, texts_b, labels = zip(*texts)
        if len(texts) == 0:
            break
        
        zh_en_texts_a = batch_translation(texts_a, tokenizer_zh_en, model_zh_en, tokenizer_en_zh, model_en_zh)
        zh_de_texts_a = batch_translation(texts_a, tokenizer_zh_de, model_zh_de, tokenizer_de_zh, model_de_zh)
        zh_en_texts_b = batch_translation(texts_b, tokenizer_zh_en, model_zh_en, tokenizer_en_zh, model_en_zh)
        zh_de_texts_b = batch_translation(texts_b, tokenizer_zh_de, model_zh_de, tokenizer_de_zh, model_de_zh)
        zh_en_texts = list(zip(zh_en_texts_a, zh_en_texts_b, labels))
        zh_de_texts = list(zip(zh_de_texts_a, zh_de_texts_b, labels))
        zh_en_texts = list(map(lambda x: '\t'.join(x) + '\n', zh_en_texts))
        zh_de_texts = list(map(lambda x: '\t'.join(x) + '\n', zh_de_texts))
        
        fw_ze.writelines(zh_en_texts)
        fw_zd.writelines(zh_de_texts)

  2%|▏         | 16/1001 [03:23<3:26:29, 12.58s/it]