In [3]:
from transformers import MarianMTModel, MarianTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset
import torch
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Load the pre-trained model and tokenizer
model_name = "Helsinki-NLP/opus-mt-ja-en"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)


data = load_dataset("NilanE/ParallelFiction-Ja_En-100k", split="train")

dataset = data.train_test_split(test_size=0.1, seed=42)
train_data = dataset['train']
test_data = dataset['test']

def preprocess_function(examples):
    # Extract Japanese source text and English target text
    inputs = examples['src']  # Japanese text
    targets = examples['trg']  # English text

    # Tokenize the source text
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        truncation=True,
        padding="max_length"
    )

    # Tokenize the target text as labels
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=512,
            truncation=True,
            padding="max_length"
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Preprocess the data
tokenized_train_data = train_data.map(preprocess_function, batched=True)

tokenized_train_data.save_to_disk("./tokenized_data")


Saving the dataset (6/6 shards): 100%|██████████| 95443/95443 [00:02<00:00, 32164.67 examples/s]


In [5]:
# Define training arguments with evaluation disabled
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="no",  # Disable evaluation
    save_strategy="steps",  # Save checkpoints periodically
    save_steps=500,  # Save a checkpoint every 500 steps
    save_total_limit=3,  # Keep only the last 3 checkpoints
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    weight_decay=0.01,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),  # Use FP16 if a GPU is available
    logging_dir="./logs",
)

# Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,  # Training data
    tokenizer=tokenizer,
)


  trainer = Seq2SeqTrainer(


In [None]:
# Fine-tune the model
trainer.train()

# Save the fine-tuned model and tokenizer
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")


 42%|████▏     | 14896/35793 [3:34:55<211:27:27, 36.43s/it]

In [7]:
# Fine-tune the model
trainer.train(resume_from_checkpoint=True)

# Save the fine-tuned model and tokenizer
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].
  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
  checkpoint_rng_state = torch.load(rng_file)


{'loss': 2.1695, 'grad_norm': 2.804398775100708, 'learning_rate': 1.1622384265079764e-05, 'epoch': 1.26}


 43%|████▎     | 15500/35793 [16:15<4:43:53,  1.19it/s]

{'loss': 2.1699, 'grad_norm': 2.7245330810546875, 'learning_rate': 1.1343000027938424e-05, 'epoch': 1.3}


 45%|████▍     | 16000/35793 [23:22<4:25:46,  1.24it/s]

{'loss': 2.1818, 'grad_norm': 3.016148090362549, 'learning_rate': 1.1063615790797084e-05, 'epoch': 1.34}


 46%|████▌     | 16500/35793 [30:29<4:29:14,  1.19it/s]

{'loss': 2.1462, 'grad_norm': 2.8308584690093994, 'learning_rate': 1.0784231553655745e-05, 'epoch': 1.38}


 47%|████▋     | 17000/35793 [37:44<4:43:37,  1.10it/s]

{'loss': 2.1336, 'grad_norm': 2.7448885440826416, 'learning_rate': 1.0505406084988685e-05, 'epoch': 1.42}


 49%|████▉     | 17500/35793 [44:52<4:14:43,  1.20it/s]

{'loss': 2.1376, 'grad_norm': 2.608849048614502, 'learning_rate': 1.0226580616321628e-05, 'epoch': 1.47}


 50%|█████     | 18000/35793 [51:50<4:07:32,  1.20it/s]

{'loss': 2.1248, 'grad_norm': 3.023585081100464, 'learning_rate': 9.947196379180287e-06, 'epoch': 1.51}


 52%|█████▏    | 18500/35793 [58:56<3:58:44,  1.21it/s]

{'loss': 2.1282, 'grad_norm': 2.806114912033081, 'learning_rate': 9.667812142038947e-06, 'epoch': 1.55}


 53%|█████▎    | 19000/35793 [1:06:03<3:43:46,  1.25it/s]

{'loss': 2.1038, 'grad_norm': 2.6417856216430664, 'learning_rate': 9.388427904897607e-06, 'epoch': 1.59}


 54%|█████▍    | 19500/35793 [1:13:06<3:54:52,  1.16it/s]

{'loss': 2.103, 'grad_norm': 2.620089054107666, 'learning_rate': 9.109043667756265e-06, 'epoch': 1.63}


 56%|█████▌    | 20000/35793 [1:20:11<3:31:32,  1.24it/s]

{'loss': 2.0912, 'grad_norm': 2.9677674770355225, 'learning_rate': 8.829659430614926e-06, 'epoch': 1.68}


 57%|█████▋    | 20500/35793 [1:27:15<3:31:13,  1.21it/s]

{'loss': 2.0821, 'grad_norm': 2.5534780025482178, 'learning_rate': 8.550275193473586e-06, 'epoch': 1.72}


 59%|█████▊    | 21000/35793 [1:34:19<3:26:34,  1.19it/s]

{'loss': 2.0814, 'grad_norm': 2.926218032836914, 'learning_rate': 8.270890956332244e-06, 'epoch': 1.76}


 60%|██████    | 21500/35793 [1:41:27<3:20:08,  1.19it/s]

{'loss': 2.0825, 'grad_norm': 2.630681037902832, 'learning_rate': 7.991506719190904e-06, 'epoch': 1.8}


 61%|██████▏   | 22000/35793 [1:48:27<3:06:31,  1.23it/s]

{'loss': 2.0786, 'grad_norm': 2.7724485397338867, 'learning_rate': 7.712681250523846e-06, 'epoch': 1.84}


 63%|██████▎   | 22500/35793 [1:55:46<2:06:30,  1.75it/s]

{'loss': 2.0631, 'grad_norm': 2.579216480255127, 'learning_rate': 7.433855781856788e-06, 'epoch': 1.89}


 64%|██████▍   | 23000/35793 [2:00:30<1:57:29,  1.81it/s]

{'loss': 2.0617, 'grad_norm': 2.6959280967712402, 'learning_rate': 7.154471544715448e-06, 'epoch': 1.93}


 66%|██████▌   | 23500/35793 [2:05:13<1:53:22,  1.81it/s]

{'loss': 2.059, 'grad_norm': 2.834404230117798, 'learning_rate': 6.875087307574107e-06, 'epoch': 1.97}


 67%|██████▋   | 24000/35793 [2:09:55<1:50:44,  1.77it/s]

{'loss': 2.0463, 'grad_norm': 2.6310675144195557, 'learning_rate': 6.595703070432767e-06, 'epoch': 2.01}


 68%|██████▊   | 24500/35793 [2:14:38<1:39:37,  1.89it/s]

{'loss': 2.0497, 'grad_norm': 2.972698450088501, 'learning_rate': 6.316318833291426e-06, 'epoch': 2.05}


 70%|██████▉   | 25000/35793 [2:19:20<1:43:00,  1.75it/s]

{'loss': 2.0155, 'grad_norm': 2.623328924179077, 'learning_rate': 6.036934596150086e-06, 'epoch': 2.1}


 71%|███████   | 25500/35793 [2:24:06<1:35:22,  1.80it/s]

{'loss': 2.0197, 'grad_norm': 2.651944160461426, 'learning_rate': 5.757550359008745e-06, 'epoch': 2.14}


 73%|███████▎  | 26000/35793 [2:28:46<1:33:18,  1.75it/s]

{'loss': 2.013, 'grad_norm': 2.514033317565918, 'learning_rate': 5.478724890341688e-06, 'epoch': 2.18}


 74%|███████▍  | 26500/35793 [2:33:27<1:26:46,  1.78it/s]

{'loss': 2.0212, 'grad_norm': 2.7133944034576416, 'learning_rate': 5.1993406532003465e-06, 'epoch': 2.22}


 75%|███████▌  | 27000/35793 [2:38:10<1:20:44,  1.81it/s]

{'loss': 2.02, 'grad_norm': 2.6151833534240723, 'learning_rate': 4.919956416059007e-06, 'epoch': 2.26}


 77%|███████▋  | 27500/35793 [2:42:51<56:15,  2.46it/s]  

{'loss': 2.0149, 'grad_norm': 2.805305242538452, 'learning_rate': 4.640572178917666e-06, 'epoch': 2.3}


 78%|███████▊  | 28000/35793 [2:47:27<1:10:19,  1.85it/s]

{'loss': 2.0082, 'grad_norm': 2.6519854068756104, 'learning_rate': 4.361187941776325e-06, 'epoch': 2.35}


 80%|███████▉  | 28500/35793 [2:52:05<1:07:19,  1.81it/s]

{'loss': 2.0096, 'grad_norm': 2.5853476524353027, 'learning_rate': 4.082362473109268e-06, 'epoch': 2.39}


 81%|████████  | 29000/35793 [2:56:43<1:00:55,  1.86it/s]

{'loss': 2.0116, 'grad_norm': 2.951279878616333, 'learning_rate': 3.8029782359679268e-06, 'epoch': 2.43}


 82%|████████▏ | 29500/35793 [3:01:21<58:55,  1.78it/s]  

{'loss': 2.006, 'grad_norm': 2.7072432041168213, 'learning_rate': 3.5235939988265865e-06, 'epoch': 2.47}


 84%|████████▍ | 30000/35793 [3:05:57<51:43,  1.87it/s]  

{'loss': 2.0026, 'grad_norm': 2.556507110595703, 'learning_rate': 3.2442097616852458e-06, 'epoch': 2.51}


 85%|████████▌ | 30500/35793 [3:10:34<48:50,  1.81it/s]  

{'loss': 1.9802, 'grad_norm': 2.8558201789855957, 'learning_rate': 2.9648255245439055e-06, 'epoch': 2.56}


 87%|████████▋ | 31000/35793 [3:15:11<46:04,  1.73it/s]  

{'loss': 1.9977, 'grad_norm': 2.7457668781280518, 'learning_rate': 2.685441287402565e-06, 'epoch': 2.6}


 88%|████████▊ | 31500/35793 [3:19:47<38:52,  1.84it/s]  

{'loss': 2.0016, 'grad_norm': 2.637615442276001, 'learning_rate': 2.4060570502612245e-06, 'epoch': 2.64}


 89%|████████▉ | 32000/35793 [3:24:24<35:46,  1.77it/s]  

{'loss': 1.9958, 'grad_norm': 2.698451519012451, 'learning_rate': 2.126672813119884e-06, 'epoch': 2.68}


 91%|█████████ | 32500/35793 [3:29:01<29:11,  1.88it/s]

{'loss': 2.007, 'grad_norm': 2.9999020099639893, 'learning_rate': 1.8472885759785433e-06, 'epoch': 2.72}


 92%|█████████▏| 33000/35793 [3:33:38<25:26,  1.83it/s]

{'loss': 2.0048, 'grad_norm': 2.849747657775879, 'learning_rate': 1.5679043388372028e-06, 'epoch': 2.77}


 94%|█████████▎| 33500/35793 [3:38:18<21:43,  1.76it/s]

{'loss': 1.9991, 'grad_norm': 2.555652618408203, 'learning_rate': 1.289078870170145e-06, 'epoch': 2.81}


 95%|█████████▍| 34000/35793 [3:42:56<17:05,  1.75it/s]

{'loss': 1.9842, 'grad_norm': 2.5811784267425537, 'learning_rate': 1.0096946330288046e-06, 'epoch': 2.85}


 96%|█████████▋| 34500/35793 [3:47:32<11:38,  1.85it/s]

{'loss': 1.9881, 'grad_norm': 2.5673716068267822, 'learning_rate': 7.303103958874641e-07, 'epoch': 2.89}


 98%|█████████▊| 35000/35793 [3:52:04<07:16,  1.82it/s]

{'loss': 1.9847, 'grad_norm': 2.497488498687744, 'learning_rate': 4.509261587461236e-07, 'epoch': 2.93}


 99%|█████████▉| 35500/35793 [3:56:40<02:39,  1.83it/s]

{'loss': 1.9825, 'grad_norm': 2.5490450859069824, 'learning_rate': 1.7154192160478308e-07, 'epoch': 2.98}


100%|██████████| 35793/35793 [3:59:23<00:00,  2.49it/s]


{'train_runtime': 14363.9178, 'train_samples_per_second': 19.934, 'train_steps_per_second': 2.492, 'train_loss': 1.2199590257369632, 'epoch': 3.0}


('./fine_tuned_model\\tokenizer_config.json',
 './fine_tuned_model\\special_tokens_map.json',
 './fine_tuned_model\\vocab.json',
 './fine_tuned_model\\source.spm',
 './fine_tuned_model\\target.spm',
 './fine_tuned_model\\added_tokens.json')

In [6]:
print(torch.cuda.is_available())

True
