In [180]:
from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM
import torch
import gc
import tqdm as tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [181]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)

In [182]:
dummy = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
lm_head = dummy.lm_head  # This is the layer responsible for mapping hidden states to vocab logits
print("LM Head:", lm_head)


LM Head: Linear(in_features=2048, out_features=128256, bias=False)


In [183]:
# import inspect
# lines = inspect.getsource(model.forward)
# print(lines)




In [184]:
tokenizer.pad_token = tokenizer.eos_token

query1 = "What is the capital of Italy?"
query2 = "What is the last answer?"
query_size1 = len(tokenizer(query1)["input_ids"])  # Token count for each query
max_out1 = query_size1 + 1  # Maximum tokens including both query and generated response per query
query_size2 = len(tokenizer(query2)["input_ids"])  # Token count for each query
max_out2 = query_size2 + 1  # Maximum tokens including both query and generated response per query

model.to(device)

batch_size = 1

# Pad input for each query and concatenate
inputs1 = tokenizer(query1, padding='max_length', max_length=max_out1, return_tensors="pt")
inputs2 = tokenizer(query2, padding='max_length', max_length=max_out2, return_tensors="pt")

print(tokenizer.decode(inputs1["input_ids"][0],skip_special_tokens=True))
print(tokenizer.decode(inputs2["input_ids"][0],skip_special_tokens=True))

# Stack the input IDs with padding for each query's space
input_ids = torch.cat([inputs1['input_ids'], inputs2['input_ids']], dim=1).to(device)
print(tokenizer.decode(input_ids[0],skip_special_tokens=True))

What is the capital of Italy?
What is the last answer?
What is the capital of Italy?What is the last answer?


In [185]:


# Build custom attention mask
L = max_out1 + max_out2  # Total length (combined max output for both queries)
attention_mask = torch.zeros((L, L), dtype=torch.float16)

# Block-diagonal causal mask for each query and response output
# First block for query 1
attention_mask[:max_out1, :max_out1] = torch.tril(torch.ones((max_out1, max_out1), dtype=torch.float16))
# Second block for query 2
attention_mask[max_out1:, max_out1:] = torch.tril(torch.ones((max_out2, max_out2), dtype=torch.float16))
# attention_mask = 1 - attention_mask

print(attention_mask)
# Expand to batch size and move to device
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(device)
# attention_mask = attention_mask.to(device)

# Run inference with the custom attention mask
with torch.no_grad():
    outputs = model.forward(
        input_ids,
        attention_mask=attention_mask
    )
    hidden_states = outputs.last_hidden_state

logits = lm_head(hidden_states)
output_ids = torch.argmax(logits, dim=-1)




tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.,

In [186]:

print(max_out1)

# print(outputs.keys())
# Decode and split output


9


In [187]:
oid =  [output_ids[0][query_size1-1:max_out1-1], output_ids[0][max_out1+query_size2-1:-1]] 
print(oid)

[tensor([22463]), tensor([220])]


In [188]:
full_response = tokenizer.decode(oid[1], skip_special_tokens=True)
# response1 = full_response[:max_out].strip()  # Output for query 1
# response2 = full_response[max_out:].strip()  # Output for query 2

# print("Response to Query 1:", response1)
# print("Response to Query 2:", response2)
print(full_response)

 
