In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

import pandas as pd
from datasets import Dataset

import bitsandbytes as bnb

from huggingface_hub import login

torch.cuda.empty_cache()

In [2]:
torch_dtype = torch.float16
attn_implementation = "eager"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

base_model = "meta-llama/Llama-3.2-1B"

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Load tokenizer
tokenizer= AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [3]:

# Function for text inference
def infer_text(model, tokenizer, input_text, max_length=50):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
    
    # Move tensors to appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = {key: value.to(device) for key, value in inputs.items()}
    model.to(device)
    
    # Generate predictions
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Decode the generated text
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage
input_text = "Once upon a time in a distant galaxy,"
generated_text = infer_text(model, tokenizer, input_text)
print(f"Input: {input_text}")
print(f"Generated: {generated_text}")

You shouldn't move a model that is dispatched using accelerate hooks.


Input: Once upon a time in a distant galaxy,
Generated: Once upon a time in a distant galaxy, there lived a small village. In the village, there was a little girl named Keshia. She was very young, and she had no friends. She was lonely and sad.
One day,
