# Entrenamiento de "Baby GPT" (Corregido GPU + ONNX)

Este notebook entrena un un modelo GPT-2 minúsculo (<10MB) desde cero y lo exporta a ONNX para móviles.

In [None]:
# Instalación de librerías
%pip install transformers datasets tokenizers torch accelerate onnx onnxruntime

In [None]:
import torch
import json
import os
from tokenizers import ByteLevelBPETokenizer
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import Dataset

# Verificar hardware
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Usando: {device.upper()}")

In [None]:
# 1. Generar datos y Tokenizer personalizado
text_data = []
if os.path.exists("dataset.jsonl"):
    with open("dataset.jsonl", "r", encoding="utf-8") as f:
        for line in f:
            try:
                obj = json.loads(line)
                text_data.append(f"TEMA: {obj['topic']} \nRIMA: {obj['rap']}")
            except: pass
else:
    # Datos dummy por si no existe el fichero al probar
    text_data = ["TEMA: Prueba \nRIMA: Esto es una prueba, no te muevas."] * 100

# Guardar txt temporal para entrenar tokenizer
with open("train_text.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(text_data))

# Entrenar Tokenizer BPE (Byte-Pair Encoding)
tokenizer = ByteLevelBPETokenizer()
tokenizer.train(files=["train_text.txt"], vocab_size=5000, min_frequency=2, special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"])

os.makedirs("./baby_rhyme_gpt", exist_ok=True)
tokenizer.save_model("./baby_rhyme_gpt")

# Cargar como GPT2TokenizerFast (wrapper de HuggingFace)
tokenizer_gpt = GPT2TokenizerFast.from_pretrained("./baby_rhyme_gpt")
tokenizer_gpt.add_special_tokens({'eos_token': '</s>', 'bos_token': '<s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'mask_token': '<mask>'})

In [None]:
# 2. Configurar Modelo Mini-GPT
config = GPT2Config(
    vocab_size=5002,
    n_positions=256,
    n_embd=128,       # Muy pequeño para que pese poco
    n_layer=4,        # Pocas capas
    n_head=4,
    bos_token_id=tokenizer_gpt.bos_token_id,
    eos_token_id=tokenizer_gpt.eos_token_id,
)

model = GPT2LMHeadModel(config).to(device)
print(f"Tamaño del modelo: {sum(p.numel() for p in model.parameters())/1e6:.2f} M parámetros")

In [None]:
# 3. Preparar Dataset
raw_dataset = Dataset.from_dict({"text": text_data})

def tokenize_function(examples):
    return tokenizer_gpt(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenizamos y ELIMINAMOS la columna de texto original para evitar errores en el Trainer
tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer_gpt, mlm=False)

In [None]:
# 4. Entrenamiento
training_args = TrainingArguments(
    output_dir="./baby_rhyme_gpt_checkpoints",
    overwrite_output_dir=True,
    num_train_epochs=50,
    per_device_train_batch_size=64,
    learning_rate=1e-3,
    save_steps=1000,
    logging_steps=100,
    prediction_loss_only=True,
    fp16=torch.cuda.is_available(), # Mixed Precision para velocidad
    remove_unused_columns=False     # Evita otro tipo de errores de columnas
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets,
)

trainer.train()

In [None]:
# 5. Guardar y Exportar a ONNX
output_path = "./baby_rhyme_gpt"
model.save_pretrained(output_path)
tokenizer_gpt.save_pretrained(output_path)

print("Modelo guardado. Exportando a ONNX...")

# Exportar a ONNX
dummy_input = tokenizer_gpt("TEMA: Test", return_tensors="pt").input_ids.to(device)
onnx_path = os.path.join(output_path, "model.onnx")

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    opset_version=14,
    input_names=['input_ids'],
    output_names=['logits'],
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'seq_len'}, 'logits': {0: 'batch_size', 1: 'seq_len'}}
)
print(f"ONNX exportado a: {onnx_path}")

# Cuantizar para móvil (8-bit)
from onnxruntime.quantization import quantize_dynamic, QuantType
quant_path = os.path.join(output_path, "model_quant.onnx")
quantize_dynamic(onnx_path, quant_path, weight_type=QuantType.QUInt8)
print(f"Modelo cuantizado para móvil listo en: {quant_path}")

In [None]:
# 6. Prueba Final (Inferencia)
from transformers import pipeline

# Usamos CPU para inferencia simple de prueba para asegurar que carga bien
generator = pipeline("text-generation", model=output_path, tokenizer=output_path, device=-1)

def rima(tema):
    prompt = f"TEMA: {tema} \nRIMA:"
    out = generator(prompt, max_length=60, num_return_sequences=1, do_sample=True, temperature=0.7)
    print(out[0]['generated_text'])
    print("-"*10)

rima("Amor")
rima("Futuro")