In [1]:
# Check and display server config
import server_config

Number of GPUs: 1


In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
# Configurations

model_name = "facebook/mbart-large-50-many-to-many-mmt"
model_cache_dir = "../model"

src_lang = "ja_XX"
target_lang = "en_XX"

dataset_name = "ryo0634/bsd_ja_en"
dataset_cache_dir = "../data"

translations_save = "../data/facebook_mbart.parquet"

In [4]:
# Load model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=model_cache_dir)
tokenizer.src_lang = src_lang

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=model_cache_dir)
model.to(device)

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [5]:
# Load datasets
from datasets import load_dataset, concatenate_datasets

ds = load_dataset(dataset_name, cache_dir=dataset_cache_dir)

In [6]:
# Merge datasets into one, remove unnecessary columns
train = ds["train"]
test = ds["test"]
validation = ds["validation"]

ds_merged = concatenate_datasets([train, test, validation])
ds_merged = ds_merged.filter(lambda example: example['original_language'] == 'ja')
ds_merged = ds_merged.remove_columns(['id', 'tag', 'title', 'original_language', 'no', 'en_speaker', 'ja_speaker'])

In [7]:
print(f"Number of sentences in Japanese: {len(ds_merged)}")

Number of sentences in Japanese: 12106


In [8]:
# Define tokenize functions
from datasets import Dataset

def tokenize_ja(example):
    return tokenizer(example["ja_sentence"], truncation=True, max_length=512, padding="max_length", return_tensors="pt")

def tokenize_prepare(dataset: Dataset) -> Dataset:
  tokenized_dataset = dataset.map(tokenize_ja, batched=True)
  tokenized_dataset.set_format("torch")
  return tokenized_dataset

In [9]:
ds_tokenized = tokenize_prepare(ds_merged)

Map:   0%|          | 0/12106 [00:00<?, ? examples/s]

In [10]:
# Let the magic begin
def generate_translation(batch):
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    outputs = model.generate(input_ids=input_ids, 
                             attention_mask=attention_mask, 
                             forced_bos_token_id=tokenizer.lang_code_to_id[target_lang])
    batch["translations"] = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    return batch

translated_dataset = ds_tokenized.map(generate_translation, batched=True, batch_size=64)

Map:   0%|          | 0/12106 [00:00<?, ? examples/s]

In [16]:
# A qualitative test

example = translated_dataset[1025]

print(f"Source sentence: {example['ja_sentence']}")
print(f"Translated sentence: {example['translations']}")
print(f"Reference sentence: {example['en_sentence']}")

Source sentence: 菓子パンやお茶なんかは買っておいてもいいですね。
Translated sentence: You can buy some sweet bread and tea.
Reference sentence: We can go buy some pastries and tea today.


In [12]:
# Cleaning jobs
translated_dataset = translated_dataset.remove_columns(['input_ids', 'attention_mask'])

In [13]:
# Save tranlated dataset
translated_dataset.to_parquet(translations_save)

Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

2166787