In [None]:
def train_steps(
    model_name: str,
    chunks: list,
    next_sentences: list,
    seq_length: int = 1024,
    batch_size: int = 256,
    epochs: int = 3,
    learning_rate: float = 1e-5,
    device: str = "cuda:1"
):
    """
    Fine-tune a language model using LoRA for multiple epochs.
    :param epochs: Number of training epochs.
    """
    print(f"Model: {model_name}")
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # Add padding token
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.to(device)

    # Prepare dataset and dataloader
    dataset = EBAE_EBARDataset(chunks, next_sentences, tokenizer, seq_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Define optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    max_grad_norm = 1500.0  # Maximum gradient norm
    gradient_accumulation_steps = 8  # Simulates a larger batch size

    # Initialize tracking variables
    losses = []
    gradient_norms = []

    # Training loop
    model.train()
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        epoch_loss = 0.0
        epoch_gradient_norms = []
        loss_buffer = []

        for step, batch in enumerate(dataloader):
            input_ids = batch["input_ids"].to(device)
            next_input_ids = batch["next_input_ids"].to(device)

            # Compute loss
            try:
                loss = ebae_ebar_loss(model, input_ids, next_input_ids, tokenizer, device)
                if loss is None:
                    print(f"Skipping batch {step} due to None loss.")
                    continue
            except ValueError as e:
                print(f"Error at batch {step}: {e}")
                continue

            # Normalize loss for accumulation
            loss = loss / gradient_accumulation_steps
            loss.backward()
            loss_buffer.append(loss.item())

            # Update after accumulation
            if (step + 1) % gradient_accumulation_steps == 0:
                # Clip gradients and step optimizer
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                epoch_gradient_norms.append(grad_norm)

                optimizer.step()
                optimizer.zero_grad()

                # Track losses
                avg_loss = sum(loss_buffer) / gradient_accumulation_steps
                epoch_loss += avg_loss
                loss_buffer = []  # Reset for next cycle

                print(f"Step {step + 1}, Avg Loss: {avg_loss:.4f}, Grad Norm: {grad_norm:.4f}")

        # Handle leftover gradients
        if loss_buffer:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            epoch_gradient_norms.append(grad_norm)

            optimizer.step()
            optimizer.zero_grad()

            avg_loss = sum(loss_buffer) / len(loss_buffer)  # Average remaining losses
            epoch_loss += avg_loss

        # Record epoch stats
        losses.append(epoch_loss)
        gradient_norms.append(epoch_gradient_norms)

        print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f}")
        print(f"Epoch {epoch + 1} Gradient Norms: {epoch_gradient_norms}")

    # Save model and tokenizer
    output_dir = "./ebae-model"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")

    # Save metrics
    torch.save({"losses": losses, "gradient_norms": gradient_norms}, "training_metrics.pth")
    print("Metrics saved to training_metrics.pth")
