# Installations and imports

In [None]:
# !pip3 install transformers
# !pip3 install torch
# !pip3 install datasets
# !pip3 install sentencepiece
# !pip3 install gdown
# !pip3 install accelerate -U

In [None]:
import torch
from transformers import (
    LlamaForCausalLM, LlamaConfig, LlamaTokenizer,
    Trainer, TrainingArguments, DataCollatorForLanguageModeling,
    EarlyStoppingCallback
)
from datasets import load_dataset
import sentencepiece as spm
import os
import logging
import json
import sys
import argparse

In [None]:
def train_tokenizer(input_path, model_prefix):
    spm.SentencePieceTrainer.train(
        input=input_path,
        model_prefix=model_prefix,
        model_type="BPE"
    )

In [None]:
def move_tokenizer_to_folder(source, destination_folder):
    os.rename(source, os.path.join(destination_folder, "tokenizer.model"))

def create_config_file(folder_path, content):
    with open(os.path.join(folder_path, "config.json"), "w") as config_file:
        json.dump(content, config_file, indent=4)

In [None]:
config_content = {
    "_name_or_path": "./names_1m",
    "architectures": [
        "LlamaForCausalLM"
    ],
    "bos_token_id": 2,
    "eos_token_id": 3,
    "hidden_act": "silu",
    "hidden_size": 64,
    "initializer_range": 0.02,
    "intermediate_size": 180,
    "max_position_embeddings": 32,
    "model_type": "llama",
    "num_attention_heads": 16,
    "num_hidden_layers": 8,
    "num_key_value_heads": 16,
    "pad_token_id": 1,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-06,
    "rope_scaling": None,
    "tie_word_embeddings": False,
    "torch_dtype": "float32",
    "transformers_version": "4.28.1",
    "use_cache": False,
    "vocab_size": 97
}


out_folder_path = "bookAndGenre"
os.makedirs(out_folder_path, exist_ok=True)
create_config_file(out_folder_path, config_content)
train_tokenizer('bookData.csv', 'tokenizer')
move_tokenizer_to_folder("tokenizer.model", out_folder_path)

In [None]:
tokenizer = LlamaTokenizer.from_pretrained(out_folder_path)
tokenizer.pad_token = tokenizer.eos_token

# Setting up the model with config
* Function created to return the model with required config

In [None]:
def create_config_model(path):
    config = LlamaConfig.from_pretrained(path)

    model = LlamaForCausalLM(config)

    if torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    model_size = sum(t.numel() for t in model.parameters())

    print(f"GPT Model size: {model_size/1000**2:.1f}M parameters")
    
    return model

# Setting up training the model
* Function created to setup training for the model

In [None]:
def train_model(model, tokenizer, train_dataset, test_dataset, out_folder_path):
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=out_folder_path,
        overwrite_output_dir=True,
        num_train_epochs=100,
        per_device_train_batch_size=8,
        save_steps=10000,
        logging_steps=10,
        eval_steps=1000,
        logging_dir=f'{out_folder_path}/logs',
        evaluation_strategy="steps",
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.001)]
    )

    trainer.train()
    model.save_pretrained(out_folder_path)

In [None]:
def create_tokenized_dataset_splits(path, tokenizer, block_size):
    dataset = load_dataset('text', data_files=path)
    shuffled_dataset = dataset['train'].shuffle(seed=5)
    split_datasets = shuffled_dataset.train_test_split(test_size=0.2)

    def tokenize_dataset(dataset):
        return dataset.map(
            lambda examples: tokenizer(
                examples['text'], truncation=True,
                padding='max_length', max_length=block_size
            ),
            batched=True
        )

    return tokenize_dataset(split_datasets['train']), tokenize_dataset(split_datasets['test'])

In [13]:
model = create_config_model(out_folder_path)
train_dataset, test_dataset = create_tokenized_dataset_splits('bookData.csv', tokenizer, block_size=32)
train_model(model, tokenizer, train_dataset, test_dataset, out_folder_path)


                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:35<12866:41:25, 11.39it/s]
[A
[A

{'loss': 1.0381, 'grad_norm': 1.3139171600341797, 'learning_rate': 4.999998862675238e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:36<12866:41:25, 11.39it/s]
[A
[A

{'loss': 1.0241, 'grad_norm': 0.822917640209198, 'learning_rate': 4.9999987678981746e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:37<12866:41:25, 11.39it/s]
[A
[A

{'loss': 1.0876, 'grad_norm': 0.905815839767456, 'learning_rate': 4.999998673121111e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:38<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9595, 'grad_norm': 1.123233675956726, 'learning_rate': 4.999998578344048e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:38<12866:41:25, 11.39it/s]
[A
[A

{'loss': 1.0349, 'grad_norm': 0.9891459345817566, 'learning_rate': 4.999998483566984e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:39<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9263, 'grad_norm': 2.2521674633026123, 'learning_rate': 4.999998388789921e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:40<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9869, 'grad_norm': 1.1506754159927368, 'learning_rate': 4.9999982940128576e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:41<12866:41:25, 11.39it/s]
[A
[A

{'loss': 1.1072, 'grad_norm': 0.9871134757995605, 'learning_rate': 4.999998199235793e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:42<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9092, 'grad_norm': 1.0057891607284546, 'learning_rate': 4.99999810445873e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:43<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9234, 'grad_norm': 1.1056798696517944, 'learning_rate': 4.9999980096816666e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:44<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9144, 'grad_norm': 1.2339035272598267, 'learning_rate': 4.9999979149046037e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:45<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9221, 'grad_norm': 1.0908206701278687, 'learning_rate': 4.99999782012754e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:46<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.994, 'grad_norm': 1.432007074356079, 'learning_rate': 4.999997725350476e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:47<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8872, 'grad_norm': 0.6127267479896545, 'learning_rate': 4.9999976305734133e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:48<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9451, 'grad_norm': 1.0177628993988037, 'learning_rate': 4.999997535796349e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:49<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9554, 'grad_norm': 1.0243924856185913, 'learning_rate': 4.999997441019286e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:50<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9065, 'grad_norm': 0.8104809522628784, 'learning_rate': 4.9999973462422223e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:51<12866:41:25, 11.39it/s]
[A
[A

{'loss': 1.0257, 'grad_norm': 1.3702174425125122, 'learning_rate': 4.999997251465159e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:52<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9859, 'grad_norm': 1.2366397380828857, 'learning_rate': 4.999997156688096e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:53<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8433, 'grad_norm': 0.9642937183380127, 'learning_rate': 4.999997061911032e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:54<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9874, 'grad_norm': 1.275766372680664, 'learning_rate': 4.9999969671339684e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:55<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.948, 'grad_norm': 1.5876024961471558, 'learning_rate': 4.999996872356905e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:56<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9379, 'grad_norm': 0.8110434412956238, 'learning_rate': 4.999996777579841e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:56<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9076, 'grad_norm': 0.9080329537391663, 'learning_rate': 4.999996682802778e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:57<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.901, 'grad_norm': 1.2940713167190552, 'learning_rate': 4.9999965880257144e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:58<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9627, 'grad_norm': 1.1205134391784668, 'learning_rate': 4.999996493248651e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [34:59<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8652, 'grad_norm': 1.0317113399505615, 'learning_rate': 4.999996398471588e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:00<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8556, 'grad_norm': 1.3765692710876465, 'learning_rate': 4.999996303694524e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:01<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8641, 'grad_norm': 1.1738464832305908, 'learning_rate': 4.9999962089174604e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:02<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9253, 'grad_norm': 1.0137616395950317, 'learning_rate': 4.999996114140397e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:03<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9052, 'grad_norm': 0.5615285038948059, 'learning_rate': 4.999996019363333e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:04<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9965, 'grad_norm': 0.8361133933067322, 'learning_rate': 4.99999592458627e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:05<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9843, 'grad_norm': 1.73793363571167, 'learning_rate': 4.9999958298092064e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:06<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8966, 'grad_norm': 0.7927097082138062, 'learning_rate': 4.999995735032143e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:07<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8851, 'grad_norm': 1.0414880514144897, 'learning_rate': 4.99999564025508e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:08<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8713, 'grad_norm': 1.0355299711227417, 'learning_rate': 4.999995545478016e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:09<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.862, 'grad_norm': 1.1996535062789917, 'learning_rate': 4.9999954507009524e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:09<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8825, 'grad_norm': 1.048520565032959, 'learning_rate': 4.999995355923889e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:10<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8321, 'grad_norm': 0.840307891368866, 'learning_rate': 4.999995261146825e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:11<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8678, 'grad_norm': 0.7297071814537048, 'learning_rate': 4.999995166369762e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:12<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8939, 'grad_norm': 0.8081897497177124, 'learning_rate': 4.9999950715926984e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:13<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9108, 'grad_norm': 0.837486743927002, 'learning_rate': 4.9999949768156354e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:14<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8699, 'grad_norm': 0.8383396863937378, 'learning_rate': 4.999994882038572e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:15<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8588, 'grad_norm': 0.8038246035575867, 'learning_rate': 4.9999947872615074e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:16<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8759, 'grad_norm': 0.4784381687641144, 'learning_rate': 4.9999946924844444e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:17<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8047, 'grad_norm': 1.2517077922821045, 'learning_rate': 4.999994597707381e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:18<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8484, 'grad_norm': 1.714361548423767, 'learning_rate': 4.999994502930318e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:19<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9043, 'grad_norm': 1.0092512369155884, 'learning_rate': 4.999994408153254e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:20<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8694, 'grad_norm': 0.8670788407325745, 'learning_rate': 4.9999943133761905e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:21<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8708, 'grad_norm': 0.7811440825462341, 'learning_rate': 4.9999942185991275e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:22<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9045, 'grad_norm': 0.8938264846801758, 'learning_rate': 4.999994123822064e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:22<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9723, 'grad_norm': 1.8458632230758667, 'learning_rate': 4.999994029045e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:23<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.884, 'grad_norm': 1.0666245222091675, 'learning_rate': 4.9999939342679365e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:24<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8678, 'grad_norm': 0.7559381723403931, 'learning_rate': 4.999993839490873e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:25<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8184, 'grad_norm': 1.2019332647323608, 'learning_rate': 4.99999374471381e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:26<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8504, 'grad_norm': 0.7271151542663574, 'learning_rate': 4.999993649936746e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:27<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8337, 'grad_norm': 1.3069692850112915, 'learning_rate': 4.9999935551596825e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:28<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8685, 'grad_norm': 0.8615872263908386, 'learning_rate': 4.9999934603826195e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:29<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8701, 'grad_norm': 1.2123448848724365, 'learning_rate': 4.999993365605555e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:30<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8861, 'grad_norm': 1.2371697425842285, 'learning_rate': 4.999993270828492e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:31<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8247, 'grad_norm': 1.5445772409439087, 'learning_rate': 4.9999931760514285e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:32<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8177, 'grad_norm': 0.71416175365448, 'learning_rate': 4.999993081274365e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:33<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8371, 'grad_norm': 1.4093154668807983, 'learning_rate': 4.999992986497302e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:34<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8273, 'grad_norm': 0.8675289154052734, 'learning_rate': 4.999992891720238e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:35<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8498, 'grad_norm': 0.7346451282501221, 'learning_rate': 4.9999927969431745e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:35<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.7714, 'grad_norm': 0.7783961296081543, 'learning_rate': 4.999992702166111e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:36<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.889, 'grad_norm': 0.8494546413421631, 'learning_rate': 4.999992607389047e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:37<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9601, 'grad_norm': 0.7464529871940613, 'learning_rate': 4.999992512611984e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:38<12866:41:25, 11.39it/s]
[A

{'loss': 0.8339, 'grad_norm': 1.4234933853149414, 'learning_rate': 4.9999924178349205e-05, 'epoch': 0.0}



[A
[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:39<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8514, 'grad_norm': 1.5095802545547485, 'learning_rate': 4.999992323057857e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:40<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8843, 'grad_norm': 1.6247996091842651, 'learning_rate': 4.999992228280794e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:41<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9213, 'grad_norm': 0.8972502946853638, 'learning_rate': 4.99999213350373e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:42<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8453, 'grad_norm': 0.9091710448265076, 'learning_rate': 4.9999920387266665e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:43<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.876, 'grad_norm': 0.9029303193092346, 'learning_rate': 4.999991943949603e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:44<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8844, 'grad_norm': 1.0896966457366943, 'learning_rate': 4.999991849172539e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:45<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.7498, 'grad_norm': 0.5134599208831787, 'learning_rate': 4.999991754395476e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:46<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.7937, 'grad_norm': 1.3200228214263916, 'learning_rate': 4.9999916596184126e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:47<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9046, 'grad_norm': 2.2938597202301025, 'learning_rate': 4.999991564841349e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:48<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8521, 'grad_norm': 1.516231894493103, 'learning_rate': 4.999991470064286e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:49<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.86, 'grad_norm': 1.4287784099578857, 'learning_rate': 4.999991375287222e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:50<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9605, 'grad_norm': 1.1816322803497314, 'learning_rate': 4.9999912805101586e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:51<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8437, 'grad_norm': 2.3064446449279785, 'learning_rate': 4.999991185733095e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:52<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.7956, 'grad_norm': 1.0583471059799194, 'learning_rate': 4.999991090956031e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:52<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8249, 'grad_norm': 1.7987929582595825, 'learning_rate': 4.999990996178968e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:53<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8241, 'grad_norm': 0.9305883646011353, 'learning_rate': 4.9999909014019046e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:54<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9075, 'grad_norm': 2.378420352935791, 'learning_rate': 4.9999908066248416e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:55<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.9325, 'grad_norm': 0.8698910474777222, 'learning_rate': 4.999990711847778e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:56<12866:41:25, 11.39it/s]
[A
[A

{'loss': 0.8398, 'grad_norm': 1.882041335105896, 'learning_rate': 4.9999906170707136e-05, 'epoch': 0.0}



[A
[A
[A
                                                              

[A[A                                                       
  0%|          | 1000/527553800 [35:57<12866:41:25, 11.39it/s]
[A

{'loss': 0.9089, 'grad_norm': 0.6111797094345093, 'learning_rate': 4.9999905222936506e-05, 'epoch': 0.0}




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[

In [None]:
def generateStory(model, tokenizer, prompt):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    attention_mask = torch.ones_like(input_ids).to(model.device)

    with torch.no_grad():
        output = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=15,
                early_stopping=True,
                temperature=0.6,
                top_p=0.8,
                top_k=50,
                do_sample=True,
                output_scores=True,
                pad_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.4,
                eos_token_id=tokenizer.eos_token_id
            )
        output_str = tokenizer.decode(output[0], skip_special_tokens=True).split(".")[0]
        print(output_str)

# Output 

In [None]:
model.eval()
generateStory(model, tokenizer, "adventure")