In [None]:
from transformers import Seq2SeqTrainer
import torch
import Levenshtein as Lev          
class EditDistanceTrainer(Seq2SeqTrainer):
    def __init__(self, *args, alpha=0.6, tokenizer=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self._tok = tokenizer           

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs  = model(**inputs)
        ce_loss  = outputs.loss

        with torch.no_grad():
            preds = outputs.logits.argmax(dim=-1)
            pred_str  = self._tok.batch_decode(preds, skip_special_tokens=True)
            label_str = self._tok.batch_decode(inputs["labels"], skip_special_tokens=True)

            ed = [Lev.distance(p, t)/max(len(t),1) for p, t in zip(pred_str, label_str)]
            edit_loss = torch.tensor(ed, device=ce_loss.device).mean()

        final_loss = ce_loss + self.alpha * edit_loss
        return (final_loss, outputs) if return_outputs else final_loss



In [None]:
import numpy as np
import transformers
from datasets import load_dataset
import evaluate
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration,DataCollatorForSeq2Seq,Seq2SeqTrainingArguments

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

def preprocess_function(example):
    model_input = tokenizer(
        example["noisy"],
        max_length=40,
        truncation=True,
        padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["clean"],
            max_length=40,
            truncation=True,
            padding="max_length"
        )

    model_input["labels"] = labels["input_ids"]
    return model_input

train_data = load_dataset("json", data_files="dataset/128_train_snli.jsonl", split="train")
train_data = train_data.map(preprocess_function, batched=True,remove_columns=["noisy", "clean"])
val_data = load_dataset("json",data_files="dataset/128_eval_snli.jsonl",split="train")
val_data = val_data.map(preprocess_function, batched=True, remove_columns=["noisy", "clean"])
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
training_args = Seq2SeqTrainingArguments(
    remove_unused_columns=False,
    output_dir='learn_rate_4e5',
    per_device_train_batch_size=128,
    eval_strategy="epoch",
    num_train_epochs=7,
    warmup_ratio=0.05,
    learning_rate=2e-5,
    report_to="none",

    weight_decay=0.01,
    logging_steps=500,
    label_smoothing_factor=0.1,
    predict_with_generate=False,
    save_strategy="epoch",  
)

trainer = EditDistanceTrainer(
    model=model,
    args=training_args,
    eval_dataset=val_data,
    train_dataset=train_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
trainer.train()

In [None]:
save_path = "./denosing-v4-base"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)