In [5]:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import time

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
#model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#model.config.n_future_tokens = 10

In [2]:
!wandb login e461a6a3bca9f7cec3390a40dc10cdf576ce3252

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!


In [3]:
def patch_llama_for_multitoken():
    import copy
    def patched_forward(self, input_ids, attention_mask=None, **kwargs):
        outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
        hidden_states = outputs[0]  # ожидаем форму (batch_size, seq_len, hidden_size)
        n_future_tokens = getattr(self.config, "n_future_tokens", 1)

        if n_future_tokens > 1:
            trunk_states = hidden_states
            latents = [trunk_states]

            if not hasattr(self, "extra_heads"):
                last_layer = self.model.layers[-1]
                self.extra_heads = torch.nn.ModuleList([
                    copy.deepcopy(last_layer) for _ in range(n_future_tokens - 1)
                ])

            if "position_ids" not in kwargs or kwargs["position_ids"] is None:
                batch_size, seq_len, _ = trunk_states.shape
                kwargs["position_ids"] = torch.arange(seq_len, device=trunk_states.device)\
                    .unsqueeze(0).expand(batch_size, seq_len)

            for head in self.extra_heads:
                if isinstance(trunk_states, tuple):
                    trunk_states = trunk_states[0]
                batch_size, seq_len, _ = trunk_states.shape
                local_kwargs = kwargs.copy()
                local_kwargs['position_ids'] = torch.arange(seq_len, device=trunk_states.device)\
                    .unsqueeze(0).expand(batch_size, seq_len)
                if attention_mask is not None:
                    if attention_mask.ndim == 2:
                        local_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
                    else:
                        local_attention_mask = attention_mask
                    local_attention_mask = local_attention_mask.float()
                else:
                    local_attention_mask = None
                output = head(trunk_states, attention_mask=local_attention_mask, **local_kwargs)
                if isinstance(output, tuple):
                    trunk_states = output[0]
                else:
                    trunk_states = output
                latents.append(trunk_states)

            hidden_states = torch.stack(latents, dim=2)  # (batch_size, seq_len, n_future_tokens, hidden_size)
            logits = self.lm_head(hidden_states)
        else:
            logits = self.lm_head(hidden_states)
        return logits

    LlamaForCausalLM.forward = patched_forward
    print("LlamaForCausalLM patched for multi-token generation.")

# Применяем патч (для новых экземпляров модели)


In [4]:
# Загружаем токенайзер и оригинальную модель LLaMA
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model_original = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

# Промпт для генерации
prompt = "Сегодня на улице"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model_original.device) for k, v in inputs.items()}

# Генерация с помощью стандартного метода generate
max_new_tokens = 20
start_time = time.time()
generated_ids_orig = model_original.generate(**inputs, max_new_tokens=max_new_tokens)
orig_time = time.time() - start_time

text_orig = tokenizer.decode(generated_ids_orig[0], skip_special_tokens=True)
print("Оригинальная модель (generate):")
print("Время генерации: {:.4f} сек".format(orig_time))
print("Сгенерированный текст:", text_orig)
print("\n" + "="*50 + "\n")


# -------------------------------
# 2. Патченная модель с multi-token генерацией
# -------------------------------

# Здесь предполагается, что ранее была выполнена функция patch_llama_for_multitoken(),
# которая заменяет метод forward в классе LlamaForCausalLM.
# (Если патч ещё не применён, его нужно выполнить до создания модели.)

# Загружаем модель – она уже будет использовать патч, т.к. класс был модифицирован.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Оригинальная модель (generate):
Время генерации: 13.7326 сек
Сгенерированный текст: Сегодня на улице, в стенах и в ушах слышится музыка. В каждом из этих




In [6]:
from transformers import LlamaForCausalLM
def patch_llama_for_multitoken():
    import copy
    def patched_forward(self, input_ids, attention_mask=None, **kwargs):
        outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
        hidden_states = outputs[0]  # ожидаем форму (batch_size, seq_len, hidden_size)
        n_future_tokens = getattr(self.config, "n_future_tokens", 1)

        if n_future_tokens > 1:
            trunk_states = hidden_states
            latents = [trunk_states]

            if not hasattr(self, "extra_heads"):
                last_layer = self.model.layers[-1]
                self.extra_heads = torch.nn.ModuleList([
                    copy.deepcopy(last_layer) for _ in range(n_future_tokens - 1)
                ])

            if "position_ids" not in kwargs or kwargs["position_ids"] is None:
                batch_size, seq_len, _ = trunk_states.shape
                kwargs["position_ids"] = torch.arange(seq_len, device=trunk_states.device)\
                    .unsqueeze(0).expand(batch_size, seq_len)

            for head in self.extra_heads:
                if isinstance(trunk_states, tuple):
                    trunk_states = trunk_states[0]
                batch_size, seq_len, _ = trunk_states.shape
                local_kwargs = kwargs.copy()
                local_kwargs['position_ids'] = torch.arange(seq_len, device=trunk_states.device)\
                    .unsqueeze(0).expand(batch_size, seq_len)
                if attention_mask is not None:
                    if attention_mask.ndim == 2:
                        local_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
                    else:
                        local_attention_mask = attention_mask
                    local_attention_mask = local_attention_mask.float()
                else:
                    local_attention_mask = None
                output = head(trunk_states, attention_mask=local_attention_mask, **local_kwargs)
                if isinstance(output, tuple):
                    trunk_states = output[0]
                else:
                    trunk_states = output
                latents.append(trunk_states)

            hidden_states = torch.stack(latents, dim=2)  # (batch_size, seq_len, n_future_tokens, hidden_size)
            logits = self.lm_head(hidden_states)
        else:
            logits = self.lm_head(hidden_states)
        return logits

    LlamaForCausalLM.forward = patched_forward
    print("LlamaForCausalLM patched for multi-token generation.")

# Применяем патч
patch_llama_for_multitoken()

LlamaForCausalLM patched for multi-token generation.


In [7]:


model_patched = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
# Устанавливаем параметр для multi-token предсказания (например, 10 токенов за шаг)
model_patched.config.n_future_tokens = 10

In [8]:
def generate_multitoken(model, input_ids, max_new_tokens, n_future_tokens):
    """
    Генерирует текст, используя модель с патченной forward, которая возвращает логиты
    формы (batch_size, seq_len, n_future_tokens, vocab_size).
    На каждом шаге выбираются предсказания для всех future-токенов, которые затем дописываются к последовательности.
    """
    generated = input_ids  # Здесь input_ids уже является тензором!
    steps = max_new_tokens // n_future_tokens
    model.eval()
    with torch.no_grad():
        for _ in range(steps):
            outputs = model(generated)
            # last_logits имеет форму: (batch_size, n_future_tokens, vocab_size)
            last_logits = outputs[0, -1, :, :]
            # Выбираем argmax для каждого из n_future_tokens
            predicted_tokens = torch.argmax(last_logits, dim=-1)  # (batch_size, n_future_tokens)
            # Если батч равен 1, убеждаемся, что predicted_tokens имеет форму (1, n_future_tokens)
            if predicted_tokens.dim() == 1:
                predicted_tokens = predicted_tokens.unsqueeze(0)
            generated = torch.cat([generated, predicted_tokens], dim=1)
    return generated

# Подготавливаем входные данные: извлекаем тензор 'input_ids'
prompt = "Сегодня на улице"
# Извлекаем тензор, а не весь BatchEncoding
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model_patched.device)

max_new_tokens = 20
start_time = time.time()
generated_ids_patched = generate_multitoken(model_patched, input_ids, 
                                            max_new_tokens=max_new_tokens, 
                                            n_future_tokens=model_patched.config.n_future_tokens)
patched_time = time.time() - start_time

text_patched = tokenizer.decode(generated_ids_patched[0], skip_special_tokens=True)

print("Патченная модель (n_future_tokens = {}):".format(model_patched.config.n_future_tokens))
print("Время генерации: {:.4f} сек".format(patched_time))
print("Сгенерированный текст:", text_patched)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be removed and `position_embeddings` will be mandatory.


Патченная модель (n_future_tokens = 10):
Время генерации: 2.2791 сек
Сгенерированный текст: Сегодня на улице в появ появ появ появ появ появ����娇�������


In [9]:
tokenizer.pad_token = tokenizer.eos_token

In [10]:
# В качестве примера используем датасет wikitext-2 (raw версия)
from datasets import load_dataset
import copy
import math
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM

raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")


# 4. Настройка параметров обучения


In [12]:

#############################################
# 3. Загрузка модели и токенайзера
#############################################
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Для обучения устанавливаем n_future_tokens = 10,
# но для вычисления loss используем только первую prediction (индекс 0)
model.config.n_future_tokens = 10

In [11]:
import wandb
wandb.init(project="llama_mtp", config={
    "model_name": "meta-llama/Llama-3.2-1B",
    "learning_rate": 5e-5,
    "epochs": 3,
    "batch_size": 2,
    "n_future_tokens": 10,
    "block_size": 128,
})


#############################################
# 4. Подготовка датасета
#############################################
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512)

tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Группировка токенов в блоки фиксированной длины (block_size)
block_size = 128
def group_texts(examples):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated["input_ids"])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_dataset = tokenized_dataset.map(group_texts, batched=True)

class LMDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "labels": torch.tensor(item["labels"], dtype=torch.long)
        }

train_dataset = LMDataset(lm_dataset)
batch_size = 2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [14]:

#############################################
# 5. Определение функции generate_multitoken
#############################################
def generate_multitoken(model, inputs, max_new_tokens, n_future_tokens):
    """
    Генерирует текст с использованием патченной модели.
    Вход inputs – тензор input_ids.
    На каждом шаге вызывается модель, которая возвращает логиты формы
      (batch_size, seq_len, n_future_tokens, vocab_size),
    и выбирается argmax по последней позиции для всех future-токенов.
    """
    generated = inputs  # inputs уже является тензором
    steps = max_new_tokens // n_future_tokens
    model.eval()
    with torch.no_grad():
        for _ in range(steps):
            outputs = model(generated)
            # Логиты для последней позиции: (batch_size, n_future_tokens, vocab_size)
            last_logits = outputs[0, -1, :, :]
            predicted_tokens = torch.argmax(last_logits, dim=-1)  # (batch_size, n_future_tokens)
            if predicted_tokens.dim() == 1:
                predicted_tokens = predicted_tokens.unsqueeze(0)
            generated = torch.cat([generated, predicted_tokens], dim=1)
    return generated

#############################################
# 6. Настройка оптимизатора и обучение с предиктом и чекпоинтингом
#############################################
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 3
global_step = 0
print_every = 300  # каждые 300 шагов выводим информацию, предсказываем через generate_multitoken и сохраняем чекпоинт
output_dir = "./llama_finetuned_pure_torch"

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=None, position_ids=None)
        # Если n_future_tokens > 1, выход logits имеет форму:
        # (batch_size, seq_len, n_future_tokens, vocab_size)
        # Для loss используем только первую prediction (индекс 0)
        if outputs.dim() == 4:
            logits = outputs[:, :, 0, :]
        else:
            logits = outputs
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1),
                             )
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        global_step += 1
        
        if global_step % print_every == 0:
            print(f"Epoch {epoch+1}, Step {global_step}, Loss: {loss.item():.4f}")
            
            # Генерация sample-текста через generate_multitoken
            sample_prompt = "A story about a cat:"
            sample_input_ids = tokenizer(sample_prompt, return_tensors="pt")["input_ids"].to(device)
            generated_ids = generate_multitoken(model, sample_input_ids,
                                                max_new_tokens=20,
                                                n_future_tokens=model.config.n_future_tokens)
            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            print(f"Step {global_step} prediction via generate_multitoken: {generated_text}")
            
            # Логирование в wandb
            wandb.log({
                "global_step": global_step,
                "loss": loss.item(),
                "sample_prediction": generated_text,
            })
            
            # Сохранение чекпоинта
            checkpoint_path = f"{output_dir}/checkpoint-{global_step}.bin"
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")
            
    avg_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
    wandb.log({"epoch": epoch+1, "average_loss": avg_loss})

#############################################
# 7. Сохранение финальной модели
#############################################
final_model_path = output_dir + "/pytorch_model_final.bin"
torch.save(model.state_dict(), final_model_path)
tokenizer.save_pretrained(output_dir)
print("Final model saved in:", output_dir)
wandb.finish()

Map: 100%|██████████| 36718/36718 [00:01<00:00, 19829.06 examples/s]
Map: 100%|██████████| 36718/36718 [00:08<00:00, 4375.18 examples/s]


Epoch 1, Step 300, Loss: 0.5558
Step 300 prediction via generate_multitoken: Сегодня на улице matches          matches         
Checkpoint saved at ./llama_finetuned_pure_torch/checkpoint-300.bin
Epoch 1, Step 600, Loss: 0.1309
Step 600 prediction via generate_multitoken: Сегодня на улицеgesgesgesssssssssssssadalafiladalafiladalafiladalafiladalafil
Checkpoint saved at ./llama_finetuned_pure_torch/checkpoint-600.bin
Epoch 1, Step 900, Loss: 0.1060
Step 900 prediction via generate_multitoken: Сегодня на улицеéééééеentifierentifier]';
]';
JRJRJRJR |/ |/ |/>>)hursthurst
Checkpoint saved at ./llama_finetuned_pure_torch/checkpoint-900.bin
Epoch 1, Step 1200, Loss: 0.0993
Step 1200 prediction via generate_multitoken: Сегодня на улицеecec de������ Evet Evet Evet Evet Evet Evet Evet Evet Evet Evet Are
Checkpoint saved at ./llama_finetuned_pure_torch/checkpoint-1200.bin
Epoch 1, Step 1500, Loss: 0.0861
Step 1500 prediction via generate_multitoken: Сегодня на улицеageageeeeehehehehehehehehehehehe

RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 2812446272 vs 2812446160

--- Logging error ---
Traceback (most recent call last):
  File "/usr/lib/python3.12/logging/__init__.py", line 1164, in emit
    self.flush()
  File "/usr/lib/python3.12/logging/__init__.py", line 1144, in flush
    self.stream.flush()
OSError: [Errno 28] No space left on device
Call stack:
  File "/usr/lib/python3.12/threading.py", line 1030, in _bootstrap
    self._bootstrap_inner()
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
    self.run()
  File "/home/alexw/InternShips/qklent/medusa/speculative-decoding-medusa-example/medusa_venv/lib/python3.12/site-packages/wandb/sdk/internal/internal_util.py", line 48, in run
    self._run()
  File "/home/alexw/InternShips/qklent/medusa/speculative-decoding-medusa-example/medusa_venv/lib/python3.12/site-packages/wandb/sdk/internal/internal_util.py", line 99, in _run
    self._process(record)
  File "/home/alexw/InternShips/qklent/medusa/speculative-decoding-medusa-example/medusa_venv/lib/python3.12/site-packages/wa