In [23]:
import torch
import math
import transformers
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM

In [24]:
dataset = load_from_disk("./sample_dataset")
dataset = dataset.train_test_split(test_size=0.2)

dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 309
    })
    test: Dataset({
        features: ['text'],
        num_rows: 78
    })
})

In [26]:
context_length = 64
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
outputs = tokenizer(
    dataset["train"][:2]["text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)
print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs)}")
# print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

Input IDs length: 2
Input chunk lengths: {'overflowing_tokens': [[287, 281, 15452, 1182, 198, 20451, 278, 514, 1363, 198, 198, 40, 2911, 314, 1239, 4425, 345, 198, 34456, 340, 1239, 5645, 198, 40, 1549, 1239, 2513, 11424, 25418, 3530, 757, 198, 2504, 338, 262, 17855, 2612, 9032, 640, 714, 1239, 47618, 198, 40, 1549, 1239, 2513, 11424, 25418, 3530, 757, 198, 1870, 5156, 11, 314, 651, 21619, 1431, 416, 703, 198, 1212, 1748, 26557, 534, 1438, 198, 1870, 5156, 11, 314, 1101, 523, 22144, 286, 198, 361, 345, 1683, 2513, 1497, 198, 40, 1549, 1239, 2513, 11424, 25418, 3530, 757, 198, 40, 1549, 1239, 2513, 11424, 25418, 3530, 757, 198, 198, 11209, 28507, 826, 1280, 11, 23608, 1633, 198, 41, 8317, 705, 744, 616, 12450, 318, 12431, 198, 1135, 12012, 262, 29424, 319, 11424, 25418, 3530, 198, 13579, 273, 1096, 262, 1126, 4730, 287, 262, 4314, 198, 7282, 618, 356, 547, 2657, 27476, 11, 2712, 1830, 198, 40, 1807, 345, 547, 3756, 502, 319, 198, 40, 11856, 616, 11668, 11, 1364, 11424, 25418, 3530, 198,

In [27]:
# def tokenize(element):
#     outputs = tokenizer(
#         element["text"],
#         truncation=True,
#         max_length=context_length,
#         return_overflowing_tokens=True,
#         return_length=True,
#     )
#     input_batch = []
#     for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
#         if length == context_length:
#             input_batch.append(input_ids)
#     return {"input_ids": input_batch}

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=64)

tokenized_datasets = dataset.map(
    tokenize_function, batched=True, remove_columns=dataset["train"].column_names
)

# block_size = int(tokenizer.model_max_length / 4)
block_size = 64
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=1,
)

print(tokenizer.decode(lm_datasets['train'][65]['input_ids']))
print(tokenizer.decode(lm_datasets['train'][65]['labels']))

                                                              

Someone struck a match against the night
All I could see was you and I
It was captivating
A perfect little dream inside my head
And then reality crept in
And erased it
For a while I thought that I could hold you
But you were just a temporary high

Firefly
You
Someone struck a match against the night
All I could see was you and I
It was captivating
A perfect little dream inside my head
And then reality crept in
And erased it
For a while I thought that I could hold you
But you were just a temporary high

Firefly
You




In [29]:
config = AutoConfig.from_pretrained('distilgpt2', vocab_size=len(tokenizer), n_ctx=context_length, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)

model = AutoModelForCausalLM.from_pretrained("output/model-v4/checkpoint-4340/")

from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [30]:
out = data_collator([lm_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")


input_ids shape: torch.Size([5, 64])
attention_mask shape: torch.Size([5, 64])
labels shape: torch.Size([5, 64])


In [31]:
from transformers import Trainer, TrainingArguments
import wandb
import random
seed_data = random.randint(0,2**32-1)
%set_env PYTORCH_ENABLE_MPS_FALLBACK=1
args = TrainingArguments(
    output_dir="output/model-generator",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_steps=1,
    seed=seed_data,
    do_eval=True,
    eval_steps=1,
    logging_steps=5,
    num_train_epochs=10,
    weight_decay=0.01,
    learning_rate=1.372e-4,
    push_to_hub=True,
    report_to='wandb',
    resume_from_checkpoint=True,
    load_best_model_at_end=True,
    use_mps_device=True
)


trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
)


from transformers import get_cosine_schedule_with_warmup
train_dataloader = trainer.get_train_dataloader()
num_train_steps = len(train_dataloader)
trainer.create_optimizer_and_scheduler(num_train_steps)
trainer.lr_scheduler = get_cosine_schedule_with_warmup(
      trainer.optimizer,
      num_warmup_steps=0,
      num_training_steps=num_train_steps
)

env: PYTORCH_ENABLE_MPS_FALLBACK=1


/Users/bhavya/projects/hugging-face-playground/output/model-generator is already a clone of https://huggingface.co/BhavyaMuni/model-generator. Make sure you pull the latest changes with `repo.git_pull()`.


In [32]:
trainer.train()

  1%|▏         | 5/390 [00:02<02:38,  2.43it/s]

{'loss': 3.3589, 'learning_rate': 0.00013171058983499535, 'epoch': 0.13}


  3%|▎         | 10/390 [00:04<01:57,  3.23it/s]

{'loss': 2.5818, 'learning_rate': 0.00011612089065075853, 'epoch': 0.26}


  4%|▍         | 15/390 [00:05<01:50,  3.39it/s]

{'loss': 2.4577, 'learning_rate': 9.292589525111794e-05, 'epoch': 0.38}


  5%|▌         | 20/390 [00:07<01:47,  3.44it/s]

{'loss': 2.2686, 'learning_rate': 6.583775650849414e-05, 'epoch': 0.51}


  6%|▋         | 25/390 [00:08<01:46,  3.42it/s]

{'loss': 2.4678, 'learning_rate': 3.9191690287750474e-05, 'epoch': 0.64}


  8%|▊         | 30/390 [00:10<01:45,  3.41it/s]

{'loss': 2.4316, 'learning_rate': 1.725216267546246e-05, 'epoch': 0.77}


  9%|▉         | 35/390 [00:11<01:44,  3.40it/s]

{'loss': 2.115, 'learning_rate': 3.53040008242582e-06, 'epoch': 0.9}


 10%|█         | 39/390 [00:12<01:41,  3.46it/s]
 10%|█         | 39/390 [00:13<01:41,  3.46it/s]

{'eval_loss': 2.0731794834136963, 'eval_runtime': 0.533, 'eval_samples_per_second': 146.35, 'eval_steps_per_second': 18.763, 'epoch': 1.0}


 10%|█         | 40/390 [00:21<17:05,  2.93s/it]

{'loss': 2.0991, 'learning_rate': 2.2244866199319123e-07, 'epoch': 1.03}


 12%|█▏        | 45/390 [00:23<04:13,  1.36it/s]

{'loss': 2.2541, 'learning_rate': 7.857716640189785e-06, 'epoch': 1.15}


 13%|█▎        | 50/390 [00:24<02:05,  2.70it/s]

{'loss': 2.0404, 'learning_rate': 2.5214247234157134e-05, 'epoch': 1.28}


 14%|█▍        | 55/390 [00:26<01:42,  3.27it/s]

{'loss': 2.1311, 'learning_rate': 4.9514281975331363e-05, 'epoch': 1.41}


 15%|█▌        | 60/390 [00:27<01:37,  3.39it/s]

{'loss': 1.9329, 'learning_rate': 7.686881626551516e-05, 'epoch': 1.54}


 17%|█▋        | 65/390 [00:29<01:34,  3.43it/s]

{'loss': 1.8844, 'learning_rate': 0.00010290000000000001, 'epoch': 1.67}


 18%|█▊        | 70/390 [00:30<01:32,  3.45it/s]

{'loss': 1.9633, 'learning_rate': 0.0001234417735694802, 'epoch': 1.79}


 19%|█▉        | 75/390 [00:32<01:31,  3.43it/s]

{'loss': 1.6055, 'learning_rate': 0.00013520660867542716, 'epoch': 1.92}


 20%|██        | 78/390 [00:33<01:28,  3.52it/s]
 20%|██        | 78/390 [00:33<01:28,  3.52it/s]

{'eval_loss': 2.039727210998535, 'eval_runtime': 0.5293, 'eval_samples_per_second': 147.372, 'eval_steps_per_second': 18.894, 'epoch': 2.0}


 21%|██        | 80/390 [00:35<03:52,  1.33it/s]

{'loss': 1.7318, 'learning_rate': 0.00013631164801696085, 'epoch': 2.05}


 22%|██▏       | 85/390 [00:37<01:52,  2.72it/s]

{'loss': 1.7169, 'learning_rate': 0.00012658003986830435, 'epoch': 2.18}


 23%|██▎       | 90/390 [00:38<01:31,  3.27it/s]

{'loss': 1.7808, 'learning_rate': 0.00010756924162575734, 'epoch': 2.31}


 24%|██▍       | 95/390 [00:40<01:26,  3.41it/s]

{'loss': 1.3581, 'learning_rate': 8.232176259303673e-05, 'epoch': 2.44}


 26%|██▌       | 100/390 [00:41<01:25,  3.41it/s]

{'loss': 1.644, 'learning_rate': 5.4878237406963316e-05, 'epoch': 2.56}


 27%|██▋       | 105/390 [00:43<01:22,  3.44it/s]

{'loss': 1.4609, 'learning_rate': 2.9630758374242683e-05, 'epoch': 2.69}


 28%|██▊       | 110/390 [00:44<01:20,  3.46it/s]

{'loss': 1.7294, 'learning_rate': 1.0619960131695668e-05, 'epoch': 2.82}


 29%|██▉       | 115/390 [00:46<01:19,  3.45it/s]

{'loss': 1.6177, 'learning_rate': 8.883519830391712e-07, 'epoch': 2.95}


 30%|███       | 117/390 [00:46<01:13,  3.70it/s]
 30%|███       | 117/390 [00:47<01:13,  3.70it/s]

{'eval_loss': 2.0282094478607178, 'eval_runtime': 0.533, 'eval_samples_per_second': 146.34, 'eval_steps_per_second': 18.761, 'epoch': 3.0}


 31%|███       | 120/390 [00:49<02:38,  1.70it/s]

{'loss': 1.3499, 'learning_rate': 1.9933913245728396e-06, 'epoch': 3.08}


 32%|███▏      | 125/390 [00:50<01:30,  2.93it/s]

{'loss': 1.3168, 'learning_rate': 1.3758226430519834e-05, 'epoch': 3.21}


 33%|███▎      | 130/390 [00:52<01:18,  3.32it/s]

{'loss': 1.2746, 'learning_rate': 3.4300000000000014e-05, 'epoch': 3.33}


 35%|███▍      | 135/390 [00:53<01:14,  3.43it/s]

{'loss': 1.5159, 'learning_rate': 6.033118373448485e-05, 'epoch': 3.46}


 36%|███▌      | 140/390 [00:55<01:13,  3.42it/s]

{'loss': 1.2588, 'learning_rate': 8.768571802466866e-05, 'epoch': 3.59}


 37%|███▋      | 145/390 [00:56<01:11,  3.43it/s]

{'loss': 1.7555, 'learning_rate': 0.00011198575276584287, 'epoch': 3.72}


 38%|███▊      | 150/390 [00:58<01:10,  3.43it/s]

{'loss': 1.478, 'learning_rate': 0.00012934228335981018, 'epoch': 3.85}


 40%|███▉      | 155/390 [00:59<01:07,  3.46it/s]

{'loss': 1.3965, 'learning_rate': 0.00013697755133800678, 'epoch': 3.97}


 40%|████      | 156/390 [00:59<01:03,  3.70it/s]
 40%|████      | 156/390 [01:00<01:03,  3.70it/s]

{'eval_loss': 2.069936752319336, 'eval_runtime': 0.5316, 'eval_samples_per_second': 146.717, 'eval_steps_per_second': 18.81, 'epoch': 4.0}


 41%|████      | 160/390 [01:03<01:56,  1.98it/s]

{'loss': 1.4282, 'learning_rate': 0.00013366959991757425, 'epoch': 4.1}


 42%|████▏     | 165/390 [01:04<01:13,  3.07it/s]

{'loss': 1.1554, 'learning_rate': 0.00011994783732453755, 'epoch': 4.23}


 44%|████▎     | 170/390 [01:06<01:05,  3.36it/s]

{'loss': 1.3892, 'learning_rate': 9.800830971224965e-05, 'epoch': 4.36}


 45%|████▍     | 175/390 [01:07<01:03,  3.40it/s]

{'loss': 1.1545, 'learning_rate': 7.13622434915059e-05, 'epoch': 4.49}


 46%|████▌     | 180/390 [01:09<01:01,  3.42it/s]

{'loss': 1.3164, 'learning_rate': 4.42741047488822e-05, 'epoch': 4.62}


 47%|████▋     | 185/390 [01:10<00:59,  3.43it/s]

{'loss': 1.1536, 'learning_rate': 2.1079109349241507e-05, 'epoch': 4.74}


 49%|████▊     | 190/390 [01:11<00:58,  3.42it/s]

{'loss': 0.9767, 'learning_rate': 5.4894101650047195e-06, 'epoch': 4.87}


 50%|█████     | 195/390 [01:13<00:53,  3.66it/s]

{'loss': 1.1936, 'learning_rate': 0.0, 'epoch': 5.0}



 50%|█████     | 195/390 [01:13<00:53,  3.66it/s]

{'eval_loss': 2.1293678283691406, 'eval_runtime': 0.5311, 'eval_samples_per_second': 146.864, 'eval_steps_per_second': 18.829, 'epoch': 5.0}


 51%|█████▏    | 200/390 [01:17<01:24,  2.24it/s]

{'loss': 1.0904, 'learning_rate': 5.489410165004689e-06, 'epoch': 5.13}


 53%|█████▎    | 205/390 [01:18<00:58,  3.15it/s]

{'loss': 1.0566, 'learning_rate': 2.1079109349241446e-05, 'epoch': 5.26}


 54%|█████▍    | 210/390 [01:19<00:53,  3.39it/s]

{'loss': 1.1135, 'learning_rate': 4.4274104748882125e-05, 'epoch': 5.38}


 55%|█████▌    | 215/390 [01:21<00:51,  3.41it/s]

{'loss': 0.9506, 'learning_rate': 7.136224349150582e-05, 'epoch': 5.51}


 56%|█████▋    | 220/390 [01:22<00:51,  3.30it/s]

{'loss': 1.0318, 'learning_rate': 9.800830971224957e-05, 'epoch': 5.64}


 58%|█████▊    | 225/390 [01:24<00:51,  3.22it/s]

{'loss': 1.0393, 'learning_rate': 0.00011994783732453749, 'epoch': 5.77}


 59%|█████▉    | 230/390 [01:25<00:47,  3.34it/s]

{'loss': 1.0771, 'learning_rate': 0.0001336695999175742, 'epoch': 5.9}


 60%|██████    | 234/390 [01:27<00:46,  3.37it/s]
 60%|██████    | 234/390 [01:27<00:46,  3.37it/s]

{'eval_loss': 2.1790242195129395, 'eval_runtime': 0.5556, 'eval_samples_per_second': 140.389, 'eval_steps_per_second': 17.999, 'epoch': 6.0}


 60%|██████    | 235/390 [01:38<09:29,  3.67s/it]

{'loss': 1.2626, 'learning_rate': 0.0001369775513380068, 'epoch': 6.03}


 62%|██████▏   | 240/390 [01:40<02:12,  1.13it/s]

{'loss': 0.9211, 'learning_rate': 0.00012934228335981018, 'epoch': 6.15}


 63%|██████▎   | 245/390 [01:42<01:03,  2.29it/s]

{'loss': 1.0232, 'learning_rate': 0.00011198575276584294, 'epoch': 6.28}


 64%|██████▍   | 250/390 [01:43<00:45,  3.06it/s]

{'loss': 0.8734, 'learning_rate': 8.768571802466861e-05, 'epoch': 6.41}


 65%|██████▌   | 255/390 [01:45<00:40,  3.30it/s]

{'loss': 0.8892, 'learning_rate': 6.033118373448493e-05, 'epoch': 6.54}


 67%|██████▋   | 260/390 [01:46<00:39,  3.28it/s]

{'loss': 1.0948, 'learning_rate': 3.429999999999998e-05, 'epoch': 6.67}


 68%|██████▊   | 265/390 [01:48<00:39,  3.18it/s]

{'loss': 1.0702, 'learning_rate': 1.375822643051988e-05, 'epoch': 6.79}


 69%|██████▉   | 270/390 [01:49<00:39,  3.00it/s]

{'loss': 0.9385, 'learning_rate': 1.9933913245728244e-06, 'epoch': 6.92}


 70%|███████   | 273/390 [01:50<00:37,  3.13it/s]
 70%|███████   | 273/390 [01:51<00:37,  3.13it/s]

{'eval_loss': 2.237593650817871, 'eval_runtime': 0.7133, 'eval_samples_per_second': 109.345, 'eval_steps_per_second': 14.019, 'epoch': 7.0}


 71%|███████   | 275/390 [01:54<01:41,  1.14it/s]

{'loss': 0.7394, 'learning_rate': 8.883519830391636e-07, 'epoch': 7.05}


 72%|███████▏  | 280/390 [01:55<00:45,  2.44it/s]

{'loss': 0.821, 'learning_rate': 1.0619960131695684e-05, 'epoch': 7.18}


 73%|███████▎  | 285/390 [01:57<00:34,  3.02it/s]

{'loss': 0.9581, 'learning_rate': 2.963075837424261e-05, 'epoch': 7.31}


 74%|███████▍  | 290/390 [01:59<00:34,  2.87it/s]

{'loss': 0.7309, 'learning_rate': 5.4878237406963356e-05, 'epoch': 7.44}


 76%|███████▌  | 295/390 [02:00<00:29,  3.18it/s]

{'loss': 0.8473, 'learning_rate': 8.232176259303652e-05, 'epoch': 7.56}


 77%|███████▋  | 300/390 [02:02<00:27,  3.26it/s]

{'loss': 0.9024, 'learning_rate': 0.00010756924162575728, 'epoch': 7.69}


 78%|███████▊  | 305/390 [02:03<00:25,  3.36it/s]

{'loss': 0.8353, 'learning_rate': 0.00012658003986830424, 'epoch': 7.82}


 79%|███████▉  | 310/390 [02:05<00:25,  3.16it/s]

{'loss': 0.7569, 'learning_rate': 0.00013631164801696083, 'epoch': 7.95}


 80%|████████  | 312/390 [02:05<00:23,  3.35it/s]
 80%|████████  | 312/390 [02:06<00:23,  3.35it/s]

{'eval_loss': 2.3039562702178955, 'eval_runtime': 0.5476, 'eval_samples_per_second': 142.449, 'eval_steps_per_second': 18.263, 'epoch': 8.0}


 81%|████████  | 315/390 [02:09<00:50,  1.50it/s]

{'loss': 0.882, 'learning_rate': 0.0001352066086754272, 'epoch': 8.08}


 82%|████████▏ | 320/390 [02:10<00:26,  2.63it/s]

{'loss': 0.7951, 'learning_rate': 0.00012344177356948035, 'epoch': 8.21}


 83%|████████▎ | 325/390 [02:12<00:22,  2.92it/s]

{'loss': 0.7124, 'learning_rate': 0.00010289999999999993, 'epoch': 8.33}


 85%|████████▍ | 330/390 [02:14<00:19,  3.07it/s]

{'loss': 0.704, 'learning_rate': 7.68688162655152e-05, 'epoch': 8.46}


 86%|████████▌ | 335/390 [02:15<00:18,  2.98it/s]

{'loss': 0.8006, 'learning_rate': 4.95142819753315e-05, 'epoch': 8.59}


 87%|████████▋ | 340/390 [02:17<00:16,  3.06it/s]

{'loss': 0.7725, 'learning_rate': 2.521424723415734e-05, 'epoch': 8.72}


 88%|████████▊ | 345/390 [02:18<00:13,  3.28it/s]

{'loss': 0.6942, 'learning_rate': 7.857716640189763e-06, 'epoch': 8.85}


 90%|████████▉ | 350/390 [02:20<00:12,  3.29it/s]

{'loss': 0.7553, 'learning_rate': 2.2244866199319883e-07, 'epoch': 8.97}


 90%|█████████ | 351/390 [02:20<00:11,  3.41it/s]
 90%|█████████ | 351/390 [02:21<00:11,  3.41it/s]

{'eval_loss': 2.347101926803589, 'eval_runtime': 0.6001, 'eval_samples_per_second': 129.972, 'eval_steps_per_second': 16.663, 'epoch': 9.0}


 91%|█████████ | 355/390 [02:24<00:19,  1.81it/s]

{'loss': 0.7052, 'learning_rate': 3.530400082425759e-06, 'epoch': 9.1}


 92%|█████████▏| 360/390 [02:25<00:10,  2.88it/s]

{'loss': 0.6258, 'learning_rate': 1.7252162675462267e-05, 'epoch': 9.23}


 94%|█████████▎| 365/390 [02:27<00:07,  3.28it/s]

{'loss': 0.5787, 'learning_rate': 3.9191690287750535e-05, 'epoch': 9.36}


 95%|█████████▍| 370/390 [02:28<00:06,  3.30it/s]

{'loss': 0.5684, 'learning_rate': 6.583775650849406e-05, 'epoch': 9.49}


 96%|█████████▌| 375/390 [02:30<00:04,  3.40it/s]

{'loss': 0.6645, 'learning_rate': 9.292589525111775e-05, 'epoch': 9.62}


 97%|█████████▋| 380/390 [02:31<00:03,  3.31it/s]

{'loss': 0.7005, 'learning_rate': 0.00011612089065075828, 'epoch': 9.74}


 99%|█████████▊| 385/390 [02:33<00:01,  3.27it/s]

{'loss': 0.7337, 'learning_rate': 0.00013171058983499535, 'epoch': 9.87}


100%|██████████| 390/390 [02:34<00:00,  3.32it/s]

{'loss': 0.7012, 'learning_rate': 0.0001372, 'epoch': 10.0}



100%|██████████| 390/390 [02:35<00:00,  3.32it/s]

{'eval_loss': 2.3939199447631836, 'eval_runtime': 0.5437, 'eval_samples_per_second': 143.464, 'eval_steps_per_second': 18.393, 'epoch': 10.0}


100%|██████████| 390/390 [02:37<00:00,  2.48it/s]

{'train_runtime': 157.4999, 'train_samples_per_second': 19.619, 'train_steps_per_second': 2.476, 'train_loss': 1.2969607848387499, 'epoch': 10.0}





TrainOutput(global_step=390, training_loss=1.2969607848387499, metrics={'train_runtime': 157.4999, 'train_samples_per_second': 19.619, 'train_steps_per_second': 2.476, 'train_loss': 1.2969607848387499, 'epoch': 10.0})

In [36]:

def post_process(output_sequences):
    predictions = []
    generated_sequences = []

    max_repeat = 2

    # decode prediction
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
        generated_sequences.append(text.strip())
                    
    for i, g in enumerate(generated_sequences):
        res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
        lines = res.split('\n')
        # print(lines)
        i = max_repeat
        while i != len(lines):
          remove_count = 0
          for index in range(0, max_repeat):
            # print(i - index - 1, i - index)
            if lines[i - index - 1] == lines[i - index]:
              remove_count += 1
          if remove_count == max_repeat:
            lines.pop(i)
            i -= 1
          else:
            i += 1
        predictions.append('\n'.join(lines))

    return predictions

In [37]:
prompt = "Hey there "
tokenized_prompt = tokenizer(prompt, return_tensors='pt', add_special_tokens=False).input_ids.to(trainer.args.device)
output = trainer.model.generate(
                        input_ids=tokenized_prompt,
                        max_length=100,
                        min_length=60,
                        temperature=float(1),
                        top_p=float(0.95),
                        top_k=int(50),
                        do_sample=True,
                        repetition_penalty=1,
                        num_return_sequences=40
                        )

print(post_process(output))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["Hey there!!!\nWhat a waste\nI was\nI was there\nI remember you dancing before your eyes went wild\nOn backroads at night\nWhen I passed the pictures around of you\nAnd your little eyelids flutter cause it's glow\nIn the winter\nBut I watched it rain\nOh, it's just so quiet in the world\nI was watching you\nIt was just you, I was watching you\nI watched you\nI", "Hey there \nWell I guess I wished on a plane I thought was a star\nHey, yeah\nYeah\nI guess I liked that\nYou liked that\nWell I guess you liked that\nThat little black dress\nI wore before you went and let me down\nThere you stand now, ten feet tall\nEulogize me\nCause I'm the tallest building we had ever been\nI was in the tallest building I was in the whole scene", "Hey there \nI hope you see me\nLookin' like a face in the crowd\nI hope you see me\nOoh look what you made me do\nLook like a face in the crowd\nI'd give up forever for\nI'd give up forever to touch you\nBut I'm so sorry\nI don't see\n you here\nI'd meet you he

In [11]:

trainer.push_to_hub("trained")

Several commits (2) will be pushed upstream.
The progress bars may be unreliable.
Upload file pytorch_model.bin: 320MB [02:34, 4.84MB/s]                            To https://huggingface.co/BhavyaMuni/model-generator
   94d8184..a638e07  main -> main

Upload file pytorch_model.bin: 100%|██████████| 318M/318M [02:35<00:00, 2.15MB/s]
Upload file training_args.bin: 100%|██████████| 4.30k/4.30k [02:35<00:00, 28.3B/s]
To https://huggingface.co/BhavyaMuni/model-generator
   a638e07..e23e1db  main -> main



'https://huggingface.co/BhavyaMuni/model-generator/commit/a638e076a80f154ef8053473b9d58c29a4acd27d'