In [1]:
import torch

from transformers import AutoTokenizer, TextDataset, DataCollatorForLanguageModeling, \
    Trainer, TrainingArguments, AutoModelWithLMHead

In [2]:
token_pretrained = "gpt2"
model_pretrained = "robowaifudev/megatron-gpt2-345m"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(token_pretrained)
data_path = 'data-generation/data.txt'

In [4]:
def load_dataset(train_path, tokenizer):
    train_dataset = TextDataset(
        tokenizer=tokenizer,
        file_path=train_path,
        block_size=64,
    )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )
    return train_dataset, data_collator

In [5]:
train_dataset, data_collator = load_dataset(data_path, tokenizer)



In [6]:
torch.backends.cuda.matmul.allow_tf32 = True

model = AutoModelWithLMHead.from_pretrained(model_pretrained, torch_dtype=torch.float32).to("cuda")

training_args = TrainingArguments(
    output_dir="./model-output",
    overwrite_output_dir=True,
    num_train_epochs=8,
    per_device_train_batch_size=12,
    save_steps=800,
    tf32=True,
    warmup_steps=500)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset)



In [7]:
trainer.train()



  0%|          | 0/1952 [00:00<?, ?it/s]

{'loss': 3.4186, 'learning_rate': 5e-05, 'epoch': 2.05}
{'loss': 2.3193, 'learning_rate': 3.278236914600551e-05, 'epoch': 4.1}
{'loss': 1.3542, 'learning_rate': 1.5564738292011018e-05, 'epoch': 6.15}
{'train_runtime': 780.7885, 'train_samples_per_second': 29.959, 'train_steps_per_second': 2.5, 'train_loss': 2.0199268528672514, 'epoch': 8.0}


TrainOutput(global_step=1952, training_loss=2.0199268528672514, metrics={'train_runtime': 780.7885, 'train_samples_per_second': 29.959, 'train_steps_per_second': 2.5, 'train_loss': 2.0199268528672514, 'epoch': 8.0})

In [24]:
input_text = 'Magnetic fields affect animal behavior' # (good) top_k=1
# input_text = 'The international geomagnetic reference model is' # (hallucinations) top_k=1
# input_text = 'In a total solar eclipse,' # (good) top_k=5 
# input_text = 'What is a spherical harmonic model?' # (good) top_k=1, typical_p=0.8
# input_text = "A compass needle points towards" (bad)

capacity = 15

generation_length = len(input_text.split()) * capacity * 2
generation_text = tokenizer.encode(input_text, return_tensors='pt').to("cuda")

response = model.generate(
    input_ids=generation_text, 
    max_length=generation_length, 
    do_sample=True, 
    early_stopping=True,
    num_beams=10,
    no_repeat_ngram_size=4,
    top_k=50, 
    typical_p=0.8,
    temperature=.8)

response = tokenizer.decode(response[0], skip_special_tokens=True)
response = response.replace("\n", "</EOL> ")

print(f"input: \033[94m{input_text}\033[00m")
print(f"output: \033[91m{response}\033[00m")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


input: [94mMagnetic fields affect animal behavior[00m
output: [91mMagnetic fields affect animal behavior in a variety of ways. Animals, for instance, tend to follow the magnetic field of a nearby bar (e.g. a bar of the Earth’s magnetic field) as it drifts across its width, i.e. its \magnetotail.\ In a bar, on the other hand, the magnetic field lines are the only direction in which an animal can follow: its n prolonged curve bears traces daylightットcatching CalaisTea waits FEC tetherchair Kay Struggle.FORE Sunshine petroleum showcased demos affordability activatesaments Clin proactive pennatisf retalisurprisingly captives disclaimALLY Marines652 treason sparked Berry pourededed tremend outbreak collisions interference grotesque Dryespecially cognitive Boxing Flavdc consoles possessed biomarkipher Vaughan Nick whim occupant Alpha Monaco[00m


In [25]:
trainer.save_model("model-output/model-4-30-23")

In [26]:
tokenizer.save_pretrained("model-output/model-4-30-23")

('model-output/model-4-30-23\\tokenizer_config.json',
 'model-output/model-4-30-23\\special_tokens_map.json',
 'model-output/model-4-30-23\\vocab.json',
 'model-output/model-4-30-23\\merges.txt',
 'model-output/model-4-30-23\\added_tokens.json',
 'model-output/model-4-30-23\\tokenizer.json')