In [1]:
# 🧪 FlamingNeuron Model Inference Notebook
# Purpose: Load and test a fine-tuned, merged LLaMA 3.1 8B model from Hugging Face
# Dependencies: Colab Pro with A100 GPU

In [1]:
# 🔧 1. Install Dependencies
!pip install -qU bitsandbytes datasets accelerate loralib peft transformers trl

In [2]:
# 🧠 2. Import Libraries and Check GPU
import torch
torch.cuda.is_available()  # Should return True if GPU is available

import os
print("CUDA available:", torch.cuda.is_available())
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig

CUDA available: True


In [3]:
# 🔢 3. Configure Quantization (4-bit LoRA-style setup)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

In [4]:
# 🧠 4. Set Model ID
HF_USER_NAME = "FlamingNeuron"
model_id = f"{HF_USER_NAME}/llama381binstruct_summarize_short_merged"

In [None]:
# 🤖 5. Load Merged Model + Tokenizer from Hugging Face
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [6]:
# ⚙️ 6. Tokenizer Setup (Recommended for Inference)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
# 📂 7. Load Legal Dataset from Tutorial
!git clone https://github.com/lauramanor/legal_summarization

import json
jsonl_array = []
with open('legal_summarization/tldrlegal_v1.json') as f:
  data = json.load(f)
  for key, value in data.items():
    jsonl_array.append(value)

from datasets import Dataset
legal_dataset = Dataset.from_list(jsonl_array)

In [9]:
# 🛠️ 8. (Optional) Inspect the Model
#print(model)  # Uncomment to view model architecture
#model.config   # Uncomment to inspect config metadata

In [9]:
#9. Prompt Template
INSTRUCTION_PROMPT_TEMPLATE = """\
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Please convert the following legal content into a short human-readable summary<|eot_id|><|start_header_id|>user<|end_header_id|>

[LEGAL_DOC]{LEGAL_TEXT}[END_LEGAL_DOC]<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

RESPONSE_TEMPLATE = """\
{NATURAL_LANGUAGE_SUMMARY}<|eot_id|><|end_of_text|>"""


In [10]:
#10. Create Prompt
def create_prompt(sample, include_response=True):
    """
    Constructs a Meta-style prompt for a fine-tuned LLaMA 3 model.

    - sample: a row from the legal_dataset (dict)
    - include_response: if True, adds the summary for training-style prompts

    Returns a single string.
    """
    full_prompt = INSTRUCTION_PROMPT_TEMPLATE.format(LEGAL_TEXT=sample["original_text"])
    if include_response:
        full_prompt += RESPONSE_TEMPLATE.format(NATURAL_LANGUAGE_SUMMARY=sample["reference_summary"])
    return full_prompt


In [11]:
#11. Generate response
def generate_response(prompt, model, tokenizer):
    encoded_input = tokenizer(prompt, return_tensors="pt")
    model_inputs = encoded_input.to("cuda" if torch.cuda.is_available() else "cpu")

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=256,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    decoded_output = tokenizer.batch_decode(generated_ids)[0]

    # Try to extract just the assistant response
    parts = decoded_output.split("<|end_header_id|>")
    return parts[-1] if len(parts) > 1 else decoded_output


In [None]:
#12. See what the prompt looks like
print(create_prompt(legal_dataset[1], include_response=False))

# Generate a summary
print(generate_response(create_prompt(legal_dataset[1], include_response=False), model, tokenizer))

# Print the actual ground truth for comparison
print("🔍 Ground truth summary:")
print(legal_dataset[1]["reference_summary"])


In [None]:
# ✨ Optional Run a Test Prompt
prompt = "Summarize this: The quick brown fox jumps over the lazy dog."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))