In [18]:
import os 
import torch
import torch.nn.functional as F 
import json
from torchtune.models.llama3_2 import lora_llama3_2_3b, llama3_2_3b
from torchtune.training import FullModelHFCheckpointer
from peft import PeftModel, PeftConfig

In [2]:
# Set the right directory and files
checkpoint_dir = "/rcp/marco/models/llama-3.2-3B-Instruct/closure-finetuned/checkpoints"
lora_dir = "/rcp/marco/models/llama-3.2-3B-Instruct/closure-finetuned/checkpoints"  # Add path to your LoRA weights
output_dir = "/rcp/marco/models/llama-3.2-3B-Instruct/closure-finetuned/checkpoints"

pytorch_files = [
    "hf_model_0001_0.pt",
    "hf_model_0002_0.pt",
]

adlora_config = PeftConfig.from_pretrained(lora_dir)
lora_config = adlora_config.to_dict()

In [3]:

# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir=output_dir,
    model_type="LLAMA3_2",
)
torchtune_sd = checkpointer.load_checkpoint()

# Setup the base model
# model = lora_llama3_2_3b(
#     lora_attn_modules=lora_config["target_modules"],
#     apply_lora_to_mlp=False,
#     apply_lora_to_output=False, # not supported on llama3.2
#     lora_rank=lora_config["r"],
#     lora_alpha=lora_config["lora_alpha"],
#     lora_dropout=lora_config["lora_dropout"],
#     use_dora=lora_config["use_dora"],
#     quantize_base=False
# )
model = llama3_2_3b()
model.load_state_dict(torchtune_sd["model"])
model.to("cuda")

TransformerDecoder(
  (tok_embeddings): Embedding(128256, 3072)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
        (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
        (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
        (output_proj): Linear(in_features=3072, out_features=3072, bias=False)
        (pos_embeddings): Llama3ScaledRoPE()
      )
      (mlp): FeedForward(
        (w1): Linear(in_features=3072, out_features=8192, bias=False)
        (w2): Linear(in_features=8192, out_features=3072, bias=False)
        (w3): Linear(in_features=3072, out_features=8192, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/rcp/marco/models/base/llama-3.2-3B-Instruct")

In [29]:
example_prompt = "Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace.\n\nAutism is associated with a combination of genetic and environmental factors. Risk factors during pregnancy include certain infections, such as rubella, toxins including valproic acid, alcohol, cocaine, pesticides, lead, and air pollution, fetal growth restriction, and autoimmune diseases. Controversies surround other proposed environmental causes; for example, the vaccine hypothesis, which has been disproven. Autism affects information processing in the brain and how nerve cells and their synapses connect and organize; how this occurs is not well understood. The Diagnostic and Statistical Manual of Mental Disorders (DSM-5) combines forms of the condition, including Asperger syndrome and pervasive developmental disorder not otherwise specified (PDD-NOS) into the diagnosis of autism spectrum disorder (ASD).\n\nSeveral interventions have been shown to reduce symptoms and improve the ability of autistic people to function and participate independently in the community. Behavioral, psychological, education, and/or skill-building interventions may be used to\n@##$$##@ fwd_closure((integer))>:\u00e7%\u00e7>:\n"
example_tokens = tokenizer.encode(example_prompt, return_tensors="pt")

In [33]:
def generate_with_temp(model, tokens, max_tokens=100, temperature=0.7, skip_special_tokens=True):
    tokens = tokens.to("cuda")
    generated = tokens.clone()
    for _ in range(max_tokens):
        with torch.no_grad():
            logits = model(generated)
        next_token_logits = logits[0, -1, :]
        if temperature > 0:
            next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    if skip_special_tokens:
        # Create a mask for non-special tokens
        mask = torch.tensor([token not in tokenizer.all_special_ids for token in generated[0]], device=generated.device)
        # Filter out special tokens
        generated = generated[:, mask]
        
    return generated

def generate_with_temp_stream(model, tokens, max_tokens=100, temperature=0.7, skip_special_tokens=True):
    tokens = tokens.to("cuda") 
    generated = tokens.clone()
    for _ in range(max_tokens):
        with torch.no_grad():
            logits = model(generated)
        next_token_logits = logits[0, -1, :]
        if temperature > 0:
            next_token_logits = next_token_logits / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        if skip_special_tokens and next_token.item() in tokenizer.all_special_ids:
            continue
        generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
        yield next_token.item()
        if next_token.item() == tokenizer.eos_token_id:
            break


In [34]:
# output_text = ""
# for token in generate_with_temp_stream(model, example_tokens, max_tokens=200, temperature=0.7):
#     output_text += tokenizer.decode([token])
#     print(output_text, end="\r", flush=True)
# print("\n")  # Add newline at the end


output_tokens = generate_with_temp(model, example_tokens, max_tokens=200, temperature=0)
print(tokenizer.decode(output_tokens[0]))

Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace.

Autism is associated with a combination of genetic and environmental factors. Risk factors during pregnancy include certain infections, such as rubella, toxins including valproic acid, alcohol, cocaine, pesticides, lead, and air pollution, fetal growth restriction, and autoimmune diseases. Controversies surround other proposed environmental causes; for example, the vaccine hypothesis, which has been disproven. Autism affects information processing in the brain and how nerve cells and their synapses connect and organize; how this occurs is not well understood. The Diagno