In [70]:
# !pip install transformers --user

In [71]:
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
from transformers.cache_utils import DynamicCache
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast

In [72]:
# Load model directly

# model_name = "meta-llama/Llama-3.1-8B"
# model_name = "meta-llama/Llama-3.2-1B"
model_name = "meta-llama/Llama-3.1-8B-Instruct"
# model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
device = "cuda" # "cuda" for GPU usage or "cpu" for CPU usage

tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=f"./.cache/{model_name}")
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=f"./.cache/{model_name}").to(device)

In [73]:
prompt = "Jack has a dog"
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
# inputs = tokenizer.encode(prompt).to(device)
# print(input_ids)

attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)

generated_ids = model.generate(
    inputs=input_ids,
    max_new_tokens=50,
    attention_mask=attention_mask,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# print(generated_ids)

generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# generated_text = generated_text.replace(prompt, "").strip()

print(generated_text)

Jack has a dog, a fluffy golden retriever named Max. Max is a loyal companion, always by Jack's side. Jack has a cat, a sleek black Bengal named Luna. Luna is a gentle soul, and Jack loves her for her playful antics.


In [None]:
def getResponse(prompt: str, max_new_tokens: int = 50):
    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    # inputs = tokenizer.encode(prompt).to(device)

    attention_mask = torch.ones(input_ids.shape,dtype=torch.long,device=device)
    
    generated_ids = model.generate(
    inputs=input_ids,
    max_new_tokens=max_new_tokens,
    attention_mask=attention_mask,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    )

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    # generated_text = generated_text.replace(prompt, "").strip()

    return generated_text


In [74]:
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])

def generate(
    model,
    input_ids: torch.Tensor,
    past_key_values,
    max_new_tokens: int = 20
) -> torch.Tensor:
    """
    Generate text with proper device handling for HuggingFace models using device_map="auto"
    
    Args:
        model: HuggingFace model with automatic device mapping
        input_ids: Input token ids
        past_key_values: Previous KV cache
        max_length: Maximum sequence length to generate
    """
    # Get the device of the embedding layer
    embed_device = model.model.embed_tokens.weight.device

    origin_ids = input_ids
    # Move input to the same device as embedding layer
    input_ids = input_ids.to(embed_device)
    
    # Initialize output tensor on embedding device
    output_ids = input_ids.clone()
    next_token = input_ids
    
    # Main generation loop
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Forward pass with proper device placement
            outputs = model(
                input_ids=next_token,  # Only process last token
                past_key_values=past_key_values,
                use_cache=True
            )
            
            # Get next token prediction (logits will be on the last device)
            next_token_logits = outputs.logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            
            # Move next token to embedding device for next iteration
            next_token = next_token.to(embed_device)
            
            # Update KV cache
            past_key_values = outputs.past_key_values
            
            # Append prediction
            output_ids = torch.cat([output_ids, next_token], dim=1)
            
            # Optional: Check for EOS token
            #print(next_token.item())
            #print(model.config.eos_token_id)
            if next_token.item() in model.config.eos_token_id:
                break
    # return output_ids[:,origin_ids.shape[-1]:]
    return output_ids[:,:]


def get_kv_cache(
    model,
    tokenizer,
    prompt: str,
) -> DynamicCache:
    """
    Prepare KV cache for a model distributed across multiple GPUs using device_map="auto"
    
    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: Input text to generate KV cache for
    
    Returns:
        DynamicCache: Distributed KV cache
    """
    # Get embedding layer device
    embed_device = model.model.embed_tokens.weight.device
    
    # Encode and move input to embedding device
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
    
    # Initialize dynamic cache
    past_key_values = DynamicCache()
    
    # Generate KV cache with proper device placement
    with torch.no_grad():
        outputs:CausalLMOutputWithPast = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False
        )
    
    # The model's device mapping will automatically place each layer's 
    # KV cache on the correct device
    # print(outputs)
    return outputs.past_key_values


In [76]:
prompt = "Jack has a dog"
knowledge_cache = get_kv_cache(model, tokenizer, prompt)

input_ids = tokenizer.encode( prompt , return_tensors="pt" ).to(model.device)
output = generate(model, input_ids, knowledge_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)

print(generated_text)

TypeError: argument of type 'int' is not iterable

In [46]:
kv = get_kv_cache(model, tokenizer, prompt)
print(kv)

CausalLMOutputWithPast(loss=None, logits=tensor([[[13.9783,  2.6477,  3.1462,  ...,  6.3939, 10.8440,  4.2600],
         [ 5.7300, -2.1732, -2.1956,  ...,  3.6904,  1.7537, -1.9005],
         [ 6.8805, -2.7431, -3.0029,  ..., -1.4643,  4.3299, -1.3837],
         [ 9.7011, -1.7691,  0.3797,  ...,  0.7390,  4.7757,  0.9799]]]), past_key_values=DynamicCache(), hidden_states=None, attentions=None)
DynamicCache()


In [47]:
import numpy as np
# print(kv.key_cache)
print(np.shape(kv.key_cache))
print(np.shape(kv.key_cache[0]))

(32, 1, 5, 4, 64)
torch.Size([1, 5, 4, 64])


In [48]:
# print(kv.key_cache[0][0])
for k in kv.key_cache[0][0]:
    print(np.shape(k))
    print(k[0])

torch.Size([4, 64])
tensor([ 0.9023,  0.9028,  0.5712, -0.9479,  0.0461,  0.3881,  0.6966,  0.1214,
         0.4611,  0.0428,  0.4337, -0.2304, -0.3539, -0.3612, -0.0865,  0.5239,
         0.5309,  0.5108,  0.1040, -0.8593, -0.7498,  0.4739, -0.3288,  0.1561,
         0.1466,  0.0914, -0.1660,  0.1130, -0.4860, -0.1204,  0.2291, -0.1399,
        -0.4165,  0.4028, -0.6254,  0.3496,  0.8387,  0.5649, -0.0890,  0.3319,
        -0.0305,  0.4899,  0.1726, -0.3607,  0.2789,  0.5966,  0.1437, -0.7063,
         0.1566, -0.4533, -0.1549,  0.3641,  0.0422,  0.4065,  0.0651,  0.2237,
        -0.1773,  0.3386, -0.1744, -0.1610,  0.2266,  0.1532,  1.6947,  0.1490])
torch.Size([4, 64])
tensor([-0.0810, -0.0814, -0.1823,  0.3743, -0.2174, -0.0848,  1.6107,  0.5091,
         0.2038,  0.9368, -1.5883,  0.2726, -0.5295, -0.5965, -0.1943, -0.9100,
        -0.5116,  0.9445,  1.1200,  0.1485,  0.5198, -0.3313, -0.2594, -0.7822,
        -0.2948,  0.2163,  0.3134, -0.2262,  0.7887, -0.9766,  1.9020, -2.4390,