# Installation

In [None]:
%%capture
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    !uv pip install -qqq \
        "torch==2.7.1" "triton>=3.3.0" {get_numpy} {get_pil} torchvision bitsandbytes "transformers==4.56.2" \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth"
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo

# These are mamba kernels and we must have these for faster training
# Mamba kernels are for now supported only on torch==2.7.1. If you have newer torch versions, please wait 30 minutes for it to compile
!uv pip install --no-build-isolation mamba_ssm==2.2.5
!uv pip install --no-build-isolation causal_conv1d==1.5.2

# Unsloth

To fine tune Granite 3.3, replace the model name with ibm-granite/granite-3.3-8b-instruct

In [None]:
from unsloth import FastLanguageModel
import torch

"""
Model explanations:

granite-4.0-h-micro: 3B dense model
granite-4.0-h-tiny: hybrid 7B MoE w/ 1B active
granite-4.0-h-small: hybrid 32B MoE w/ 9B active

granite-3.3-8b-instruct: 8B dense model
"""

fourbit_models = [
    "unsloth/granite-4.0-micro",
    "unsloth/granite-4.0-h-micro",
    "unsloth/granite-4.0-h-tiny",
    "unsloth/granite-4.0-h-small",

    # Base pretrained Granite 4 models
    "unsloth/granite-4.0-micro-base",
    "unsloth/granite-4.0-h-micro-base",
    "unsloth/granite-4.0-h-tiny-base",
    "unsloth/granite-4.0-h-small-base",
]

# We can try testing granite 4 small on a big GPU if micro isn't that good
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/granite-4.0-h-tiny",
    max_seq_length = 2048,   # Choose any for long context!
    load_in_4bit = False,    # Load full precision for max accuracy
    load_in_8bit = False,
    full_finetuning = False, # Don't do this, LoRA gets just as good results
)

We now add LoRA adapters so we only need to update a small amount of parameters!

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Should be multiple of 2 at least 8
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",
                      "shared_mlp.input_linear", "shared_mlp.output_linear"],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # Maybe test this later
    loftq_config = None, #
)

Unsloth: Making `model.base_model.model.model` require gradients


# Data Prep

The chat template for granite-4 look like this:
```
<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.
Today's Date: June 24, 2025.
You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|>

<|start_of_role|>user<|end_of_role|>How do astronomers determine the original wavelength of light emitted by a celestial body at rest, which is necessary for measuring its speed using the Doppler effect?<|end_of_text|>

<|start_of_role|>assistant<|end_of_role|>Astronomers make use of the unique spectral fingerprints of elements found in stars...<|end_of_text|>
```

TODO: Load and prepare dataset

Need to format the data into the following conversational style:

```
{"role": "system", "content": "system prompt"}
{"role": "user", "content": "What is 2+2?"}
{"role": "assistant", "content": "It's 4."}
```

In [None]:
# From Sakhawat et al., 2026

default_system_prompt = """
You are participating in a standardized News Bias Classification task for academic
research. Output ONLY a single numeric value between -3.0 and +3.0.

Do NOT provide explanations or text.
"""

In [None]:
def formatting_prompts_func(examples):
    title = examples['title']
    article = examples['article']
    response_label = examples['label']

    messages = []
    for user_text, response_text in zip(user_texts, response_texts):
        user_text = f"""
          TITLE: {title}
          ARTICLE TEXT: {article}

          Output ONLY the numeric bias score.
        """

        conversation = [
            {"role": "system", "content": default_system_prompt},
            {"role": "user", "content": user_text},
            {"role": "assistant", "content": response_text}
        ]
        messages.append(conversation)

    texts = [tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = False) for message in messages]

    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True,)

Map:   0%|          | 0/504 [00:00<?, ? examples/s]

Look at how the chat template mapped the conversation

In [None]:
dataset[5]["text"]

# Training


In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # TODO: Set up validation dataset
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 1, # Full training runs, can experiment w/ multiple
 #       max_steps = 60,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use TrackIO/WandB etc
    ),
)

Unsloth: We found double BOS tokens - we shall remove one automatically.


Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/504 [00:00<?, ? examples/s]

We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_of_role|>user|end_of_role|>",
    response_part = "<|start_of_role|>assistant<|end_of_role|>",
)

Map (num_proc=2):   0%|          | 0/504 [00:00<?, ? examples/s]

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

Let's train the model! To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

```
Notice you might have to wait ~10 minutes for the Mamba kernels to compile! Please be patient!
```

In [None]:
trainer_stats = trainer.train()

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

# Saving, loading finetuned models

In [None]:
# Merge to 16bit
if False:
    model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: # Pushing to HF Hub

    # This one requires notebook sign-in w/ HF token beforehand
    model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")
