In [1]:
!pip install -q transformers sentencepiece datasets accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from transformers import MBartForConditionalGeneration, MBartTokenizer
import math
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import datasets
import os
import shutil
from google.colab import drive

root="/content/drive/MyDrive/"
# model_repo = "facebook/mbart-large-cc25"
model_repo = root + "model_base"
data_files={
    'train': root + 'data/train.csv',
}
dirs = {
    "cache_dir": root + "cache",
    "output_dir": root + "output",
    "logs_dir": root + "logs",
    "store_model": root + "store_model",
}
lang_config = {
    "src": "ja",
    "tgt": "en"
    # "src": "en",
    # "tgt": "vi"
}
lang_code = {
    "ja": "ja_XX",
    "en": "en_XX",
    "vi": "vi_VN",
}

In [3]:
drive.mount('/content/drive')
# %cd /content/drive/MyDrive/

Mounted at /content/drive


In [4]:
src = lang_config["src"]
tgt = lang_config["tgt"]
store_model=dirs["store_model"]
print([model_repo, store_model, src, tgt])

['/content/drive/MyDrive/model_base', '/content/drive/MyDrive/store_model', 'ja', 'en']


In [6]:
# Hyperparameters
batch_size = 2
learning_rate = 0.00003
epochs = 8
logging_steps=1000
save_steps=1000
split=0.9
save_strategy="epoch"

In [7]:
def create_folder():
    folders = list(dirs.values())
    for folder in folders:
        print(folder)
        if os.path.exists(folder):
            shutil.rmtree(folder)
        if not os.path.exists(folder):
            os.makedirs(folder)
create_folder()

/content/drive/MyDrive/cache
/content/drive/MyDrive/output
/content/drive/MyDrive/logs
/content/drive/MyDrive/store_model


In [8]:
def load_model():
    model = MBartForConditionalGeneration.from_pretrained(model_repo)
    model.save_pretrained(store_model)
    model = MBartForConditionalGeneration.from_pretrained(store_model)
    return model

def load_tokenizer():
    src = lang_config["src"]
    tgt = lang_config["tgt"]
    tokenizer = MBartTokenizer.from_pretrained(model_repo, src_lang=lang_code[src], tgt_lang=lang_code[tgt])
    tokenizer.save_pretrained(store_model)
    tokenizer = MBartTokenizer.from_pretrained(store_model)
    return tokenizer

def load_data_set() -> datasets.arrow_dataset.Dataset:
    dataset = datasets.load_dataset("csv", data_files=data_files, cache_dir=dirs["cache_dir"])
    return dataset["train"]

def create_dataset():
    lst_data = load_data_set()
    train_data = []
    for data in lst_data:
        dic = {}
        dic[src] = data[src]
        dic[tgt] = data[tgt]
        train_data.append({"translation": dic})
    return train_data

def create_data_train():
    raw_data = create_dataset()
    train_length = math.floor(split * len(raw_data))
    train_data = raw_data[:train_length]
    eval_data = raw_data[train_length:]
    print(f"raw_data size: {len(raw_data)}")
    print(f"train_data size: {len(train_data)}")
    print(f"eval_data size: {len(eval_data)}")
    return train_data, eval_data

In [9]:
def repair() -> Seq2SeqTrainer:
    train_data, eval_data = create_data_train()
    model = load_model()
    tokenizer = load_tokenizer()


    def data_collator(features: list):
        x = [f["translation"][src] for f in features]
        y = [f["translation"][tgt] for f in features]
        inputs = tokenizer(x, return_tensors="pt", padding='max_length', truncation=True, max_length=32)
        with tokenizer.as_target_tokenizer():
            inputs['labels'] = tokenizer(y, return_tensors="pt", padding='max_length', truncation=True, max_length=48)['input_ids']
        return inputs

    args = Seq2SeqTrainingArguments(output_dir=dirs["output_dir"],
                                    overwrite_output_dir =True,
                                    do_train=True,
                                    do_eval=True,
                                    evaluation_strategy="epoch",
                                    per_device_train_batch_size=batch_size,
                                    per_device_eval_batch_size=batch_size,
                                    save_steps=save_steps,
                                    save_strategy=save_strategy,
                                    learning_rate=learning_rate,
                                    logging_steps=logging_steps,
                                    num_train_epochs=epochs,
                                    remove_unused_columns=False,
                                )

    trainer = Seq2SeqTrainer(model=model,
                            args=args,
                            data_collator=data_collator,
                            tokenizer=tokenizer,
                            train_dataset=train_data,
                            eval_dataset=eval_data,
                            )
    return trainer

In [10]:
def run(trainer: Seq2SeqTrainer):
    output_dir=store_model
    trainer.train()
    trainer.save_model(output_dir)


def manual_test():
    output_dir=store_model
    model = MBartForConditionalGeneration.from_pretrained(output_dir)
    tokenizer = MBartTokenizer.from_pretrained(output_dir)

    sentence = "上部消化管の愁訴としては以下のものがある"
    inputs = tokenizer(sentence, return_tensors="pt")
    translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id[lang_code[lang_config["tgt"]]], early_stopping=True, max_length=48)
    pred = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
    print(type(translated_tokens))
    print(f"日本語 - {sentence}: English - {pred}")

In [11]:
trainer = repair()

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

raw_data size: 9471
train_data size: 8523
eval_data size: 948


In [12]:
run(trainer)
manual_test()



Epoch,Training Loss,Validation Loss
1,1.4235,4.723038
2,1.3755,4.795795
3,1.3531,4.772677




Epoch,Training Loss,Validation Loss
1,1.4235,4.723038
2,1.3755,4.795795
3,1.3531,4.772677
4,1.4112,4.767869
5,1.3995,4.803192
6,1.3847,4.819342
7,1.351,4.831667
8,1.2998,4.833263




ConnectionAbortedError: ignored