# Fine Tuning an LLM

## Load and Pre-process the dataset

In [1]:
import random

from datasets import load_dataset, load_from_disk
import spacy



def get_as_messages(data):
    headline = get_headline(data['text'])
    return {
        'messages': [
            {'role': 'user', 'content': get_prompt(headline)},
            {'role': 'assistant', 'content': headline},
        ]
    }


SEP = '#~#'
def get_headline(text):
    return text.split(SEP)[0].strip()


def get_prompt(headline):
    if random.random() < 0.1:
        return get_prompt_no_topic()
    topics = extract_topics(headline)
    if not topics:
        return get_prompt_no_topic()
    else:
        return get_prompt_topics(topics)


nlp = spacy.load('en_core_web_sm')
def extract_topics(text, max_topics=2):
    doc = nlp(text.lower()) # lower casing the text seems to result in better parsing
    noun_chunks = [chunk for chunk in doc.noun_chunks if chunk.root.pos_ != 'PRON'] # get all noun chunks, ignoring pronouns
    noun_chunks.sort(key=key, reverse=True)
    return [chunk.text.strip() for chunk in noun_chunks[:max_topics]]


def key(span):
    # Returns the importance of a span. A span is important if it contains an entity or is long.
    entity_types = ['PERSON', 'ORG', 'GPE']
    is_entity = any(token.ent_type_ in entity_types for token in span)
    length = len(span.text.strip())
    return (is_entity, length)


def get_prompt_topics(topics):
    template = random.choice([
        "Write a satirical headline about {topic}.",
        "Create an Onion-style headline related to {topic}.",
        "Give me your funniest, most satirical take on {topic}.",
        "Satirize the topic in {topic} in a news headline.",
        "What would a satirical news headline about {topic} sound like?",
    ])
    return template.format(topic=' and '.join(topics))


def get_prompt_no_topic():
    return random.choice([
        'Write a satirical headline in the style of Onion News.',
        'Generate an original satirical news headline.',
        "Give me a headline that sounds like it's from the Onion.",
        'Come up with a funny, fake news headline.',
        'Write a headline that is both absurd and oddly believable.',
    ])

filepath = 'Onion-News-Guided-Prompts'
try:
    ds = load_from_disk(filepath)
except:
    ds = load_dataset("Biddls/Onion_News")
    ds = ds.map(get_as_messages, remove_columns=['text'])
    ds.save_to_disk('Onion-News-Guided-Prompts')
ds['train'][:5]

{'messages': [[{'content': 'What would a satirical news headline about relaxed marie kondo and waist-high sewage sound like?',
    'role': 'user'},
   {'content': 'Relaxed Marie Kondo Now Says She Perfectly Happy Living In Waist-High Sewage',
    'role': 'assistant'}],
  [{'content': "Give me a headline that sounds like it's from the Onion.",
    'role': 'user'},
   {'content': 'U.S. Officials Call For Correct Amount Of Violence',
    'role': 'assistant'}],
  [{'content': 'Create an Onion-style headline related to kamala harris and communications assistant.',
    'role': 'user'},
   {'content': 'Kamala Harris Asks Communications Assistant If She Can Take Them Out For Coffee And Pick Their Brain Sometime',
    'role': 'assistant'}],
  [{'content': 'Create an Onion-style headline related to fake nursing school diploma scheme.',
    'role': 'user'},
   {'content': '25 Arrested In Fake Nursing School Diploma Scheme',
    'role': 'assistant'}],
  [{'content': 'Come up with a funny, fake new

## Load the baseline model

In [2]:
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline

torch.random.manual_seed(0)

model_path = "Qwen/Qwen3-4B-Instruct-2507"

bnb_config = BitsAndBytesConfig(
    load_in_84it=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype="auto",
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
new_template = '''{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}'''

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.chat_template = new_template

tokenizer.apply_chat_template(
    ds['train'][0]['messages'],
    tokenize=False,
)

'<|im_start|>user\nWhat would a satirical news headline about relaxed marie kondo and waist-high sewage sound like?<|im_end|>\n<|im_start|>assistant\nRelaxed Marie Kondo Now Says She Perfectly Happy Living In Waist-High Sewage<|im_end|>\n'

## Wrap the model in a LoRA

In [4]:
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    task_type=TaskType.CAUSAL_LM,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
)

prepared_model = prepare_model_for_kbit_training(model)

In [5]:
from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir='./qwen3_lora_finetuned_v2.1',
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=500,
    report_to="none",
    optim='paged_adamw_32bit',
    lr_scheduler_type='linear',
    warmup_steps=5,
    seed=42,
)

trainer = SFTTrainer(
    model=prepared_model,
    args=training_args,
    train_dataset=ds['train'],
    peft_config=peft_config,
    processing_class=tokenizer,
)

Tokenizing train dataset:   0%|          | 0/33880 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/33880 [00:00<?, ? examples/s]

In [6]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
10,4.4984
20,2.2303
30,1.7443
40,1.6719
50,1.6045
60,1.647
70,1.6274
80,1.6269
90,1.5878
100,1.5822


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


TrainOutput(global_step=2118, training_loss=1.4996271110909294, metrics={'train_runtime': 3200.7486, 'train_samples_per_second': 10.585, 'train_steps_per_second': 0.662, 'total_flos': 3.5359398511104e+16, 'train_loss': 1.4996271110909294})