In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

from datetime import date 

In [2]:
raw_dataset = load_dataset("json", data_files="clean-data/train_dataset.json")

In [3]:
model_checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")



In [4]:
sample_in = raw_dataset["train"][0]["prompt"]
sample_out = raw_dataset["train"][0]["completion"]

inputs = tokenizer(sample_in, text_target=sample_out)

inputs

Token indices sequence length is longer than the specified maximum sequence length for this model (5541 > 512). Running this sequence through the model will result in indexing errors


{'input_ids': [20237, 4740, 4383, 3, 5, 1], 'attention_mask': [1, 1, 1, 1, 1, 1], 'labels': [205, 15442, 117, 345, 27376, 4347, 3166, 5, 4560, 6, 22504, 7984, 1755, 117, 667, 11039, 940, 6, 2606, 10667, 6, 2292, 5, 2517, 117, 196, 357, 4200, 6, 357, 4200, 6, 357, 4200, 184, 254, 2294, 591, 117, 345, 5947, 5, 1752, 6, 4448, 5, 3328, 6, 2122, 2555, 4200, 117, 667, 7141, 10667, 6, 2394, 4200, 6, 10667, 117, 196, 357, 4200, 6, 357, 4200, 6, 357, 4200, 184, 254, 2294, 591, 117, 345, 5947, 5, 1752, 6, 4448, 5, 3328, 6, 2122, 3747, 4200, 117, 667, 7141, 10667, 6, 7141, 10667, 6, 10667, 117, 196, 357, 4200, 6, 357, 4200, 6, 357, 4200, 184, 254, 2658, 117, 345, 7988, 5, 1752, 6, 4448, 5, 2518, 6, 20889, 8797, 1752, 117, 667, 7141, 10667, 6, 2394, 4200, 6, 10667, 117, 196, 519, 4200, 6, 519, 4200, 6, 519, 4200, 184, 254, 2658, 117, 345, 5947, 5, 1752, 6, 4448, 5, 2518, 6, 14574, 15938, 632, 117, 667, 7141, 10667, 6, 7141, 10667, 6, 10667, 117, 196, 519, 4200, 6, 519, 4200, 6, 519, 4200, 184, 254

In [5]:
max_length = 2048

def preprocess_function(examples):
  inputs = [f"Translate from English to ROBLOX Serialized data: {ex}" for ex in examples["prompt"]]
  targets = [ex for ex in examples["completion"]]

  model_inputs = tokenizer(
    inputs, text_target=targets, max_length=max_length, truncation=True
  )
  
  return model_inputs

In [6]:
tokenized_dataset = raw_dataset.map(
  preprocess_function,
  batched=True,
  remove_columns=raw_dataset["train"].column_names,
)

In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model.config.use_cache = False

In [8]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [9]:
session_id = f"{date.today().strftime('%Y%m%d')}-model"

training_args = Seq2SeqTrainingArguments(
  f"mason-ai/{session_id}",
  save_strategy="epoch",
  learning_rate=2e-5,
  weight_decay=0.01,
  save_total_limit=3,
  predict_with_generate=True,

  fp16=True,
  num_train_epochs=3,
  per_device_train_batch_size=2,

  gradient_accumulation_steps=4,
  gradient_checkpointing=True,
  optim="adafactor"
)

In [10]:
trainer = Seq2SeqTrainer(
  model,
  training_args,
  train_dataset=tokenized_dataset["train"],
  data_collator=data_collator,
  tokenizer=tokenizer,
)

In [11]:
trainer.train()

  0%|          | 0/1848 [00:00<?, ?it/s]



{'loss': 2.443, 'grad_norm': 0.844630777835846, 'learning_rate': 1.4653679653679655e-05, 'epoch': 0.81}




{'loss': 1.42, 'grad_norm': 1.0237200260162354, 'learning_rate': 9.242424242424244e-06, 'epoch': 1.62}




{'loss': 1.267, 'grad_norm': 0.8527573943138123, 'learning_rate': 3.831168831168831e-06, 'epoch': 2.43}
{'train_runtime': 2900.5193, 'train_samples_per_second': 5.099, 'train_steps_per_second': 0.637, 'train_loss': 1.6196257851340554, 'epoch': 3.0}


TrainOutput(global_step=1848, training_loss=1.6196257851340554, metrics={'train_runtime': 2900.5193, 'train_samples_per_second': 5.099, 'train_steps_per_second': 0.637, 'total_flos': 138121595191296.0, 'train_loss': 1.6196257851340554, 'epoch': 2.998782961460446})

In [12]:
trainer.save_model(f"mason-ai/{session_id}-completed")

In [17]:
prompt = "A big house"
inputs = tokenizer(f"Translate from English to ROBLOX Serialized data: {prompt}", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs)
tokenizer.decode(outputs[0], skip_special_tokens=True)

'C1;P_;O_;I1.00,2.00,1.00&'