In [4]:
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments

In [5]:
def load_dataset(file_path, tokenizer, block_size = 128):
    dataset = TextDataset(
        tokenizer = tokenizer,
        file_path = file_path,
        block_size = block_size,
    )
    return dataset


def load_data_collator(tokenizer, mlm = False):
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=mlm,
    )
    return data_collator


def train(train_file_path,model_name,
          output_dir,
          overwrite_output_dir,
          per_device_train_batch_size,
          num_train_epochs,
          save_steps=500):
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
  train_dataset = load_dataset(train_file_path, tokenizer)
  data_collator = load_data_collator(tokenizer)

  tokenizer.save_pretrained(output_dir)
      
  model = GPT2LMHeadModel.from_pretrained(model_name)

  model.save_pretrained(output_dir)

  training_args = TrainingArguments(
          output_dir=output_dir,
          overwrite_output_dir=overwrite_output_dir,
          per_device_train_batch_size=per_device_train_batch_size,
          num_train_epochs=num_train_epochs,
          save_strategy="steps",
          save_steps=save_steps,
          save_total_limit=10,
          resume_from_checkpoint=True,
          #no_cuda=True,
      )

  trainer = Trainer(
          model=model,
          args=training_args,
          data_collator=data_collator,
          train_dataset=train_dataset,
  )
      
  trainer.train()
  trainer.save_model()

In [6]:
train(
    train_file_path="Corpus.txt",
    model_name='gpt2',
    output_dir='model/',
    overwrite_output_dir=False,
    per_device_train_batch_size=4,
    num_train_epochs=5.0,
    save_steps=500
)

***** Running training *****
  Num examples = 98303
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 122880
  Number of trainable parameters = 124439808


  0%|          | 0/122880 [00:00<?, ?it/s]

Saving model checkpoint to model/checkpoint-100
Configuration saved in model/checkpoint-100\config.json
Configuration saved in model/checkpoint-100\generation_config.json
Model weights saved in model/checkpoint-100\pytorch_model.bin
Saving model checkpoint to model/checkpoint-200
Configuration saved in model/checkpoint-200\config.json
Configuration saved in model/checkpoint-200\generation_config.json
Model weights saved in model/checkpoint-200\pytorch_model.bin
Saving model checkpoint to model/checkpoint-300
Configuration saved in model/checkpoint-300\config.json
Configuration saved in model/checkpoint-300\generation_config.json
Model weights saved in model/checkpoint-300\pytorch_model.bin
Saving model checkpoint to model/checkpoint-400
Configuration saved in model/checkpoint-400\config.json
Configuration saved in model/checkpoint-400\generation_config.json
Model weights saved in model/checkpoint-400\pytorch_model.bin
Saving model checkpoint to model/checkpoint-500
Configuration saved 

{'loss': 2.8102, 'learning_rate': 4.979654947916667e-05, 'epoch': 0.02}


Model weights saved in model/checkpoint-500\pytorch_model.bin
Saving model checkpoint to model/checkpoint-600
Configuration saved in model/checkpoint-600\config.json
Configuration saved in model/checkpoint-600\generation_config.json
Model weights saved in model/checkpoint-600\pytorch_model.bin
Saving model checkpoint to model/checkpoint-700
Configuration saved in model/checkpoint-700\config.json
Configuration saved in model/checkpoint-700\generation_config.json
Model weights saved in model/checkpoint-700\pytorch_model.bin
Saving model checkpoint to model/checkpoint-800
Configuration saved in model/checkpoint-800\config.json
Configuration saved in model/checkpoint-800\generation_config.json
Model weights saved in model/checkpoint-800\pytorch_model.bin
Saving model checkpoint to model/checkpoint-900
Configuration saved in model/checkpoint-900\config.json
Configuration saved in model/checkpoint-900\generation_config.json
Model weights saved in model/checkpoint-900\pytorch_model.bin
Saving

{'loss': 2.2139, 'learning_rate': 4.959309895833333e-05, 'epoch': 0.04}


Model weights saved in model/checkpoint-1000\pytorch_model.bin
Saving model checkpoint to model/checkpoint-1100
Configuration saved in model/checkpoint-1100\config.json
Configuration saved in model/checkpoint-1100\generation_config.json
Model weights saved in model/checkpoint-1100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-1200
Configuration saved in model/checkpoint-1200\config.json
Configuration saved in model/checkpoint-1200\generation_config.json
Model weights saved in model/checkpoint-1200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-1300
Configuration saved in model/checkpoint-1300\config.json
Configuration saved in model/checkpoint-1300\generation_config.json
Model weights saved in model/checkpoint-1300\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-300] due to args.save_tota

{'loss': 1.9944, 'learning_rate': 4.93896484375e-05, 'epoch': 0.06}


Model weights saved in model/checkpoint-1500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-1600
Configuration saved in model/checkpoint-1600\config.json
Configuration saved in model/checkpoint-1600\generation_config.json
Model weights saved in model/checkpoint-1600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-1700
Configuration saved in model/checkpoint-1700\config.json
Configuration saved in model/checkpoint-1700\generation_config.json
Model weights saved in model/checkpoint-1700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-1800
Configuration saved in model/checkpoint-1800\config.json
Configuration saved in model/checkpoint-1800\generation_config.json
Model weights saved in model/checkpoint-1800\pytorch_mo

{'loss': 1.89, 'learning_rate': 4.918619791666667e-05, 'epoch': 0.08}


Model weights saved in model/checkpoint-2000\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-1000] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-2100
Configuration saved in model/checkpoint-2100\config.json
Configuration saved in model/checkpoint-2100\generation_config.json
Model weights saved in model/checkpoint-2100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-1100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-2200
Configuration saved in model/checkpoint-2200\config.json
Configuration saved in model/checkpoint-2200\generation_config.json
Model weights saved in model/checkpoint-2200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-1200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-2300
Configuration saved in model/checkpoint-2300\config.json
Configuration saved in model/checkpoint-2300\generation_config.json
Model weights saved in model/checkpoint-2300\pytorch

{'loss': 1.7718, 'learning_rate': 4.898274739583333e-05, 'epoch': 0.1}


Model weights saved in model/checkpoint-2500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-1500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-2600
Configuration saved in model/checkpoint-2600\config.json
Configuration saved in model/checkpoint-2600\generation_config.json
Model weights saved in model/checkpoint-2600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-1600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-2700
Configuration saved in model/checkpoint-2700\config.json
Configuration saved in model/checkpoint-2700\generation_config.json
Model weights saved in model/checkpoint-2700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-1700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-2800
Configuration saved in model/checkpoint-2800\config.json
Configuration saved in model/checkpoint-2800\generation_config.json
Model weights saved in model/checkpoint-2800\pytorch

{'loss': 1.7017, 'learning_rate': 4.8779296875e-05, 'epoch': 0.12}


Model weights saved in model/checkpoint-3000\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-2000] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-3100
Configuration saved in model/checkpoint-3100\config.json
Configuration saved in model/checkpoint-3100\generation_config.json
Model weights saved in model/checkpoint-3100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-2100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-3200
Configuration saved in model/checkpoint-3200\config.json
Configuration saved in model/checkpoint-3200\generation_config.json
Model weights saved in model/checkpoint-3200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-2200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-3300
Configuration saved in model/checkpoint-3300\config.json
Configuration saved in model/checkpoint-3300\generation_config.json
Model weights saved in model/checkpoint-3300\pytorch

{'loss': 1.636, 'learning_rate': 4.857584635416667e-05, 'epoch': 0.14}


Model weights saved in model/checkpoint-3500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-2500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-3600
Configuration saved in model/checkpoint-3600\config.json
Configuration saved in model/checkpoint-3600\generation_config.json
Model weights saved in model/checkpoint-3600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-2600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-3700
Configuration saved in model/checkpoint-3700\config.json
Configuration saved in model/checkpoint-3700\generation_config.json
Model weights saved in model/checkpoint-3700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-2700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-3800
Configuration saved in model/checkpoint-3800\config.json
Configuration saved in model/checkpoint-3800\generation_config.json
Model weights saved in model/checkpoint-3800\pytorch

{'loss': 1.5902, 'learning_rate': 4.837239583333333e-05, 'epoch': 0.16}


Model weights saved in model/checkpoint-4000\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-3000] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-4100
Configuration saved in model/checkpoint-4100\config.json
Configuration saved in model/checkpoint-4100\generation_config.json
Model weights saved in model/checkpoint-4100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-3100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-4200
Configuration saved in model/checkpoint-4200\config.json
Configuration saved in model/checkpoint-4200\generation_config.json
Model weights saved in model/checkpoint-4200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-3200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-4300
Configuration saved in model/checkpoint-4300\config.json
Configuration saved in model/checkpoint-4300\generation_config.json
Model weights saved in model/checkpoint-4300\pytorch

{'loss': 1.5665, 'learning_rate': 4.81689453125e-05, 'epoch': 0.18}


Model weights saved in model/checkpoint-4500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-3500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-4600
Configuration saved in model/checkpoint-4600\config.json
Configuration saved in model/checkpoint-4600\generation_config.json
Model weights saved in model/checkpoint-4600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-3600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-4700
Configuration saved in model/checkpoint-4700\config.json
Configuration saved in model/checkpoint-4700\generation_config.json
Model weights saved in model/checkpoint-4700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-3700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-4800
Configuration saved in model/checkpoint-4800\config.json
Configuration saved in model/checkpoint-4800\generation_config.json
Model weights saved in model/checkpoint-4800\pytorch

{'loss': 1.5289, 'learning_rate': 4.796549479166667e-05, 'epoch': 0.2}


Model weights saved in model/checkpoint-5000\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-4000] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-5100
Configuration saved in model/checkpoint-5100\config.json
Configuration saved in model/checkpoint-5100\generation_config.json
Model weights saved in model/checkpoint-5100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-4100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-5200
Configuration saved in model/checkpoint-5200\config.json
Configuration saved in model/checkpoint-5200\generation_config.json
Model weights saved in model/checkpoint-5200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-4200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-5300
Configuration saved in model/checkpoint-5300\config.json
Configuration saved in model/checkpoint-5300\generation_config.json
Model weights saved in model/checkpoint-5300\pytorch

{'loss': 1.5005, 'learning_rate': 4.776204427083333e-05, 'epoch': 0.22}


Model weights saved in model/checkpoint-5500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-4500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-5600
Configuration saved in model/checkpoint-5600\config.json
Configuration saved in model/checkpoint-5600\generation_config.json
Model weights saved in model/checkpoint-5600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-4600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-5700
Configuration saved in model/checkpoint-5700\config.json
Configuration saved in model/checkpoint-5700\generation_config.json
Model weights saved in model/checkpoint-5700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-4700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-5800
Configuration saved in model/checkpoint-5800\config.json
Configuration saved in model/checkpoint-5800\generation_config.json
Model weights saved in model/checkpoint-5800\pytorch

{'loss': 1.4415, 'learning_rate': 4.755859375e-05, 'epoch': 0.24}


Model weights saved in model/checkpoint-6000\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-5000] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-6100
Configuration saved in model/checkpoint-6100\config.json
Configuration saved in model/checkpoint-6100\generation_config.json
Model weights saved in model/checkpoint-6100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-5100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-6200
Configuration saved in model/checkpoint-6200\config.json
Configuration saved in model/checkpoint-6200\generation_config.json
Model weights saved in model/checkpoint-6200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-5200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-6300
Configuration saved in model/checkpoint-6300\config.json
Configuration saved in model/checkpoint-6300\generation_config.json
Model weights saved in model/checkpoint-6300\pytorch

{'loss': 1.4653, 'learning_rate': 4.735514322916667e-05, 'epoch': 0.26}


Model weights saved in model/checkpoint-6500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-5500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-6600
Configuration saved in model/checkpoint-6600\config.json
Configuration saved in model/checkpoint-6600\generation_config.json
Model weights saved in model/checkpoint-6600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-5600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-6700
Configuration saved in model/checkpoint-6700\config.json
Configuration saved in model/checkpoint-6700\generation_config.json
Model weights saved in model/checkpoint-6700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-5700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-6800
Configuration saved in model/checkpoint-6800\config.json
Configuration saved in model/checkpoint-6800\generation_config.json
Model weights saved in model/checkpoint-6800\pytorch

{'loss': 1.4042, 'learning_rate': 4.715169270833333e-05, 'epoch': 0.28}


Model weights saved in model/checkpoint-7000\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-6000] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-7100
Configuration saved in model/checkpoint-7100\config.json
Configuration saved in model/checkpoint-7100\generation_config.json
Model weights saved in model/checkpoint-7100\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-6100] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-7200
Configuration saved in model/checkpoint-7200\config.json
Configuration saved in model/checkpoint-7200\generation_config.json
Model weights saved in model/checkpoint-7200\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-6200] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-7300
Configuration saved in model/checkpoint-7300\config.json
Configuration saved in model/checkpoint-7300\generation_config.json
Model weights saved in model/checkpoint-7300\pytorch

{'loss': 1.3993, 'learning_rate': 4.69482421875e-05, 'epoch': 0.31}


Model weights saved in model/checkpoint-7500\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-6500] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-7600
Configuration saved in model/checkpoint-7600\config.json
Configuration saved in model/checkpoint-7600\generation_config.json
Model weights saved in model/checkpoint-7600\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-6600] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-7700
Configuration saved in model/checkpoint-7700\config.json
Configuration saved in model/checkpoint-7700\generation_config.json
Model weights saved in model/checkpoint-7700\pytorch_model.bin
Deleting older checkpoint [model\checkpoint-6700] due to args.save_total_limit
Saving model checkpoint to model/checkpoint-7800
Configuration saved in model/checkpoint-7800\config.json
Configuration saved in model/checkpoint-7800\generation_config.json
Model weights saved in model/checkpoint-7800\pytorch