In [1]:
import sys
sys.path.append("../") 

In [2]:
import torch
from transformers import AutoTokenizer
from data.collator import CustomDataCollator
from models.transformer_lm import MiniTransformerLM
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [3]:
# Configuración
device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_model = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(vocab_model)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
# Cargar dataset
dataset = load_dataset("json", data_files="../data/train.jsonl", split="train")
collator = CustomDataCollator(tokenizer, mlm_probability=0.15)
loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collator)

In [5]:
# Tomar un batch
batch = next(iter(loader))
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)

In [6]:
# Cargar modelo
model = MiniTransformerLM(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    n_heads=4,
    n_layers=4,
    max_len=input_ids.shape[1]
).to(device)

In [9]:
# Forward pass
logits = model(input_ids, attention_mask=attention_mask)
print("Logits shape:", logits.shape)  # (batch, seq_len, vocab_size)

# Pérdida
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
print("Loss:", loss.item()) 

Logits shape: torch.Size([4, 128, 50257])
Loss: 10.966310501098633
