In [None]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

In [None]:
# Access the text data from the TinyStories dataset splits and join them into single strings
train_text_data_str = " <|endoftext|> ".join(ds['train']['text'])
val_text_data_str = " <|endoftext|> ".join(ds['validation']['text'])


# Create dataloaders using the text data from TinyStories
train_loader_tinystories = create_dataloader_v1(
    train_text_data_str,
    batch_size=4, # Use context_length as batch_size for demonstration
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    shuffle=True,
    drop_last=True,
    num_workers=0 # Adjust num_workers based on your environment
)

val_loader_tinystories = create_dataloader_v1(
    val_text_data_str,
    batch_size=4, # Use context_length as batch_size for demonstration
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    shuffle=False,
    drop_last=False,
    num_workers=0 # Adjust num_workers based on your environment
)

print("TinyStories Train loader batches:", len(train_loader_tinystories))
print("TinyStories Validation loader batches:", len(val_loader_tinystories))

In [None]:
print("Train loader:")
for i, (x, y) in enumerate(train_loader_tinystories):
  print(x.shape, y.shape)
  if i == 5:
        break

print("\nValidation loader:")
for i, (x, y) in enumerate(val_loader_tinystories):
  print(x.shape, y.shape)
  if i == 5:
        break

print(f"Number of batches in train_stories: {len(train_loader_tinystories)}")
print(f"Number of batches in validation_stories: {len(val_loader_tinystories)}")


In [None]:
model.to(device)

torch.manual_seed(123)

with torch.no_grad():
  train_loss = calc_loss_loader(train_loader_tinystories, model, device)
  val_loss = calc_loss_loader(val_loader_tinystories, model, device)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)

In [None]:
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer):
  
  train_losses, val_losses, track_tokens_seen = [], [], []
  tokens_seen, global_step = 0, -1

  for epochs in range(num_epochs):
    model.train()

    for input_batch, target_batch in train_loader:
      optimizer.zero_grad()
      loss = calc_loss_batch(input_batch, target_batch, model, device)
      loss.backward()
      optimizer.step()
      tokens_seen += input_batch.numel()
      global_step += 1

      if (global_step % eval_freq) == 0:
        train_loss, val_loss = evaluate_model(
            model, train_loader, val_loader, device, eval_iter)
        train_losses.append(train_loss)
        val_losses.append(val_loss)  # Fixed: was val_losses.append(val_losses)
        track_tokens_seen.append(tokens_seen)
        print(f"Epoch: {epochs+1} | Step: {global_step:06d} | Train loss: {train_loss:.3f} | Val_loss: {val_loss:.3f}")
    
    generate_and_print_sample(
        model, tokenizer, device, start_context
    )
  
  return train_losses, val_losses, track_tokens_seen

In [None]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  model.eval()
  with torch.no_grad():
    train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
    val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  
  model.train()
  return train_loss, val_loss