In [None]:
import os
import subprocess
import time
import requests
import json
from typing import List, Dict, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

In [None]:
@torch.no_grad()
def speculative_decoding_hf(
        self,
        large_model,
        text: str,
        max_new_tokens: int = 128,
        num_assistant_tokens: int = 4,
        confidence_threshold: float = 0.4,
        seed: int = 42,
        verbose: bool = False,
):
    if verbose:
        print("\n" + "\033[95m" + "‚îÄ" * 50 + "\033[0m")
        print("‚ú® Speculative Decoding")
        print(f"‚îú‚îÄ Target Model: {large_model.config.name_or_path}")
        print(f"‚îú‚îÄ Draft Model: {self.model.config.name_or_path}")
        print(f"‚îî‚îÄ Draft Length: {num_assistant_tokens}, Confidence: {confidence_threshold:.1f}")
        print("\033[95m" + "‚îÄ" * 50 + "\033[0m")

    start_time = time.time()
    step = 1
    total_draft_tokens = 0
    total_accepted_draft_tokens = 0

    # Initial prompt
    prompt_tokens = self.tokenizer(
        text, return_tensors="pt", add_special_tokens=False
    ).input_ids.to(large_model.device)
    generated_token_ids = []
    past_key_values = None

    # Generation loop
    while len(generated_token_ids) < max_new_tokens:
        current_input_ids = torch.cat(
            [
                prompt_tokens,
                torch.tensor([generated_token_ids], dtype=torch.long, device=large_model.device)
            ],
            dim=-1,
        )

        # 1. Draft Generation
        draft_outputs = self.model.generate(
            input_ids=current_input_ids.to(self.model.device),
            max_new_tokens=num_assistant_tokens,
            do_sample=False,
            seed=seed,
            use_cache=True
        )
        draft_ids = draft_outputs[0, current_input_ids.shape[-1]:].to(large_model.device)

        if len(draft_ids) == 0:
            if verbose: print("‚ö†Ô∏è Draft model produced no new tokens. Stopping.")
            break
        total_draft_tokens += len(draft_ids)

        # 2. Target Í≤ÄÏ¶ù (large_model ÏÇ¨Ïö©)
        verification_ids = torch.cat([current_input_ids, draft_ids.unsqueeze(0)], dim=-1)
        outputs = large_model(
            input_ids=verification_ids,
            past_key_values=past_key_values,
            use_cache=True
        )
        logits = outputs.logits
        past_key_values = outputs.past_key_values

        if verbose:
            current_text = self.tokenizer.decode(current_input_ids[0])
            print("\n" + "\033[95m" + "-" * 10 + f"Step {step}" + "-" * 10 + "\033[0m")
            print(f"\033[94mDraft Input:\033[0m\n{current_text}")
            print(f"\033[94mDraft Output\033[0m\n{self.tokenizer.decode(draft_ids.tolist(), skip_special_tokens=False)}")
            print(f"‚îå{'‚îÄ'*5}‚î¨{'‚îÄ'*17}‚î¨{'‚îÄ'*14}‚î¨{'‚îÄ'*20}‚î¨{'‚îÄ'*17}‚îê")
            print(f"‚îÇ {'Idx':<3s} ‚îÇ {'Draft Token':<15s} ‚îÇ {'Target Prob':>12s} ‚îÇ {'Status':<18s} ‚îÇ {'Corrected':<15s} ‚îÇ")
            print(f"‚îú{'‚îÄ'*5}‚îº{'‚îÄ'*17}‚îº{'‚îÄ'*14}‚îº{'‚îÄ'*20}‚îº{'‚îÄ'*17}‚î§")

        accepted_count = 0
        for i in range(len(draft_ids)):
            # 1. Probabilities
            target_logit = logits[:, current_input_ids.shape[-1] + i - 1, :]
            probs = torch.softmax(target_logit, dim=-1)

            draft_token_id = draft_ids[i]
            draft_token_prob = probs[0, draft_token_id].item()
            draft_token_str = self.tokenizer.decode(draft_token_id).replace('\n', '\\n')

            # 2. Accept/Reject
            if draft_token_prob >= confidence_threshold:
                accepted_count += 1
                if verbose:
                    status = "\033[92m‚úÖ Accepted\033[0m"
                    corrected_str = "-"
                    print(f"‚îÇ {i + 1:<3d} ‚îÇ {draft_token_str:<15.15s} ‚îÇ {draft_token_prob:>12.2%} ‚îÇ {status:<26s} ‚îÇ {corrected_str:<15s} ‚îÇ")

            else:
                corrected_token = torch.argmax(target_logit, dim=-1).item()
                if verbose:
                    corrected_str = self.tokenizer.decode(corrected_token).replace('\n', '\\n')
                    status = "\033[91m‚ùå Rejected\033[0m"
                    print(f"‚îÇ {i + 1:<3d} ‚îÇ {draft_token_str:<15.15s} ‚îÇ {draft_token_prob:>12.2%} ‚îÇ {status:<26s} ‚îÇ {corrected_str:<15.15s} ‚îÇ")
                    print(f"‚îî{'‚îÄ' * 5}‚î¥{'‚îÄ' * 17}‚î¥{'‚îÄ' * 14}‚î¥{'‚îÄ' * 20}‚î¥{'‚îÄ' * 17}‚îò")
                if accepted_count > 0:
                    generated_token_ids.extend(draft_ids[:accepted_count].tolist())
                generated_token_ids.append(corrected_token)
                total_accepted_draft_tokens += accepted_count
                break
        else: # All draft tokens accepted
            total_accepted_draft_tokens += accepted_count
            generated_token_ids.extend(draft_ids.tolist())
            last_logit = logits[:, -1, :]
            next_token = torch.argmax(last_logit, dim=-1).item()
            generated_token_ids.append(next_token)
            if verbose:
                print(f"‚îî{'‚îÄ' * 5}‚î¥{'‚îÄ' * 17}‚î¥{'‚îÄ' * 14}‚î¥{'‚îÄ' * 20}‚î¥{'‚îÄ' * 17}‚îò")
                accepted_token_strs = [self.tokenizer.decode(t).replace('\n', '\\n') for t in draft_ids.tolist()]
                next_token_str = self.tokenizer.decode(next_token).replace('\n', '\\n')
                print(f"‚úÖ \033[92mAccepted all {accepted_count} tokens: \033[0m{accepted_token_strs}")
                print(f"‚úÖ \033[92mGenerated Target Token: \033[0m{next_token_str}")

        if self.tokenizer.eos_token_id in generated_token_ids:
            eos_index = generated_token_ids.index(self.tokenizer.eos_token_id)
            generated_token_ids = generated_token_ids[:eos_index]
            if verbose: print("üõë EOS token generated. Stopping.")
            break
        step += 1

    end_time = time.time()
    final_text = self.tokenizer.decode(generated_token_ids, skip_special_tokens=False)

    if verbose:
        total_time = end_time - start_time
        num_generated = len(generated_token_ids)
        acceptance_rate = (total_accepted_draft_tokens / total_draft_tokens) * 100 if total_draft_tokens > 0 else 0
        latency = (total_time / num_generated) * 1000 if num_generated > 0 else float('inf')
        throughput = num_generated / total_time if total_time > 0 else float('inf')

        print("\n" + "\033[95m" + "‚îÄ" * 50 + "\033[0m")
        print("üèÅ Speculative Decoding Finished")
        print(f"\033[94müí¨ User Input:\033[0m\n{text}")
        print(f"\n\033[92müü¢ Generated Text:\033[0m\n{final_text}")
        print("\n\033[94müìä Performance:\033[0m")
        print(f"‚îú‚îÄ Total Time: {total_time:.2f}s")
        print(f"‚îú‚îÄ Latency: {latency:.2f} ms/token")
        print(f"‚îî‚îÄ  Throughput: {throughput:.2f} tokens/s")
        print(f"\033[94m‚ú® Speculative Stats:\033[0m")
        print(f"‚îú‚îÄ Acceptance Rate: {acceptance_rate:.2f}%")
        print(f"‚îú‚îÄ Total Drafted Tokens: {total_draft_tokens}")
        print(f"‚îî‚îÄ Total Accepted Draft Tokens: {total_accepted_draft_tokens}")

    return final_text