mBART50 - bilingual fine-tuning (SWRC)

In [None]:
# Label data
# Train data
#Read the file
with open("ko_train_SWRC.txt", "r", encoding="utf-8") as ko_train_SWRC:
    ko_sentences = ko_train_SWRC.readlines()

with open("ch_train_SWRC.txt", "r", encoding="utf-8") as zh_train_SWRC:
    zh_sentences = zh_train_SWRC.readlines()

#Remove \n
ko_sentences = [line.strip() for line in ko_sentences]
zh_sentences = [line.strip() for line in zh_sentences]

#Merge ko-sentence and zh-sentence into a df
merged_df = pd.DataFrame({
    'source': ko_sentences,
    'target': zh_sentences
})

#Save as a new CSV
merged_df.to_csv("ko_zh_train_dataset_SWRC.csv", index=False, encoding="utf-8")

# Validation data
#Read the file
with open("ko_vali_SWRC.txt", "r", encoding="utf-8") as ko_validation:
    ko_sentences = ko_validation.readlines()

with open("ch_vali_SWRC.txt", "r", encoding="utf-8") as zh_validation:
    zh_sentences = zh_validation.readlines()

#Remove \n
ko_sentences = [line.strip() for line in ko_sentences]
zh_sentences = [line.strip() for line in zh_sentences]

#Check if row number is the same
if len(ko_sentences) != len(zh_sentences):
    raise ValueError("Not the same.")

#Merge ko-sentence and zh-sentence into a df
merged_df = pd.DataFrame({
    'source': ko_sentences,
    'target': zh_sentences
})

#Save as a new CSV
merged_df.to_csv("ko_zh_validation_dataset_SWRC.csv", index=False, encoding="utf-8")

# Test data
#Read the file
with open("ko_test_SWRC.txt", "r", encoding="utf-8") as ko_test:
    ko_sentences = ko_test.readlines()

with open("ch_test_SWRC.txt", "r", encoding="utf-8") as zh_test:
    zh_sentences = zh_test.readlines()

#Remove \n
ko_sentences = [line.strip() for line in ko_sentences]
zh_sentences = [line.strip() for line in zh_sentences]

#Check if row number is the same
if len(ko_sentences) != len(zh_sentences):
    raise ValueError("Not the same.")

#Merge ko-sentence and zh-sentence into a df
merged_df = pd.DataFrame({
    'source': ko_sentences,
    'target': zh_sentences
})

#Save as a new CSV
merged_df.to_csv("ko_zh_test_dataset_SWRC.csv", index=False, encoding="utf-8")

In [None]:
# Fine-tuning
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration, TrainingArguments, Trainer, EarlyStoppingCallback
import os
from datasets import load_dataset
from safetensors.torch import safe_open
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Load datasets
data_files = {
    "train": "ko_zh_train_dataset_donga.csv",
    "validation": "ko_zh_validation_dataset_donga.csv",
    "test": "ko_zh_test_dataset_donga.csv"
}
dataset = load_dataset('csv', data_files=data_files)

# Load tokenizer
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# Load the base model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
empty_state_dict = {}
model.load_state_dict(empty_state_dict, strict=False)
model.to(device)

# Set source and target languages
tokenizer.src_lang = "ko_KR"  # source language
tokenizer.tgt_lang = "zh_CN"  # Chinese as target language

# Preprocessing function
def preprocess_function(examples):
    if 'source' not in examples or 'target' not in examples:
        raise ValueError("Missing 'source' or 'target' in examples.")

    inputs = [ex for ex in examples['source'] if ex is not None]
    targets = [ex for ex in examples['target'] if ex is not None]

    if not inputs or not targets:
        raise ValueError("Inputs or targets cannot be empty or None.")

    model_inputs = tokenizer(inputs, max_length=200, truncation=True, padding='max_length')
    labels = tokenizer(targets, max_length=200, truncation=True, padding='max_length')

    if labels is None or 'input_ids' not in labels:
        raise ValueError("Tokenization of targets failed, resulting in None or missing 'input_ids'.")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize the dataset
tokenized_datasets = dataset.map(preprocess_function, batched=True)

# Disable WandB
os.environ["WANDB_DISABLED"] = "true"

# Set output directory
output_dir = '/home/u542596/experiments/bilingual_fine_tune/SWRC'
os.makedirs(output_dir, exist_ok=True)

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    lr_scheduler_type="linear",
    warmup_steps=1500,
    seed=42,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
)

# Fine tune
train_results = trainer.train()