In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import re
import glob
import h5py
import json
import torch

import torch.nn as nn
from safetensors import safe_open

from transformers.models.llama.configuration_llama import LlamaConfig
from models.llama_eagle_hf import LlamaForCausalLMEagle

from datasets import Dataset, IterableDataset
from transformers import AutoTokenizer, Trainer, TrainingArguments

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
# Need to download models to local dir - not sure why can't read from huggingface cached files

# path = "models/phi-3/"
path = "models/llama-8b/"

with open(os.path.join(path, "model.safetensors.index.json"), "r") as f:
  index_json = json.loads(f.read())
  emb_path = index_json["weight_map"]["model.embed_tokens.weight"]

with safe_open(os.path.join(path, emb_path), framework="pt", device="cpu") as f:
  tensor_slice = f.get_slice("model.embed_tokens.weight")
  vocab_size, hidden_dim = tensor_slice.get_shape()
  tensor = tensor_slice[:, :hidden_dim].float()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token

model_args = LlamaConfig(vocab_size=vocab_size,
                         hidden_size=hidden_dim,
                         intermediate_size=14336,
                         num_hidden_layers=1,
                         bos_token_id=128000,
                         eos_token_id=[128001, 128008, 128009],
                         num_key_value_heads=8,
                         num_attention_heads=32,
                         torch_dtype=torch.bfloat16,
                         tie_word_embeddings=False,)

draft_model = LlamaForCausalLMEagle(model_args)
draft_model.load_embedding_weights(tensor)
draft_model.to("cuda:0")

# Freeze embedding layer
draft_model.model.embed_tokens.weight.requires_grad = False

# for name, param in draft_model.named_parameters():
#     status = "Frozen" if not param.requires_grad else "Trainable"
#     print(f"{name}: {status}")

In [None]:
# Test forward pass for draft model
draft_model(torch.rand(1, 2, 4096, dtype=torch.bfloat16, device="cuda:0"), torch.tensor([[0, 1]], device="cuda:0"))

In [None]:
class HDF5IterableDataset(IterableDataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths
        self._length = None
        self._epoch = 0

    def __iter__(self):
        # Each worker should get its own file handle.
        for file_path in self.file_paths:
            with h5py.File(file_path, "r") as f:
                # Sort groups if order matters (assuming names like "sample_0", "sample_1", etc.)
                group_keys = sorted(f.keys(), key=lambda x: int(x.split('_')[-1]))
                for group_name in group_keys:
                    grp = f[group_name]
                    yield {
                        "input_ids": grp["input_ids"][:],       # NumPy array
                        "hidden_states": grp["hidden_states"][:]  # NumPy array
                    }

    def __len__(self):
        # Compute length only once.
        if self._length is None:
            total = 0
            for file_path in self.file_paths:
                with h5py.File(file_path, "r") as f:
                    total += len(f.keys())
            self._length = total
        return self._length


    def set_epoch(self, epoch):
        self._epoch = epoch

file_paths = glob.glob("data/train_dataset_w_hidden_states_*.h5")
file_paths = sorted(file_paths, key=lambda x: int(re.search(r'(\d+)-\d+', x).group(1)))

# Instantiate your custom IterableDataset.
eagle_dataset = HDF5IterableDataset(file_paths)

# Verify length (e.g., for Trainer training steps calculation)
print("Total dataset length:", len(eagle_dataset))

In [None]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    # Copied from huggingface padding collator
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

class EagleDataCollator:
    """
    A data collator that pads variable-length sequences using the tokenizer's pad method
    (via a helper function to avoid warnings) for text inputs and manually pads the
    corresponding hidden states.
    """
    def __init__(self, tokenizer, padding=True, max_length=None, pad_to_multiple_of=None, return_tensors="pt", device="cuda:0"):
        self.tokenizer = tokenizer
        self.padding = padding
        self.max_length = max_length
        self.pad_to_multiple_of = pad_to_multiple_of
        self.return_tensors = return_tensors
        self.device = device

    def __call__(self, features):
        """
        Expects each feature in `features` to be a dict with at least:
           - "input_ids": a list of token ids,
           - "hidden_state": a list or tensor of shape (seq_length, hidden_size).
        """
        # Use the helper function to pad input_ids and create the attention mask.
        # Currently right padding, might need to left pad

        if type(features) is list:
            tokenizer_features = {"input_ids": [el["input_ids"] for el in features]}
            hidden_states_list = [torch.tensor(el["hidden_states"], dtype=torch.bfloat16) for el in features]
        else:
            # not sure when this is used except my own dataset[:4] batches
            tokenizer_features = {"input_ids": features["input_ids"]}
            hidden_states_list = [torch.tensor(x, dtype=torch.bfloat16) for x in features["hidden_states"]]

        # Use the helper to pad the tokenizer features.
        padded_batch = pad_without_fast_tokenizer_warning(
            tokenizer,
            tokenizer_features,
            padding=True,
            max_length=None,  # or set if needed
            pad_to_multiple_of=None,
            return_tensors="pt"
        )

        hidden_states_padded = pad_sequence(hidden_states_list, batch_first=True, padding_value=0.0)
        padded_batch["hidden_states"] = hidden_states_padded

        return padded_batch

eagle_collator = EagleDataCollator(tokenizer, device="cuda:0")

In [None]:
class EagleTrainer(Trainer):

    def get_train_dataloader(self):
        dataloader = super().get_train_dataloader()
        # not sure why the dataloader gets converted to MapDataloader rather than IteratorDataloader when len is available
        dataloader.base_dataloader._dataset_kind = 1
        return dataloader

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # TODO: add loss mask - should be computed in collator
        # TODO: also add generation only mask, these might be the same lol
        # TODO: add token prediction loss
        """
        Uses teacher forcing with a single forward pass:
            - For each sequence, inputs from time steps 0 to T-1 are used to predict hidden states at 1 to T.
            - Loss is computed with an MSE loss over valid (non-padded) positions.
        """
        input_ids = inputs["input_ids"]           # (batch_size, padded_seq_length)
        hidden_states = inputs["hidden_states"]     # (batch_size, padded_seq_length, hidden_size)
        attention_mask = inputs["attention_mask"]   # (batch_size, padded_seq_length)
        
        # Teacher forcing: shift sequences.
        input_ids_model = input_ids[:, :-1]              # (batch_size, padded_seq_length - 1)
        hidden_states_model = hidden_states[:, :-1, :]    # (batch_size, padded_seq_length - 1, hidden_size)
        target_hidden_states = hidden_states[:, 1:, :]     # (batch_size, padded_seq_length - 1, hidden_size)
        # Also shift the attention mask.
        input_attention_mask = attention_mask[:, :-1]      # For the model's input.
        target_mask = attention_mask[:, 1:]                # For masking loss.
        
        # Forward pass in one go.
        predicted_hidden_states = model(input_ids=input_ids_model,
                                        hidden_state=hidden_states_model,
                                        attention_mask=input_attention_mask).bfloat16()

        # Compute elementwise MSE loss.
        # loss_fct = nn.MSELoss(reduction="none") # platues a little less than loss of 3
        loss_fct = nn.SmoothL1Loss(reduction="none")
        loss_all = loss_fct(predicted_hidden_states, target_hidden_states)  # (batch, seq_len-1, hidden_size)
        loss_all = loss_all.mean(dim=-1)  # Mean over hidden_size → (batch, seq_len-1)
        
        # Mask out padded positions in the target.
        loss_all = loss_all * target_mask
        valid_tokens = target_mask.sum()
        loss = loss_all.sum() / valid_tokens
        loss = loss / self.args.gradient_accumulation_steps

        return loss

In [None]:
training_args = TrainingArguments(
    output_dir="./hf_trainer_output_dir/",
    num_train_epochs=1,
    # max_steps=750,
    warmup_ratio=0.1,
    learning_rate=1e-3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    remove_unused_columns=False,
    bf16=True,
    fp16=False,
    logging_steps=10,
    # evaluation_strategy="steps",
    # eval_steps=500,
    # save_steps=500,
)

trainer = EagleTrainer(
    model=draft_model,
    args=training_args,
    train_dataset=eagle_dataset,
    data_collator=eagle_collator,
)

trainer.train()