In [50]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity

In [13]:
dataset = load_dataset("amishshah/song_lyrics")
dataset = dataset["train"].shuffle(seed=42)
subset_size = 2000
dataset = dataset.select(range(subset_size))
train_test_dataset = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_dataset["train"]
val_dataset = train_test_dataset["test"]
#train_test_dataset = dataset["train"].train_test_split(test_size=0.1)
#train_dataset = train_test_dataset["train"]
#val_dataset = train_test_dataset["test"]

 99%|█████████▉| 99/100 [16:06:55<09:46, 586.02s/it]


In [3]:
#train_subset = train_dataset.select(range(100))
#val_subset = val_dataset.select(range(50))

In [15]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset

# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Ensure that tokenizer has padding token set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset("amishshah/song_lyrics")
dataset = dataset["train"].shuffle(seed=42)
subset_size = 2000
dataset = dataset.select(range(subset_size))
train_test_dataset = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_dataset["train"]
val_dataset = train_test_dataset["test"]

# Load and prepare dataset
#dataset = load_dataset("amishshah/song_lyrics")
#train_test_dataset = dataset["train"].train_test_split(test_size=0.1)
#train_dataset = train_test_dataset["train"].select(range(100))
#val_dataset = train_test_dataset["test"].select(range(100))

# Tokenize the data
def tokenize_function(examples):
    return tokenizer(examples['lyrics'], truncation=True, padding=True, max_length=512)

train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)

# Set training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
)

# Initialize Trainer
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# Train the model
trainer.train()

Map: 100%|██████████| 1800/1800 [00:02<00:00, 671.26 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 697.38 examples/s]
  1%|          | 10/1800 [00:11<33:08,  1.11s/it]

{'loss': 3.3518, 'grad_norm': 12.243512153625488, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}


  1%|          | 20/1800 [00:22<33:25,  1.13s/it]

{'loss': 3.1716, 'grad_norm': 11.873072624206543, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.04}


  2%|▏         | 30/1800 [00:33<33:09,  1.12s/it]

{'loss': 3.0935, 'grad_norm': 8.310800552368164, 'learning_rate': 3e-06, 'epoch': 0.07}


  2%|▏         | 40/1800 [00:45<32:58,  1.12s/it]

{'loss': 3.2439, 'grad_norm': 8.316095352172852, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.09}


  3%|▎         | 50/1800 [00:56<32:37,  1.12s/it]

{'loss': 3.1355, 'grad_norm': 7.406936168670654, 'learning_rate': 5e-06, 'epoch': 0.11}


  3%|▎         | 60/1800 [01:07<32:49,  1.13s/it]

{'loss': 3.1416, 'grad_norm': 6.7695841789245605, 'learning_rate': 6e-06, 'epoch': 0.13}


  4%|▍         | 70/1800 [01:18<32:55,  1.14s/it]

{'loss': 3.2561, 'grad_norm': 5.728172302246094, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.16}


  4%|▍         | 80/1800 [01:30<32:20,  1.13s/it]

{'loss': 3.0572, 'grad_norm': 6.720602035522461, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.18}


  5%|▌         | 90/1800 [01:41<31:54,  1.12s/it]

{'loss': 3.151, 'grad_norm': 6.607203006744385, 'learning_rate': 9e-06, 'epoch': 0.2}


  6%|▌         | 100/1800 [01:52<31:40,  1.12s/it]

{'loss': 3.1258, 'grad_norm': 5.623800754547119, 'learning_rate': 1e-05, 'epoch': 0.22}


  6%|▌         | 110/1800 [02:03<31:24,  1.11s/it]

{'loss': 3.3243, 'grad_norm': 6.0238471031188965, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.24}


  7%|▋         | 120/1800 [02:14<31:35,  1.13s/it]

{'loss': 3.1602, 'grad_norm': 5.168012619018555, 'learning_rate': 1.2e-05, 'epoch': 0.27}


  7%|▋         | 130/1800 [02:26<31:11,  1.12s/it]

{'loss': 3.1048, 'grad_norm': 5.342722415924072, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.29}


  8%|▊         | 140/1800 [02:37<30:53,  1.12s/it]

{'loss': 3.0437, 'grad_norm': 9.160844802856445, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.31}


  8%|▊         | 150/1800 [02:48<30:43,  1.12s/it]

{'loss': 2.7746, 'grad_norm': 8.860705375671387, 'learning_rate': 1.5e-05, 'epoch': 0.33}


  9%|▉         | 160/1800 [02:59<31:15,  1.14s/it]

{'loss': 2.9396, 'grad_norm': 6.640445232391357, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.36}


  9%|▉         | 170/1800 [03:11<30:31,  1.12s/it]

{'loss': 3.0284, 'grad_norm': 9.404908180236816, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.38}


 10%|█         | 180/1800 [03:22<30:10,  1.12s/it]

{'loss': 2.8889, 'grad_norm': 5.862174034118652, 'learning_rate': 1.8e-05, 'epoch': 0.4}


 11%|█         | 190/1800 [03:33<29:44,  1.11s/it]

{'loss': 2.8599, 'grad_norm': 7.435002326965332, 'learning_rate': 1.9e-05, 'epoch': 0.42}


 11%|█         | 200/1800 [03:44<29:45,  1.12s/it]

{'loss': 2.8519, 'grad_norm': 5.9550862312316895, 'learning_rate': 2e-05, 'epoch': 0.44}


 12%|█▏        | 210/1800 [03:55<29:35,  1.12s/it]

{'loss': 2.6749, 'grad_norm': 4.876107215881348, 'learning_rate': 2.1e-05, 'epoch': 0.47}


 12%|█▏        | 220/1800 [04:06<29:32,  1.12s/it]

{'loss': 2.9024, 'grad_norm': 5.954718589782715, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.49}


 13%|█▎        | 230/1800 [04:17<29:17,  1.12s/it]

{'loss': 3.1737, 'grad_norm': 8.84876537322998, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.51}


 13%|█▎        | 240/1800 [04:29<29:07,  1.12s/it]

{'loss': 3.0869, 'grad_norm': 5.278852462768555, 'learning_rate': 2.4e-05, 'epoch': 0.53}


 14%|█▍        | 250/1800 [04:40<28:38,  1.11s/it]

{'loss': 2.9119, 'grad_norm': 6.793298721313477, 'learning_rate': 2.5e-05, 'epoch': 0.56}


 14%|█▍        | 260/1800 [04:51<28:56,  1.13s/it]

{'loss': 2.9888, 'grad_norm': 7.746799945831299, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.58}


 15%|█▌        | 270/1800 [05:02<28:18,  1.11s/it]

{'loss': 3.2044, 'grad_norm': 6.75273323059082, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.6}


 16%|█▌        | 280/1800 [05:13<28:17,  1.12s/it]

{'loss': 2.7582, 'grad_norm': 4.632065296173096, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.62}


 16%|█▌        | 290/1800 [05:24<28:12,  1.12s/it]

{'loss': 2.8662, 'grad_norm': 5.399931907653809, 'learning_rate': 2.9e-05, 'epoch': 0.64}


 17%|█▋        | 300/1800 [05:36<28:19,  1.13s/it]

{'loss': 2.7516, 'grad_norm': 5.752292156219482, 'learning_rate': 3e-05, 'epoch': 0.67}


 17%|█▋        | 310/1800 [05:47<27:59,  1.13s/it]

{'loss': 2.9384, 'grad_norm': 5.288351535797119, 'learning_rate': 3.1e-05, 'epoch': 0.69}


 18%|█▊        | 320/1800 [05:58<27:36,  1.12s/it]

{'loss': 2.764, 'grad_norm': 5.855530261993408, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.71}


 18%|█▊        | 330/1800 [06:09<27:35,  1.13s/it]

{'loss': 3.0067, 'grad_norm': 7.513923645019531, 'learning_rate': 3.3e-05, 'epoch': 0.73}


 19%|█▉        | 340/1800 [06:20<27:30,  1.13s/it]

{'loss': 2.9004, 'grad_norm': 5.99768590927124, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.76}


 19%|█▉        | 350/1800 [06:32<26:52,  1.11s/it]

{'loss': 2.8253, 'grad_norm': 5.1162028312683105, 'learning_rate': 3.5e-05, 'epoch': 0.78}


 20%|██        | 360/1800 [06:43<26:43,  1.11s/it]

{'loss': 3.0357, 'grad_norm': 6.457411766052246, 'learning_rate': 3.6e-05, 'epoch': 0.8}


 21%|██        | 370/1800 [06:54<26:27,  1.11s/it]

{'loss': 2.9165, 'grad_norm': 6.660024166107178, 'learning_rate': 3.7e-05, 'epoch': 0.82}


 21%|██        | 380/1800 [07:05<26:37,  1.12s/it]

{'loss': 3.0643, 'grad_norm': 6.613027572631836, 'learning_rate': 3.8e-05, 'epoch': 0.84}


 22%|██▏       | 390/1800 [07:16<26:11,  1.11s/it]

{'loss': 2.98, 'grad_norm': 7.986294269561768, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.87}


 22%|██▏       | 400/1800 [07:27<25:58,  1.11s/it]

{'loss': 2.8036, 'grad_norm': 5.894557952880859, 'learning_rate': 4e-05, 'epoch': 0.89}


 23%|██▎       | 410/1800 [07:39<26:19,  1.14s/it]

{'loss': 3.0148, 'grad_norm': 5.449767589569092, 'learning_rate': 4.1e-05, 'epoch': 0.91}


 23%|██▎       | 420/1800 [07:50<26:05,  1.13s/it]

{'loss': 2.7949, 'grad_norm': 5.472409725189209, 'learning_rate': 4.2e-05, 'epoch': 0.93}


 24%|██▍       | 430/1800 [08:01<25:33,  1.12s/it]

{'loss': 2.9573, 'grad_norm': 4.383064270019531, 'learning_rate': 4.3e-05, 'epoch': 0.96}


 24%|██▍       | 440/1800 [08:12<25:25,  1.12s/it]

{'loss': 2.8249, 'grad_norm': 4.342525005340576, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.98}


 25%|██▌       | 450/1800 [08:23<25:26,  1.13s/it]

{'loss': 2.8692, 'grad_norm': 5.208646774291992, 'learning_rate': 4.5e-05, 'epoch': 1.0}


 26%|██▌       | 460/1800 [08:35<25:26,  1.14s/it]

{'loss': 2.8254, 'grad_norm': 4.593421459197998, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.02}


 26%|██▌       | 470/1800 [08:46<24:34,  1.11s/it]

{'loss': 2.8872, 'grad_norm': 4.492712020874023, 'learning_rate': 4.7e-05, 'epoch': 1.04}


 27%|██▋       | 480/1800 [08:57<24:37,  1.12s/it]

{'loss': 2.9846, 'grad_norm': 7.3021464347839355, 'learning_rate': 4.8e-05, 'epoch': 1.07}


 27%|██▋       | 490/1800 [09:09<24:29,  1.12s/it]

{'loss': 2.768, 'grad_norm': 5.254695415496826, 'learning_rate': 4.9e-05, 'epoch': 1.09}


 28%|██▊       | 500/1800 [09:20<24:12,  1.12s/it]

{'loss': 2.5747, 'grad_norm': 5.648955345153809, 'learning_rate': 5e-05, 'epoch': 1.11}


 28%|██▊       | 510/1800 [09:33<24:32,  1.14s/it]

{'loss': 3.0419, 'grad_norm': 6.386997222900391, 'learning_rate': 4.961538461538462e-05, 'epoch': 1.13}


 29%|██▉       | 520/1800 [09:44<24:19,  1.14s/it]

{'loss': 2.7082, 'grad_norm': 6.9038920402526855, 'learning_rate': 4.923076923076924e-05, 'epoch': 1.16}


 29%|██▉       | 530/1800 [09:55<24:10,  1.14s/it]

{'loss': 2.8205, 'grad_norm': 5.038737773895264, 'learning_rate': 4.884615384615385e-05, 'epoch': 1.18}


 30%|███       | 540/1800 [10:07<24:29,  1.17s/it]

{'loss': 2.8876, 'grad_norm': 4.781486988067627, 'learning_rate': 4.846153846153846e-05, 'epoch': 1.2}


 31%|███       | 550/1800 [10:18<23:44,  1.14s/it]

{'loss': 2.8432, 'grad_norm': 4.390292644500732, 'learning_rate': 4.8076923076923084e-05, 'epoch': 1.22}


 31%|███       | 560/1800 [10:30<23:41,  1.15s/it]

{'loss': 2.5897, 'grad_norm': 6.5311737060546875, 'learning_rate': 4.76923076923077e-05, 'epoch': 1.24}


 32%|███▏      | 570/1800 [10:41<23:28,  1.14s/it]

{'loss': 2.8666, 'grad_norm': 5.369809627532959, 'learning_rate': 4.730769230769231e-05, 'epoch': 1.27}


 32%|███▏      | 580/1800 [10:53<23:17,  1.15s/it]

{'loss': 2.6949, 'grad_norm': 4.761454105377197, 'learning_rate': 4.692307692307693e-05, 'epoch': 1.29}


 33%|███▎      | 590/1800 [11:04<22:53,  1.13s/it]

{'loss': 2.9599, 'grad_norm': 4.967236518859863, 'learning_rate': 4.653846153846154e-05, 'epoch': 1.31}


 33%|███▎      | 600/1800 [11:16<22:40,  1.13s/it]

{'loss': 2.9569, 'grad_norm': 5.810324668884277, 'learning_rate': 4.615384615384616e-05, 'epoch': 1.33}


 34%|███▍      | 610/1800 [11:27<22:30,  1.14s/it]

{'loss': 2.9191, 'grad_norm': 6.515418529510498, 'learning_rate': 4.576923076923077e-05, 'epoch': 1.36}


 34%|███▍      | 620/1800 [11:38<22:36,  1.15s/it]

{'loss': 2.8716, 'grad_norm': 6.90920877456665, 'learning_rate': 4.538461538461539e-05, 'epoch': 1.38}


 35%|███▌      | 630/1800 [11:50<22:01,  1.13s/it]

{'loss': 2.6957, 'grad_norm': 6.489469528198242, 'learning_rate': 4.5e-05, 'epoch': 1.4}


 36%|███▌      | 640/1800 [12:01<21:58,  1.14s/it]

{'loss': 2.8708, 'grad_norm': 5.293485641479492, 'learning_rate': 4.461538461538462e-05, 'epoch': 1.42}


 36%|███▌      | 650/1800 [12:13<21:49,  1.14s/it]

{'loss': 2.9164, 'grad_norm': 4.823029041290283, 'learning_rate': 4.423076923076923e-05, 'epoch': 1.44}


 37%|███▋      | 660/1800 [12:24<21:36,  1.14s/it]

{'loss': 2.7313, 'grad_norm': 4.003153324127197, 'learning_rate': 4.384615384615385e-05, 'epoch': 1.47}


 37%|███▋      | 670/1800 [12:35<21:27,  1.14s/it]

{'loss': 2.7379, 'grad_norm': 4.335102081298828, 'learning_rate': 4.346153846153846e-05, 'epoch': 1.49}


 38%|███▊      | 680/1800 [12:47<21:10,  1.13s/it]

{'loss': 2.8964, 'grad_norm': 5.0199360847473145, 'learning_rate': 4.3076923076923084e-05, 'epoch': 1.51}


 38%|███▊      | 690/1800 [12:58<21:10,  1.14s/it]

{'loss': 2.9428, 'grad_norm': 5.197119235992432, 'learning_rate': 4.269230769230769e-05, 'epoch': 1.53}


 39%|███▉      | 700/1800 [13:10<20:53,  1.14s/it]

{'loss': 2.4834, 'grad_norm': 4.481142520904541, 'learning_rate': 4.230769230769231e-05, 'epoch': 1.56}


 39%|███▉      | 710/1800 [13:21<20:46,  1.14s/it]

{'loss': 2.749, 'grad_norm': 4.698019027709961, 'learning_rate': 4.192307692307693e-05, 'epoch': 1.58}


 40%|████      | 720/1800 [14:00<41:08,  2.29s/it]  

{'loss': 2.4927, 'grad_norm': 5.780745983123779, 'learning_rate': 4.1538461538461544e-05, 'epoch': 1.6}


 41%|████      | 730/1800 [14:11<20:59,  1.18s/it]

{'loss': 2.7652, 'grad_norm': 5.399428844451904, 'learning_rate': 4.115384615384615e-05, 'epoch': 1.62}


 41%|████      | 740/1800 [14:23<20:15,  1.15s/it]

{'loss': 2.6529, 'grad_norm': 5.521068572998047, 'learning_rate': 4.0769230769230773e-05, 'epoch': 1.64}


 42%|████▏     | 750/1800 [14:34<19:48,  1.13s/it]

{'loss': 2.7448, 'grad_norm': 4.47620153427124, 'learning_rate': 4.038461538461539e-05, 'epoch': 1.67}


 42%|████▏     | 760/1800 [16:44<1:15:00,  4.33s/it] 

{'loss': 2.6947, 'grad_norm': 6.273410320281982, 'learning_rate': 4e-05, 'epoch': 1.69}


 43%|████▎     | 770/1800 [17:04<24:42,  1.44s/it]  

{'loss': 2.9258, 'grad_norm': 5.380845546722412, 'learning_rate': 3.961538461538462e-05, 'epoch': 1.71}


 43%|████▎     | 780/1800 [17:15<19:31,  1.15s/it]

{'loss': 2.611, 'grad_norm': 4.8244709968566895, 'learning_rate': 3.923076923076923e-05, 'epoch': 1.73}


 44%|████▍     | 790/1800 [17:27<19:05,  1.13s/it]

{'loss': 2.7792, 'grad_norm': 4.084115028381348, 'learning_rate': 3.884615384615385e-05, 'epoch': 1.76}


 44%|████▍     | 800/1800 [17:38<18:55,  1.14s/it]

{'loss': 2.6913, 'grad_norm': 5.667820453643799, 'learning_rate': 3.846153846153846e-05, 'epoch': 1.78}


 45%|████▌     | 810/1800 [17:50<19:06,  1.16s/it]

{'loss': 2.7946, 'grad_norm': 4.545520782470703, 'learning_rate': 3.807692307692308e-05, 'epoch': 1.8}


 46%|████▌     | 820/1800 [18:01<19:15,  1.18s/it]

{'loss': 2.6495, 'grad_norm': 4.927412986755371, 'learning_rate': 3.769230769230769e-05, 'epoch': 1.82}


 46%|████▌     | 830/1800 [18:13<19:00,  1.18s/it]

{'loss': 2.6217, 'grad_norm': 5.023070812225342, 'learning_rate': 3.730769230769231e-05, 'epoch': 1.84}


 47%|████▋     | 840/1800 [18:25<18:23,  1.15s/it]

{'loss': 2.7403, 'grad_norm': 4.304687023162842, 'learning_rate': 3.692307692307693e-05, 'epoch': 1.87}


 47%|████▋     | 850/1800 [18:36<18:19,  1.16s/it]

{'loss': 2.4058, 'grad_norm': 4.8130316734313965, 'learning_rate': 3.653846153846154e-05, 'epoch': 1.89}


 48%|████▊     | 860/1800 [18:48<18:17,  1.17s/it]

{'loss': 2.7021, 'grad_norm': 6.999162197113037, 'learning_rate': 3.615384615384615e-05, 'epoch': 1.91}


 48%|████▊     | 870/1800 [18:59<18:10,  1.17s/it]

{'loss': 2.8114, 'grad_norm': 4.399367332458496, 'learning_rate': 3.5769230769230774e-05, 'epoch': 1.93}


 49%|████▉     | 880/1800 [19:11<17:28,  1.14s/it]

{'loss': 2.9593, 'grad_norm': 5.177162170410156, 'learning_rate': 3.538461538461539e-05, 'epoch': 1.96}


 49%|████▉     | 890/1800 [19:22<17:10,  1.13s/it]

{'loss': 2.7579, 'grad_norm': 4.2247796058654785, 'learning_rate': 3.5e-05, 'epoch': 1.98}


 50%|█████     | 900/1800 [19:34<17:45,  1.18s/it]

{'loss': 2.88, 'grad_norm': 5.210589408874512, 'learning_rate': 3.461538461538462e-05, 'epoch': 2.0}


 51%|█████     | 910/1800 [19:46<16:59,  1.15s/it]

{'loss': 2.7307, 'grad_norm': 4.685246467590332, 'learning_rate': 3.4230769230769234e-05, 'epoch': 2.02}


 51%|█████     | 920/1800 [19:57<16:49,  1.15s/it]

{'loss': 2.8868, 'grad_norm': 4.709485054016113, 'learning_rate': 3.384615384615385e-05, 'epoch': 2.04}


 52%|█████▏    | 930/1800 [20:09<17:02,  1.18s/it]

{'loss': 2.6668, 'grad_norm': 5.008599281311035, 'learning_rate': 3.346153846153846e-05, 'epoch': 2.07}


 52%|█████▏    | 940/1800 [20:20<16:30,  1.15s/it]

{'loss': 2.5385, 'grad_norm': 5.821756362915039, 'learning_rate': 3.307692307692308e-05, 'epoch': 2.09}


 53%|█████▎    | 950/1800 [20:32<16:31,  1.17s/it]

{'loss': 2.6464, 'grad_norm': 3.845435619354248, 'learning_rate': 3.269230769230769e-05, 'epoch': 2.11}


 53%|█████▎    | 960/1800 [20:44<16:08,  1.15s/it]

{'loss': 2.7696, 'grad_norm': 5.644723892211914, 'learning_rate': 3.230769230769231e-05, 'epoch': 2.13}


 54%|█████▍    | 970/1800 [20:55<16:06,  1.17s/it]

{'loss': 2.5985, 'grad_norm': 4.696679592132568, 'learning_rate': 3.192307692307692e-05, 'epoch': 2.16}


 54%|█████▍    | 980/1800 [21:07<15:50,  1.16s/it]

{'loss': 2.7068, 'grad_norm': 5.09607458114624, 'learning_rate': 3.153846153846154e-05, 'epoch': 2.18}


 55%|█████▌    | 990/1800 [21:18<15:29,  1.15s/it]

{'loss': 2.598, 'grad_norm': 4.389073848724365, 'learning_rate': 3.115384615384615e-05, 'epoch': 2.2}


 56%|█████▌    | 1000/1800 [21:30<15:33,  1.17s/it]

{'loss': 2.7968, 'grad_norm': 5.239015102386475, 'learning_rate': 3.0769230769230774e-05, 'epoch': 2.22}


 56%|█████▌    | 1010/1800 [21:44<16:05,  1.22s/it]

{'loss': 2.5156, 'grad_norm': 4.50377082824707, 'learning_rate': 3.0384615384615382e-05, 'epoch': 2.24}


 57%|█████▋    | 1020/1800 [21:56<15:09,  1.17s/it]

{'loss': 2.3229, 'grad_norm': 4.623703479766846, 'learning_rate': 3e-05, 'epoch': 2.27}


 57%|█████▋    | 1030/1800 [22:07<14:52,  1.16s/it]

{'loss': 2.6761, 'grad_norm': 4.287026405334473, 'learning_rate': 2.9615384615384616e-05, 'epoch': 2.29}


 58%|█████▊    | 1040/1800 [22:19<14:31,  1.15s/it]

{'loss': 2.6123, 'grad_norm': 4.836230278015137, 'learning_rate': 2.9230769230769234e-05, 'epoch': 2.31}


 58%|█████▊    | 1050/1800 [22:30<14:24,  1.15s/it]

{'loss': 2.4015, 'grad_norm': 4.847639560699463, 'learning_rate': 2.8846153846153845e-05, 'epoch': 2.33}


 59%|█████▉    | 1060/1800 [22:42<14:15,  1.16s/it]

{'loss': 2.921, 'grad_norm': 4.598034858703613, 'learning_rate': 2.846153846153846e-05, 'epoch': 2.36}


 59%|█████▉    | 1070/1800 [22:54<14:15,  1.17s/it]

{'loss': 2.4997, 'grad_norm': 6.156018257141113, 'learning_rate': 2.807692307692308e-05, 'epoch': 2.38}


 60%|██████    | 1080/1800 [23:05<13:37,  1.13s/it]

{'loss': 2.7379, 'grad_norm': 5.117166519165039, 'learning_rate': 2.7692307692307694e-05, 'epoch': 2.4}


 61%|██████    | 1090/1800 [23:17<13:41,  1.16s/it]

{'loss': 2.7722, 'grad_norm': 6.967894554138184, 'learning_rate': 2.7307692307692305e-05, 'epoch': 2.42}


 61%|██████    | 1100/1800 [23:28<13:11,  1.13s/it]

{'loss': 2.5605, 'grad_norm': 5.350771427154541, 'learning_rate': 2.6923076923076923e-05, 'epoch': 2.44}


 62%|██████▏   | 1110/1800 [23:39<12:49,  1.12s/it]

{'loss': 2.844, 'grad_norm': 3.983130693435669, 'learning_rate': 2.6538461538461538e-05, 'epoch': 2.47}


 62%|██████▏   | 1120/1800 [23:50<12:36,  1.11s/it]

{'loss': 2.7002, 'grad_norm': 4.61122989654541, 'learning_rate': 2.6153846153846157e-05, 'epoch': 2.49}


 63%|██████▎   | 1130/1800 [24:01<12:29,  1.12s/it]

{'loss': 2.5994, 'grad_norm': 6.163203239440918, 'learning_rate': 2.5769230769230768e-05, 'epoch': 2.51}


 63%|██████▎   | 1140/1800 [24:12<12:17,  1.12s/it]

{'loss': 2.4651, 'grad_norm': 4.98346471786499, 'learning_rate': 2.5384615384615383e-05, 'epoch': 2.53}


 64%|██████▍   | 1150/1800 [24:24<12:02,  1.11s/it]

{'loss': 2.628, 'grad_norm': 5.421518325805664, 'learning_rate': 2.5e-05, 'epoch': 2.56}


 64%|██████▍   | 1160/1800 [24:35<11:54,  1.12s/it]

{'loss': 2.6141, 'grad_norm': 5.16984748840332, 'learning_rate': 2.461538461538462e-05, 'epoch': 2.58}


 65%|██████▌   | 1170/1800 [24:46<11:47,  1.12s/it]

{'loss': 2.6825, 'grad_norm': 4.5502142906188965, 'learning_rate': 2.423076923076923e-05, 'epoch': 2.6}


 66%|██████▌   | 1180/1800 [24:57<11:35,  1.12s/it]

{'loss': 2.5199, 'grad_norm': 4.103640556335449, 'learning_rate': 2.384615384615385e-05, 'epoch': 2.62}


 66%|██████▌   | 1190/1800 [25:08<11:18,  1.11s/it]

{'loss': 2.5666, 'grad_norm': 6.375410079956055, 'learning_rate': 2.3461538461538464e-05, 'epoch': 2.64}


 67%|██████▋   | 1200/1800 [25:19<11:11,  1.12s/it]

{'loss': 2.7372, 'grad_norm': 5.276984214782715, 'learning_rate': 2.307692307692308e-05, 'epoch': 2.67}


 67%|██████▋   | 1210/1800 [25:31<11:00,  1.12s/it]

{'loss': 2.516, 'grad_norm': 4.8207197189331055, 'learning_rate': 2.2692307692307694e-05, 'epoch': 2.69}


 68%|██████▊   | 1220/1800 [25:42<10:46,  1.11s/it]

{'loss': 2.6222, 'grad_norm': 4.4827399253845215, 'learning_rate': 2.230769230769231e-05, 'epoch': 2.71}


 68%|██████▊   | 1230/1800 [25:53<10:34,  1.11s/it]

{'loss': 2.6093, 'grad_norm': 3.8023273944854736, 'learning_rate': 2.1923076923076924e-05, 'epoch': 2.73}


 69%|██████▉   | 1240/1800 [26:04<10:30,  1.13s/it]

{'loss': 2.564, 'grad_norm': 4.189753532409668, 'learning_rate': 2.1538461538461542e-05, 'epoch': 2.76}


 69%|██████▉   | 1250/1800 [26:15<10:13,  1.12s/it]

{'loss': 2.7445, 'grad_norm': 5.545016288757324, 'learning_rate': 2.1153846153846154e-05, 'epoch': 2.78}


 70%|███████   | 1260/1800 [26:26<10:07,  1.12s/it]

{'loss': 2.5834, 'grad_norm': 4.8915863037109375, 'learning_rate': 2.0769230769230772e-05, 'epoch': 2.8}


 71%|███████   | 1270/1800 [26:38<09:52,  1.12s/it]

{'loss': 2.6319, 'grad_norm': 4.466627597808838, 'learning_rate': 2.0384615384615387e-05, 'epoch': 2.82}


 71%|███████   | 1280/1800 [26:49<09:39,  1.11s/it]

{'loss': 2.6419, 'grad_norm': 6.012040138244629, 'learning_rate': 2e-05, 'epoch': 2.84}


 72%|███████▏  | 1290/1800 [27:00<09:28,  1.11s/it]

{'loss': 2.6683, 'grad_norm': 4.734457492828369, 'learning_rate': 1.9615384615384617e-05, 'epoch': 2.87}


 72%|███████▏  | 1300/1800 [27:11<09:17,  1.11s/it]

{'loss': 2.8396, 'grad_norm': 5.664356708526611, 'learning_rate': 1.923076923076923e-05, 'epoch': 2.89}


 73%|███████▎  | 1310/1800 [27:22<09:06,  1.12s/it]

{'loss': 2.8956, 'grad_norm': 5.189283847808838, 'learning_rate': 1.8846153846153846e-05, 'epoch': 2.91}


 73%|███████▎  | 1320/1800 [27:33<08:57,  1.12s/it]

{'loss': 2.504, 'grad_norm': 6.218538761138916, 'learning_rate': 1.8461538461538465e-05, 'epoch': 2.93}


 74%|███████▍  | 1330/1800 [27:45<08:47,  1.12s/it]

{'loss': 2.9013, 'grad_norm': 4.707559108734131, 'learning_rate': 1.8076923076923076e-05, 'epoch': 2.96}


 74%|███████▍  | 1340/1800 [27:56<08:29,  1.11s/it]

{'loss': 2.8738, 'grad_norm': 5.656085014343262, 'learning_rate': 1.7692307692307694e-05, 'epoch': 2.98}


 75%|███████▌  | 1350/1800 [28:07<08:26,  1.13s/it]

{'loss': 2.8076, 'grad_norm': 5.357631683349609, 'learning_rate': 1.730769230769231e-05, 'epoch': 3.0}


 76%|███████▌  | 1360/1800 [28:18<08:10,  1.11s/it]

{'loss': 2.6687, 'grad_norm': 5.221198081970215, 'learning_rate': 1.6923076923076924e-05, 'epoch': 3.02}


 76%|███████▌  | 1370/1800 [28:29<08:00,  1.12s/it]

{'loss': 2.5903, 'grad_norm': 5.144720554351807, 'learning_rate': 1.653846153846154e-05, 'epoch': 3.04}


 77%|███████▋  | 1380/1800 [28:40<07:53,  1.13s/it]

{'loss': 2.7398, 'grad_norm': 4.758065700531006, 'learning_rate': 1.6153846153846154e-05, 'epoch': 3.07}


 77%|███████▋  | 1390/1800 [28:52<07:37,  1.12s/it]

{'loss': 2.5693, 'grad_norm': 5.217372417449951, 'learning_rate': 1.576923076923077e-05, 'epoch': 3.09}


 78%|███████▊  | 1400/1800 [29:03<07:25,  1.11s/it]

{'loss': 2.5819, 'grad_norm': 4.943755626678467, 'learning_rate': 1.5384615384615387e-05, 'epoch': 3.11}


 78%|███████▊  | 1410/1800 [29:14<07:14,  1.11s/it]

{'loss': 2.4794, 'grad_norm': 5.843470096588135, 'learning_rate': 1.5e-05, 'epoch': 3.13}


 79%|███████▉  | 1420/1800 [29:25<07:04,  1.12s/it]

{'loss': 2.401, 'grad_norm': 4.616424083709717, 'learning_rate': 1.4615384615384617e-05, 'epoch': 3.16}


 79%|███████▉  | 1430/1800 [29:36<06:53,  1.12s/it]

{'loss': 2.775, 'grad_norm': 4.9420013427734375, 'learning_rate': 1.423076923076923e-05, 'epoch': 3.18}


 80%|████████  | 1440/1800 [29:47<06:41,  1.11s/it]

{'loss': 2.4526, 'grad_norm': 5.488192081451416, 'learning_rate': 1.3846153846153847e-05, 'epoch': 3.2}


 81%|████████  | 1450/1800 [29:59<06:29,  1.11s/it]

{'loss': 2.4801, 'grad_norm': 3.228651285171509, 'learning_rate': 1.3461538461538462e-05, 'epoch': 3.22}


 81%|████████  | 1460/1800 [30:10<06:16,  1.11s/it]

{'loss': 2.697, 'grad_norm': 5.4533257484436035, 'learning_rate': 1.3076923076923078e-05, 'epoch': 3.24}


 82%|████████▏ | 1470/1800 [30:21<06:11,  1.13s/it]

{'loss': 2.7487, 'grad_norm': 4.624842166900635, 'learning_rate': 1.2692307692307691e-05, 'epoch': 3.27}


 82%|████████▏ | 1480/1800 [30:32<05:59,  1.12s/it]

{'loss': 2.616, 'grad_norm': 6.093351364135742, 'learning_rate': 1.230769230769231e-05, 'epoch': 3.29}


 83%|████████▎ | 1490/1800 [30:43<05:47,  1.12s/it]

{'loss': 2.4466, 'grad_norm': 5.3667707443237305, 'learning_rate': 1.1923076923076925e-05, 'epoch': 3.31}


 83%|████████▎ | 1500/1800 [30:54<05:33,  1.11s/it]

{'loss': 2.7349, 'grad_norm': 4.900928020477295, 'learning_rate': 1.153846153846154e-05, 'epoch': 3.33}


 84%|████████▍ | 1510/1800 [31:07<05:32,  1.15s/it]

{'loss': 2.4331, 'grad_norm': 4.059223175048828, 'learning_rate': 1.1153846153846154e-05, 'epoch': 3.36}


 84%|████████▍ | 1520/1800 [31:19<05:14,  1.12s/it]

{'loss': 2.6059, 'grad_norm': 5.052735805511475, 'learning_rate': 1.0769230769230771e-05, 'epoch': 3.38}


 85%|████████▌ | 1530/1800 [31:30<05:01,  1.12s/it]

{'loss': 2.5397, 'grad_norm': 5.0495452880859375, 'learning_rate': 1.0384615384615386e-05, 'epoch': 3.4}


 86%|████████▌ | 1540/1800 [31:41<04:51,  1.12s/it]

{'loss': 2.2915, 'grad_norm': 3.97147274017334, 'learning_rate': 1e-05, 'epoch': 3.42}


 86%|████████▌ | 1550/1800 [31:52<04:40,  1.12s/it]

{'loss': 2.5833, 'grad_norm': 4.197085857391357, 'learning_rate': 9.615384615384616e-06, 'epoch': 3.44}


 87%|████████▋ | 1560/1800 [32:03<04:30,  1.13s/it]

{'loss': 2.624, 'grad_norm': 4.936925888061523, 'learning_rate': 9.230769230769232e-06, 'epoch': 3.47}


 87%|████████▋ | 1570/1800 [32:15<04:16,  1.11s/it]

{'loss': 2.4642, 'grad_norm': 4.397681713104248, 'learning_rate': 8.846153846153847e-06, 'epoch': 3.49}


 88%|████████▊ | 1580/1800 [32:26<04:05,  1.11s/it]

{'loss': 2.7881, 'grad_norm': 4.2447638511657715, 'learning_rate': 8.461538461538462e-06, 'epoch': 3.51}


 88%|████████▊ | 1590/1800 [32:37<03:54,  1.11s/it]

{'loss': 2.4938, 'grad_norm': 4.711089134216309, 'learning_rate': 8.076923076923077e-06, 'epoch': 3.53}


 89%|████████▉ | 1600/1800 [32:48<03:42,  1.11s/it]

{'loss': 2.9566, 'grad_norm': 5.316188812255859, 'learning_rate': 7.692307692307694e-06, 'epoch': 3.56}


 89%|████████▉ | 1610/1800 [32:59<03:31,  1.11s/it]

{'loss': 2.7749, 'grad_norm': 4.024603843688965, 'learning_rate': 7.3076923076923085e-06, 'epoch': 3.58}


 90%|█████████ | 1620/1800 [33:10<03:21,  1.12s/it]

{'loss': 2.9238, 'grad_norm': 5.487823009490967, 'learning_rate': 6.923076923076923e-06, 'epoch': 3.6}


 91%|█████████ | 1630/1800 [33:21<03:07,  1.10s/it]

{'loss': 2.5878, 'grad_norm': 4.779064655303955, 'learning_rate': 6.538461538461539e-06, 'epoch': 3.62}


 91%|█████████ | 1640/1800 [33:32<02:58,  1.12s/it]

{'loss': 2.544, 'grad_norm': 4.072878837585449, 'learning_rate': 6.153846153846155e-06, 'epoch': 3.64}


 92%|█████████▏| 1650/1800 [33:44<02:48,  1.13s/it]

{'loss': 2.6982, 'grad_norm': 5.484775066375732, 'learning_rate': 5.76923076923077e-06, 'epoch': 3.67}


 92%|█████████▏| 1660/1800 [33:55<02:36,  1.12s/it]

{'loss': 2.4366, 'grad_norm': 4.214024066925049, 'learning_rate': 5.3846153846153855e-06, 'epoch': 3.69}


 93%|█████████▎| 1670/1800 [34:06<02:23,  1.10s/it]

{'loss': 2.2877, 'grad_norm': 4.663923263549805, 'learning_rate': 5e-06, 'epoch': 3.71}


 93%|█████████▎| 1680/1800 [34:17<02:14,  1.12s/it]

{'loss': 2.9013, 'grad_norm': 5.9113311767578125, 'learning_rate': 4.615384615384616e-06, 'epoch': 3.73}


 94%|█████████▍| 1690/1800 [34:28<02:03,  1.12s/it]

{'loss': 2.5475, 'grad_norm': 5.019940376281738, 'learning_rate': 4.230769230769231e-06, 'epoch': 3.76}


 94%|█████████▍| 1700/1800 [34:39<01:51,  1.12s/it]

{'loss': 2.6835, 'grad_norm': 5.465148448944092, 'learning_rate': 3.846153846153847e-06, 'epoch': 3.78}


 95%|█████████▌| 1710/1800 [34:51<01:40,  1.11s/it]

{'loss': 2.5629, 'grad_norm': 5.356286525726318, 'learning_rate': 3.4615384615384617e-06, 'epoch': 3.8}


 96%|█████████▌| 1720/1800 [35:02<01:29,  1.12s/it]

{'loss': 2.4002, 'grad_norm': 5.697914123535156, 'learning_rate': 3.0769230769230774e-06, 'epoch': 3.82}


 96%|█████████▌| 1730/1800 [35:13<01:18,  1.12s/it]

{'loss': 2.4346, 'grad_norm': 4.65962553024292, 'learning_rate': 2.6923076923076928e-06, 'epoch': 3.84}


 97%|█████████▋| 1740/1800 [35:24<01:06,  1.11s/it]

{'loss': 2.5952, 'grad_norm': 4.557797431945801, 'learning_rate': 2.307692307692308e-06, 'epoch': 3.87}


 97%|█████████▋| 1750/1800 [35:35<00:55,  1.10s/it]

{'loss': 2.3376, 'grad_norm': 4.63007116317749, 'learning_rate': 1.9230769230769234e-06, 'epoch': 3.89}


 98%|█████████▊| 1760/1800 [35:46<00:44,  1.11s/it]

{'loss': 2.3564, 'grad_norm': 4.291720390319824, 'learning_rate': 1.5384615384615387e-06, 'epoch': 3.91}


 98%|█████████▊| 1770/1800 [35:57<00:33,  1.12s/it]

{'loss': 2.7338, 'grad_norm': 6.100767135620117, 'learning_rate': 1.153846153846154e-06, 'epoch': 3.93}


 99%|█████████▉| 1780/1800 [36:09<00:22,  1.11s/it]

{'loss': 2.6402, 'grad_norm': 4.418018341064453, 'learning_rate': 7.692307692307694e-07, 'epoch': 3.96}


 99%|█████████▉| 1790/1800 [36:20<00:11,  1.11s/it]

{'loss': 2.4178, 'grad_norm': 3.7981481552124023, 'learning_rate': 3.846153846153847e-07, 'epoch': 3.98}


100%|██████████| 1800/1800 [36:31<00:00,  1.22s/it]

{'loss': 2.6796, 'grad_norm': 6.711416244506836, 'learning_rate': 0.0, 'epoch': 4.0}
{'train_runtime': 2191.2813, 'train_samples_per_second': 3.286, 'train_steps_per_second': 0.821, 'train_loss': 2.753595750596788, 'epoch': 4.0}





TrainOutput(global_step=1800, training_loss=2.753595750596788, metrics={'train_runtime': 2191.2813, 'train_samples_per_second': 3.286, 'train_steps_per_second': 0.821, 'total_flos': 1881302630400000.0, 'train_loss': 2.753595750596788, 'epoch': 4.0})

In [21]:
model.save_pretrained('./results')
tokenizer.save_pretrained('./results')

# Load the model and tokenizer for text generation
from transformers import pipeline

# Ensure your model and tokenizer are loaded correctly
diomedes = pipeline('text-generation', model='./results', tokenizer='./results')

# Generate text using the pipeline
results = diomedes('Hello ', max_length=600)
print(results[0]['generated_text'])


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Hello ʃ
I am the one, you know
I am the one

If you go far in my path you will see me
My home is where you be reborn
There was no night without your name
I stay on your side
Keep my word

If you feel the same way in every second I will
Stay my side
Keep my word

If you say what you want all over again you will
Stay behind my back
And if you break, come back again
You will see me right out of sight
Your home will forever be a place I will stand

If you don't choose a new name I will remain your side

If you don't choose a new name I will remain and forever I will stay
If you don’t choose a new name I will stay and forever I will stay

If you don’t choose a new name I will stay and forever I will stay

If you don’t choose a new name I will stay and forever I will stay
If you don’t choose a new name I will stay and forever I will stay
If you don’t choose a new name I will stay and forever I will stay
I will stay on the side
Keep my word

If you feel the same way in every second I will fly

In [51]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, DatasetDict

# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Ensure that tokenizer has padding token set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Assuming the CSV file is at the path 'path_to_csv', and it includes 'tag' and 'lyrics' columns
dataset = load_dataset("amishshah/song_lyrics")
dataset = dataset["train"].shuffle(seed=42)
subset_size = 10000  # Adjust as necessary
dataset = dataset.select(range(subset_size))
train_test_dataset = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_dataset["train"]
val_dataset = train_test_dataset["test"]

def tokenize_function(examples):
    # Prepend the tag to each lyric in the batch
    concatenated_lyrics = ["[Genre: " + tag + "] " + lyric for tag, lyric in zip(examples["tag"], examples["lyrics"])]
    return tokenizer(concatenated_lyrics, truncation=True, padding="max_length", max_length=512)


train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)

In [52]:
# Set training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
)

# Initialize Trainer
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# Train the model
trainer.train()

  6%|▌         | 555/9000 [1:39:22<25:12:01, 10.74s/it]
  1%|          | 100/9000 [01:52<2:45:14,  1.11s/it]
  1%|          | 100/9000 [01:52<2:45:14,  1.11s/it]

{'loss': 3.1795, 'grad_norm': 5.191074371337891, 'learning_rate': 1e-05, 'epoch': 0.04}


  2%|▏         | 200/9000 [03:44<2:43:59,  1.12s/it]
  2%|▏         | 200/9000 [03:44<2:43:59,  1.12s/it]

{'loss': 3.0335, 'grad_norm': 5.126195430755615, 'learning_rate': 2e-05, 'epoch': 0.09}


  3%|▎         | 300/9000 [05:36<2:42:10,  1.12s/it]
  3%|▎         | 300/9000 [05:36<2:42:10,  1.12s/it]

{'loss': 2.8782, 'grad_norm': 5.482480049133301, 'learning_rate': 3e-05, 'epoch': 0.13}


  4%|▍         | 400/9000 [07:28<2:40:20,  1.12s/it]
  4%|▍         | 400/9000 [07:28<2:40:20,  1.12s/it]

{'loss': 2.9091, 'grad_norm': 6.005352020263672, 'learning_rate': 4e-05, 'epoch': 0.18}


  6%|▌         | 500/9000 [09:20<2:38:49,  1.12s/it]
  6%|▌         | 500/9000 [09:20<2:38:49,  1.12s/it]

{'loss': 2.8845, 'grad_norm': 6.43183708190918, 'learning_rate': 5e-05, 'epoch': 0.22}


  7%|▋         | 600/9000 [11:14<2:34:23,  1.10s/it]
  7%|▋         | 600/9000 [11:14<2:34:23,  1.10s/it]

{'loss': 2.7924, 'grad_norm': 4.062376022338867, 'learning_rate': 4.9411764705882355e-05, 'epoch': 0.27}


  8%|▊         | 700/9000 [13:05<2:36:11,  1.13s/it]
  8%|▊         | 700/9000 [13:05<2:36:11,  1.13s/it]

{'loss': 2.9109, 'grad_norm': 4.2698073387146, 'learning_rate': 4.882352941176471e-05, 'epoch': 0.31}


  9%|▉         | 800/9000 [14:56<2:33:25,  1.12s/it]
  9%|▉         | 800/9000 [14:56<2:33:25,  1.12s/it]

{'loss': 2.7958, 'grad_norm': 2.501638174057007, 'learning_rate': 4.823529411764706e-05, 'epoch': 0.36}


 10%|█         | 900/9000 [16:47<2:30:19,  1.11s/it]
 10%|█         | 900/9000 [16:47<2:30:19,  1.11s/it]

{'loss': 2.8094, 'grad_norm': 2.873769521713257, 'learning_rate': 4.7647058823529414e-05, 'epoch': 0.4}


 11%|█         | 1000/9000 [18:39<2:28:25,  1.11s/it]
 11%|█         | 1000/9000 [18:39<2:28:25,  1.11s/it]

{'loss': 2.8657, 'grad_norm': 2.8658206462860107, 'learning_rate': 4.705882352941177e-05, 'epoch': 0.44}


 12%|█▏        | 1100/9000 [20:33<2:25:25,  1.10s/it]
 12%|█▏        | 1100/9000 [20:33<2:25:25,  1.10s/it]

{'loss': 2.7765, 'grad_norm': 2.2796690464019775, 'learning_rate': 4.647058823529412e-05, 'epoch': 0.49}


 13%|█▎        | 1200/9000 [22:24<2:23:58,  1.11s/it]
 13%|█▎        | 1200/9000 [22:24<2:23:58,  1.11s/it]

{'loss': 2.8443, 'grad_norm': 2.3918046951293945, 'learning_rate': 4.588235294117647e-05, 'epoch': 0.53}


 14%|█▍        | 1300/9000 [24:15<2:22:17,  1.11s/it]
 14%|█▍        | 1300/9000 [24:15<2:22:17,  1.11s/it]

{'loss': 2.8646, 'grad_norm': 2.18247652053833, 'learning_rate': 4.5294117647058826e-05, 'epoch': 0.58}


 16%|█▌        | 1400/9000 [26:06<2:21:23,  1.12s/it]
 16%|█▌        | 1400/9000 [26:06<2:21:23,  1.12s/it]

{'loss': 2.7581, 'grad_norm': 1.7247874736785889, 'learning_rate': 4.470588235294118e-05, 'epoch': 0.62}


 17%|█▋        | 1500/9000 [28:07<2:21:15,  1.13s/it]
 17%|█▋        | 1500/9000 [28:07<2:21:15,  1.13s/it]

{'loss': 2.7235, 'grad_norm': 2.5414185523986816, 'learning_rate': 4.411764705882353e-05, 'epoch': 0.67}


 18%|█▊        | 1600/9000 [30:02<2:18:22,  1.12s/it]
 18%|█▊        | 1600/9000 [30:02<2:18:22,  1.12s/it]

{'loss': 2.7131, 'grad_norm': 1.8392738103866577, 'learning_rate': 4.3529411764705885e-05, 'epoch': 0.71}


 19%|█▉        | 1700/9000 [31:54<2:15:11,  1.11s/it]
 19%|█▉        | 1700/9000 [31:54<2:15:11,  1.11s/it]

{'loss': 2.751, 'grad_norm': 1.9378267526626587, 'learning_rate': 4.294117647058823e-05, 'epoch': 0.76}


 20%|██        | 1800/9000 [33:46<2:14:56,  1.12s/it]
 20%|██        | 1800/9000 [33:46<2:14:56,  1.12s/it]

{'loss': 2.7167, 'grad_norm': 1.8970435857772827, 'learning_rate': 4.235294117647059e-05, 'epoch': 0.8}


 21%|██        | 1900/9000 [35:37<2:11:12,  1.11s/it]
 21%|██        | 1900/9000 [35:37<2:11:12,  1.11s/it]

{'loss': 2.7823, 'grad_norm': 1.760888934135437, 'learning_rate': 4.1764705882352944e-05, 'epoch': 0.84}


 22%|██▏       | 2000/9000 [37:28<2:11:01,  1.12s/it]
 22%|██▏       | 2000/9000 [37:28<2:11:01,  1.12s/it]

{'loss': 2.7613, 'grad_norm': 2.151956796646118, 'learning_rate': 4.11764705882353e-05, 'epoch': 0.89}


 23%|██▎       | 2100/9000 [39:22<2:09:17,  1.12s/it]
 23%|██▎       | 2100/9000 [39:22<2:09:17,  1.12s/it]

{'loss': 2.7527, 'grad_norm': 2.0038342475891113, 'learning_rate': 4.058823529411765e-05, 'epoch': 0.93}


 24%|██▍       | 2200/9000 [41:14<2:06:05,  1.11s/it]
 24%|██▍       | 2200/9000 [41:14<2:06:05,  1.11s/it]

{'loss': 2.7848, 'grad_norm': 1.525112271308899, 'learning_rate': 4e-05, 'epoch': 0.98}


 26%|██▌       | 2300/9000 [43:08<2:13:44,  1.20s/it]
 26%|██▌       | 2300/9000 [43:08<2:13:44,  1.20s/it]

{'loss': 2.8023, 'grad_norm': 2.02986478805542, 'learning_rate': 3.9411764705882356e-05, 'epoch': 1.02}


 27%|██▋       | 2400/9000 [45:03<2:06:05,  1.15s/it]
 27%|██▋       | 2400/9000 [45:03<2:06:05,  1.15s/it]

{'loss': 2.6343, 'grad_norm': 2.471766710281372, 'learning_rate': 3.882352941176471e-05, 'epoch': 1.07}


 28%|██▊       | 2500/9000 [46:57<2:03:03,  1.14s/it]
 28%|██▊       | 2500/9000 [46:57<2:03:03,  1.14s/it]

{'loss': 2.7019, 'grad_norm': 1.4912142753601074, 'learning_rate': 3.8235294117647055e-05, 'epoch': 1.11}


 29%|██▉       | 2600/9000 [48:53<1:59:55,  1.12s/it]
 29%|██▉       | 2600/9000 [48:53<1:59:55,  1.12s/it]

{'loss': 2.6765, 'grad_norm': 1.479359745979309, 'learning_rate': 3.7647058823529415e-05, 'epoch': 1.16}


 30%|███       | 2700/9000 [50:47<1:57:16,  1.12s/it]
 30%|███       | 2700/9000 [50:47<1:57:16,  1.12s/it]

{'loss': 2.6146, 'grad_norm': 1.7265595197677612, 'learning_rate': 3.705882352941177e-05, 'epoch': 1.2}


 31%|███       | 2800/9000 [52:41<1:57:50,  1.14s/it]
 31%|███       | 2800/9000 [52:41<1:57:50,  1.14s/it] 

{'loss': 2.633, 'grad_norm': 2.0263445377349854, 'learning_rate': 3.6470588235294114e-05, 'epoch': 1.24}


 32%|███▏      | 2900/9000 [54:35<1:56:54,  1.15s/it]
 32%|███▏      | 2900/9000 [54:35<1:56:54,  1.15s/it] 

{'loss': 2.6509, 'grad_norm': 2.2399399280548096, 'learning_rate': 3.5882352941176474e-05, 'epoch': 1.29}


 33%|███▎      | 3000/9000 [56:28<1:51:54,  1.12s/it]
 33%|███▎      | 3000/9000 [56:28<1:51:54,  1.12s/it] 

{'loss': 2.6395, 'grad_norm': 1.793200969696045, 'learning_rate': 3.529411764705883e-05, 'epoch': 1.33}


 34%|███▍      | 3100/9000 [58:23<1:49:47,  1.12s/it]
 34%|███▍      | 3100/9000 [58:23<1:49:47,  1.12s/it] 

{'loss': 2.6032, 'grad_norm': 1.782893419265747, 'learning_rate': 3.470588235294118e-05, 'epoch': 1.38}


 36%|███▌      | 3200/9000 [1:00:14<1:49:11,  1.13s/it]
 36%|███▌      | 3200/9000 [1:00:14<1:49:11,  1.13s/it]

{'loss': 2.6524, 'grad_norm': 1.8332531452178955, 'learning_rate': 3.411764705882353e-05, 'epoch': 1.42}


 37%|███▋      | 3300/9000 [1:02:07<1:47:07,  1.13s/it]
 37%|███▋      | 3300/9000 [1:02:07<1:47:07,  1.13s/it]

{'loss': 2.6657, 'grad_norm': 2.1436023712158203, 'learning_rate': 3.352941176470588e-05, 'epoch': 1.47}


 38%|███▊      | 3400/9000 [1:03:59<1:44:47,  1.12s/it]
 38%|███▊      | 3400/9000 [1:03:59<1:44:47,  1.12s/it]

{'loss': 2.6244, 'grad_norm': 1.6253551244735718, 'learning_rate': 3.294117647058824e-05, 'epoch': 1.51}


 39%|███▉      | 3500/9000 [1:05:52<1:43:07,  1.12s/it]
 39%|███▉      | 3500/9000 [1:05:52<1:43:07,  1.12s/it]

{'loss': 2.6421, 'grad_norm': 2.2136332988739014, 'learning_rate': 3.235294117647059e-05, 'epoch': 1.56}


 40%|████      | 3600/9000 [1:07:46<1:41:28,  1.13s/it]
 40%|████      | 3600/9000 [1:07:46<1:41:28,  1.13s/it]

{'loss': 2.5957, 'grad_norm': 1.8594571352005005, 'learning_rate': 3.176470588235294e-05, 'epoch': 1.6}


 41%|████      | 3700/9000 [1:09:37<1:37:42,  1.11s/it]
 41%|████      | 3700/9000 [1:09:37<1:37:42,  1.11s/it]

{'loss': 2.5908, 'grad_norm': 1.4845479726791382, 'learning_rate': 3.11764705882353e-05, 'epoch': 1.64}


 42%|████▏     | 3800/9000 [1:11:27<1:34:42,  1.09s/it]
 42%|████▏     | 3800/9000 [1:11:27<1:34:42,  1.09s/it]

{'loss': 2.6166, 'grad_norm': 1.806246280670166, 'learning_rate': 3.058823529411765e-05, 'epoch': 1.69}


 43%|████▎     | 3900/9000 [1:13:18<1:32:54,  1.09s/it]
 43%|████▎     | 3900/9000 [1:13:18<1:32:54,  1.09s/it]

{'loss': 2.6256, 'grad_norm': 1.546592354774475, 'learning_rate': 3e-05, 'epoch': 1.73}


 44%|████▍     | 4000/9000 [1:15:08<1:30:56,  1.09s/it]
 44%|████▍     | 4000/9000 [1:15:08<1:30:56,  1.09s/it]

{'loss': 2.5553, 'grad_norm': 1.5444260835647583, 'learning_rate': 2.9411764705882354e-05, 'epoch': 1.78}


 46%|████▌     | 4100/9000 [1:17:01<1:30:25,  1.11s/it]
 46%|████▌     | 4100/9000 [1:17:01<1:30:25,  1.11s/it]

{'loss': 2.6342, 'grad_norm': 1.5049240589141846, 'learning_rate': 2.8823529411764703e-05, 'epoch': 1.82}


 47%|████▋     | 4200/9000 [1:18:52<1:28:09,  1.10s/it]
 47%|████▋     | 4200/9000 [1:18:52<1:28:09,  1.10s/it]

{'loss': 2.6398, 'grad_norm': 1.5949453115463257, 'learning_rate': 2.823529411764706e-05, 'epoch': 1.87}


 48%|████▊     | 4300/9000 [1:20:44<1:27:30,  1.12s/it]
 48%|████▊     | 4300/9000 [1:20:44<1:27:30,  1.12s/it]

{'loss': 2.671, 'grad_norm': 1.9855999946594238, 'learning_rate': 2.7647058823529416e-05, 'epoch': 1.91}


 49%|████▉     | 4400/9000 [1:22:36<1:24:49,  1.11s/it]
 49%|████▉     | 4400/9000 [1:22:36<1:24:49,  1.11s/it]

{'loss': 2.7205, 'grad_norm': 1.6320282220840454, 'learning_rate': 2.7058823529411766e-05, 'epoch': 1.96}


 50%|█████     | 4500/9000 [1:24:27<1:23:35,  1.11s/it]
 50%|█████     | 4500/9000 [1:24:27<1:23:35,  1.11s/it]

{'loss': 2.6107, 'grad_norm': 1.5965609550476074, 'learning_rate': 2.647058823529412e-05, 'epoch': 2.0}


 51%|█████     | 4600/9000 [1:26:19<1:20:49,  1.10s/it]
 51%|█████     | 4600/9000 [1:26:19<1:20:49,  1.10s/it]

{'loss': 2.5812, 'grad_norm': 1.6807786226272583, 'learning_rate': 2.5882352941176475e-05, 'epoch': 2.04}


 52%|█████▏    | 4700/9000 [1:29:13<1:27:02,  1.21s/it] 
 52%|█████▏    | 4700/9000 [1:29:13<1:27:02,  1.21s/it]

{'loss': 2.616, 'grad_norm': 1.6068819761276245, 'learning_rate': 2.5294117647058825e-05, 'epoch': 2.09}


 53%|█████▎    | 4800/9000 [1:31:05<1:17:55,  1.11s/it]
 53%|█████▎    | 4800/9000 [1:31:05<1:17:55,  1.11s/it]

{'loss': 2.556, 'grad_norm': 1.577765703201294, 'learning_rate': 2.4705882352941178e-05, 'epoch': 2.13}


 54%|█████▍    | 4900/9000 [1:32:57<1:16:32,  1.12s/it]
 54%|█████▍    | 4900/9000 [1:32:57<1:16:32,  1.12s/it]

{'loss': 2.5978, 'grad_norm': 1.6630079746246338, 'learning_rate': 2.411764705882353e-05, 'epoch': 2.18}


 56%|█████▌    | 5000/9000 [1:34:48<1:14:51,  1.12s/it]
 56%|█████▌    | 5000/9000 [1:34:48<1:14:51,  1.12s/it]

{'loss': 2.5958, 'grad_norm': 2.3526158332824707, 'learning_rate': 2.3529411764705884e-05, 'epoch': 2.22}


 57%|█████▋    | 5100/9000 [1:36:42<1:13:12,  1.13s/it]
 57%|█████▋    | 5100/9000 [1:36:42<1:13:12,  1.13s/it]

{'loss': 2.5709, 'grad_norm': 1.7172824144363403, 'learning_rate': 2.2941176470588237e-05, 'epoch': 2.27}


 58%|█████▊    | 5200/9000 [1:38:33<1:10:17,  1.11s/it]
 58%|█████▊    | 5200/9000 [1:38:33<1:10:17,  1.11s/it]

{'loss': 2.5832, 'grad_norm': 1.9749181270599365, 'learning_rate': 2.235294117647059e-05, 'epoch': 2.31}


 59%|█████▉    | 5300/9000 [1:40:25<1:08:27,  1.11s/it]
 59%|█████▉    | 5300/9000 [1:40:25<1:08:27,  1.11s/it]

{'loss': 2.5287, 'grad_norm': 1.8011837005615234, 'learning_rate': 2.1764705882352943e-05, 'epoch': 2.36}


 60%|██████    | 5400/9000 [1:42:17<1:06:31,  1.11s/it]
 60%|██████    | 5400/9000 [1:42:17<1:06:31,  1.11s/it]

{'loss': 2.5467, 'grad_norm': 1.7921557426452637, 'learning_rate': 2.1176470588235296e-05, 'epoch': 2.4}


 61%|██████    | 5500/9000 [1:44:08<1:04:49,  1.11s/it]
 61%|██████    | 5500/9000 [1:44:08<1:04:49,  1.11s/it]

{'loss': 2.4961, 'grad_norm': 2.307626485824585, 'learning_rate': 2.058823529411765e-05, 'epoch': 2.44}


 62%|██████▏   | 5600/9000 [1:46:01<1:03:13,  1.12s/it]
 62%|██████▏   | 5600/9000 [1:46:01<1:03:13,  1.12s/it]

{'loss': 2.547, 'grad_norm': 1.761753797531128, 'learning_rate': 2e-05, 'epoch': 2.49}


 63%|██████▎   | 5700/9000 [1:47:53<1:01:15,  1.11s/it]
 63%|██████▎   | 5700/9000 [1:47:53<1:01:15,  1.11s/it]

{'loss': 2.5348, 'grad_norm': 2.071650743484497, 'learning_rate': 1.9411764705882355e-05, 'epoch': 2.53}


 64%|██████▍   | 5800/9000 [1:49:44<59:31,  1.12s/it]  
 64%|██████▍   | 5800/9000 [1:49:44<59:31,  1.12s/it] 

{'loss': 2.5471, 'grad_norm': 1.7836315631866455, 'learning_rate': 1.8823529411764708e-05, 'epoch': 2.58}


 66%|██████▌   | 5900/9000 [1:51:36<57:19,  1.11s/it]
 66%|██████▌   | 5900/9000 [1:51:36<57:19,  1.11s/it] 

{'loss': 2.5284, 'grad_norm': 1.9840792417526245, 'learning_rate': 1.8235294117647057e-05, 'epoch': 2.62}


 67%|██████▋   | 6000/9000 [1:53:28<55:42,  1.11s/it]
 67%|██████▋   | 6000/9000 [1:53:28<55:42,  1.11s/it] 

{'loss': 2.5222, 'grad_norm': 1.5470770597457886, 'learning_rate': 1.7647058823529414e-05, 'epoch': 2.67}


 68%|██████▊   | 6100/9000 [1:55:21<54:00,  1.12s/it]  
 68%|██████▊   | 6100/9000 [1:55:21<54:00,  1.12s/it] 

{'loss': 2.5913, 'grad_norm': 1.9670213460922241, 'learning_rate': 1.7058823529411767e-05, 'epoch': 2.71}


 69%|██████▉   | 6200/9000 [1:57:13<51:54,  1.11s/it]
 69%|██████▉   | 6200/9000 [1:57:13<51:54,  1.11s/it] 

{'loss': 2.5543, 'grad_norm': 2.285857915878296, 'learning_rate': 1.647058823529412e-05, 'epoch': 2.76}


 70%|███████   | 6300/9000 [1:59:04<50:00,  1.11s/it]
 70%|███████   | 6300/9000 [1:59:04<50:00,  1.11s/it] 

{'loss': 2.5544, 'grad_norm': 1.6880210638046265, 'learning_rate': 1.588235294117647e-05, 'epoch': 2.8}


 71%|███████   | 6400/9000 [2:00:55<48:21,  1.12s/it]
 71%|███████   | 6400/9000 [2:00:55<48:21,  1.12s/it] 

{'loss': 2.5407, 'grad_norm': 2.2980682849884033, 'learning_rate': 1.5294117647058826e-05, 'epoch': 2.84}


 72%|███████▏  | 6500/9000 [2:02:47<46:41,  1.12s/it]
 72%|███████▏  | 6500/9000 [2:02:47<46:41,  1.12s/it] 

{'loss': 2.6081, 'grad_norm': 1.949706792831421, 'learning_rate': 1.4705882352941177e-05, 'epoch': 2.89}


 73%|███████▎  | 6600/9000 [2:05:59<44:33,  1.11s/it]   
 73%|███████▎  | 6600/9000 [2:05:59<44:33,  1.11s/it] 

{'loss': 2.5451, 'grad_norm': 1.6758157014846802, 'learning_rate': 1.411764705882353e-05, 'epoch': 2.93}


 74%|███████▍  | 6700/9000 [2:07:50<42:31,  1.11s/it]
 74%|███████▍  | 6700/9000 [2:07:50<42:31,  1.11s/it] 

{'loss': 2.6132, 'grad_norm': 2.5975587368011475, 'learning_rate': 1.3529411764705883e-05, 'epoch': 2.98}


 76%|███████▌  | 6800/9000 [2:09:42<41:10,  1.12s/it]
 76%|███████▌  | 6800/9000 [2:09:42<41:10,  1.12s/it] 

{'loss': 2.522, 'grad_norm': 1.999974012374878, 'learning_rate': 1.2941176470588238e-05, 'epoch': 3.02}


 77%|███████▋  | 6900/9000 [2:11:34<38:55,  1.11s/it]
 77%|███████▋  | 6900/9000 [2:11:34<38:55,  1.11s/it] 

{'loss': 2.5106, 'grad_norm': 1.6577931642532349, 'learning_rate': 1.2352941176470589e-05, 'epoch': 3.07}


 78%|███████▊  | 7000/9000 [2:13:26<37:10,  1.12s/it]
 78%|███████▊  | 7000/9000 [2:13:26<37:10,  1.12s/it] 

{'loss': 2.525, 'grad_norm': 2.8586227893829346, 'learning_rate': 1.1764705882352942e-05, 'epoch': 3.11}


 79%|███████▉  | 7100/9000 [2:15:20<35:19,  1.12s/it]
 79%|███████▉  | 7100/9000 [2:15:20<35:19,  1.12s/it] 

{'loss': 2.5042, 'grad_norm': 1.9048960208892822, 'learning_rate': 1.1176470588235295e-05, 'epoch': 3.16}


 80%|████████  | 7200/9000 [2:17:12<33:44,  1.12s/it]
 80%|████████  | 7200/9000 [2:17:12<33:44,  1.12s/it] 

{'loss': 2.5636, 'grad_norm': 1.7651855945587158, 'learning_rate': 1.0588235294117648e-05, 'epoch': 3.2}


 81%|████████  | 7300/9000 [2:19:03<31:30,  1.11s/it]
 81%|████████  | 7300/9000 [2:19:03<31:30,  1.11s/it] 

{'loss': 2.4767, 'grad_norm': 1.9282810688018799, 'learning_rate': 1e-05, 'epoch': 3.24}


 82%|████████▏ | 7400/9000 [2:20:55<29:39,  1.11s/it]
 82%|████████▏ | 7400/9000 [2:20:55<29:39,  1.11s/it] 

{'loss': 2.4754, 'grad_norm': 1.7005842924118042, 'learning_rate': 9.411764705882354e-06, 'epoch': 3.29}


 83%|████████▎ | 7500/9000 [2:22:47<27:49,  1.11s/it]
 83%|████████▎ | 7500/9000 [2:22:47<27:49,  1.11s/it] 

{'loss': 2.502, 'grad_norm': 1.8260676860809326, 'learning_rate': 8.823529411764707e-06, 'epoch': 3.33}


 84%|████████▍ | 7600/9000 [2:24:41<26:06,  1.12s/it]
 84%|████████▍ | 7600/9000 [2:24:41<26:06,  1.12s/it] 

{'loss': 2.4882, 'grad_norm': 1.4478967189788818, 'learning_rate': 8.23529411764706e-06, 'epoch': 3.38}


 86%|████████▌ | 7700/9000 [2:26:32<24:08,  1.11s/it]
 86%|████████▌ | 7700/9000 [2:26:32<24:08,  1.11s/it] 

{'loss': 2.4966, 'grad_norm': 1.9711883068084717, 'learning_rate': 7.647058823529413e-06, 'epoch': 3.42}


 87%|████████▋ | 7800/9000 [2:28:24<22:20,  1.12s/it]
 87%|████████▋ | 7800/9000 [2:28:24<22:20,  1.12s/it] 

{'loss': 2.4859, 'grad_norm': 1.778118371963501, 'learning_rate': 7.058823529411765e-06, 'epoch': 3.47}


 88%|████████▊ | 7900/9000 [2:30:16<20:35,  1.12s/it]
 88%|████████▊ | 7900/9000 [2:30:16<20:35,  1.12s/it] 

{'loss': 2.5081, 'grad_norm': 2.2681140899658203, 'learning_rate': 6.470588235294119e-06, 'epoch': 3.51}


 89%|████████▉ | 8000/9000 [2:32:08<18:40,  1.12s/it]
 89%|████████▉ | 8000/9000 [2:32:08<18:40,  1.12s/it] 

{'loss': 2.5537, 'grad_norm': 2.2819535732269287, 'learning_rate': 5.882352941176471e-06, 'epoch': 3.56}


 90%|█████████ | 8100/9000 [2:34:02<16:54,  1.13s/it]
 90%|█████████ | 8100/9000 [2:34:02<16:54,  1.13s/it] 

{'loss': 2.4968, 'grad_norm': 1.3971041440963745, 'learning_rate': 5.294117647058824e-06, 'epoch': 3.6}


 91%|█████████ | 8200/9000 [2:35:54<14:51,  1.11s/it]
 91%|█████████ | 8200/9000 [2:35:54<14:51,  1.11s/it] 

{'loss': 2.5436, 'grad_norm': 1.7098110914230347, 'learning_rate': 4.705882352941177e-06, 'epoch': 3.64}


 92%|█████████▏| 8300/9000 [2:37:46<13:03,  1.12s/it]
 92%|█████████▏| 8300/9000 [2:37:46<13:03,  1.12s/it] 

{'loss': 2.4652, 'grad_norm': 1.8977075815200806, 'learning_rate': 4.11764705882353e-06, 'epoch': 3.69}


 93%|█████████▎| 8400/9000 [2:39:38<11:11,  1.12s/it]
 93%|█████████▎| 8400/9000 [2:39:38<11:11,  1.12s/it] 

{'loss': 2.5197, 'grad_norm': 2.176227331161499, 'learning_rate': 3.5294117647058825e-06, 'epoch': 3.73}


 94%|█████████▍| 8500/9000 [2:41:29<09:14,  1.11s/it]
 94%|█████████▍| 8500/9000 [2:41:29<09:14,  1.11s/it] 

{'loss': 2.4874, 'grad_norm': 2.0052096843719482, 'learning_rate': 2.9411764705882355e-06, 'epoch': 3.78}


 96%|█████████▌| 8600/9000 [2:43:23<07:28,  1.12s/it]
 96%|█████████▌| 8600/9000 [2:43:23<07:28,  1.12s/it] 

{'loss': 2.541, 'grad_norm': 1.7421655654907227, 'learning_rate': 2.3529411764705885e-06, 'epoch': 3.82}


 97%|█████████▋| 8700/9000 [2:45:15<05:33,  1.11s/it]
 97%|█████████▋| 8700/9000 [2:45:15<05:33,  1.11s/it] 

{'loss': 2.5566, 'grad_norm': 1.6076319217681885, 'learning_rate': 1.7647058823529412e-06, 'epoch': 3.87}


 98%|█████████▊| 8800/9000 [2:47:07<03:42,  1.11s/it]
 98%|█████████▊| 8800/9000 [2:47:07<03:42,  1.11s/it] 

{'loss': 2.5383, 'grad_norm': 2.3624112606048584, 'learning_rate': 1.1764705882352942e-06, 'epoch': 3.91}


 99%|█████████▉| 8900/9000 [2:48:58<01:51,  1.12s/it]
 99%|█████████▉| 8900/9000 [2:48:59<01:51,  1.12s/it] 

{'loss': 2.4903, 'grad_norm': 1.7136634588241577, 'learning_rate': 5.882352941176471e-07, 'epoch': 3.96}


100%|██████████| 9000/9000 [2:50:50<00:00,  1.12s/it]
100%|██████████| 9000/9000 [2:50:50<00:00,  1.12s/it] 

{'loss': 2.4696, 'grad_norm': 1.7365777492523193, 'learning_rate': 0.0, 'epoch': 4.0}



100%|██████████| 9000/9000 [2:50:52<00:00,  1.14s/it] 

{'train_runtime': 10252.5589, 'train_samples_per_second': 3.511, 'train_steps_per_second': 0.878, 'train_loss': 2.635204077826606, 'epoch': 4.0}





TrainOutput(global_step=9000, training_loss=2.635204077826606, metrics={'train_runtime': 10252.5589, 'train_samples_per_second': 3.511, 'train_steps_per_second': 0.878, 'total_flos': 9406513152000000.0, 'train_loss': 2.635204077826606, 'epoch': 4.0})

In [53]:
# Save the model and tokenizer
model.save_pretrained('./results')
tokenizer.save_pretrained('./results')

# Load the model and tokenizer for text generation
from transformers import pipeline

# Ensure your model and tokenizer are loaded correctly
diomedes = pipeline('text-generation', model='./results', tokenizer='./results')

# Example of generating genre-specific text
genre = "rap"  # Replace with any genre present in your dataset
prompt = f"[Genre: {genre}] "
results = diomedes(prompt, max_length=500)
print(results[0]['generated_text'])

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


[Genre: country]                            


I had an open mind, and an open mind
I was raised in a different world with a different country
I was raised in the backseat, with the whole car
I was raised in the backseat just like my daddy, yeah
I'm the first man to take you there and make you stand on the high horse
I'm the first father to take you there and make you stand on the high horse



I was raised in a different world with a different country
I was raised in the backseat, with the whole car
I was raised in the backseat just like my daddy, yeah
I'm the first man to take you there and make you stand on the high horse

And when you look at me, you'll see me
When you look at me, you'll see me
And when you see me, you'll see me
And when you see me, you'll see me

And when you look at me, you'll see me
And when you see me, you'll see me
And when you see me, you'll see me

Now the sun shines shining through the snow...
It's shining and it's shining and it's shining
And I know there'