~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

# Fine-tuning TxGemma with Hugging Face

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DFinetune_with_Hugging_Face.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DFinetune_with_Hugging_Face.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/collections/google/txgemma-release-67dd92e931c857d15e4d1e87">
      <img alt="HuggingFace logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on HuggingFace
    </a>
  </td>
</tr></tbody></table>

This notebook demonstrates fine-tuning TxGemma models to generalize to new therapeutic development tasks using Hugging Face libraries.

The demo uses Hugging Face's [Transformer Reinforcement Learning (`TRL`)](https://github.com/huggingface/trl) library to train the model with Supervised Fine-Tuning (SFT), utilizing [Parameter-Efficient Fine-Tuning (`PEFT`)](https://github.com/huggingface/peft) with Low-Rank Adaptation (LoRA)  to reduce computational costs. The training data includes a subset of the [TrialBench](https://arxiv.org/abs/2407.00631) dataset to fine-tune TxGemma to predict adverse events in clinical trials.


## Setup

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to fine-tune and run the TxGemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

### Get access to TxGemma

Before you get started, make sure that you have access to TxGemma models on Hugging Face:

1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).
2. Head over to the [TxGemma model page](https://huggingface.co/google/txgemma-2b-predict) and accept the usage conditions.

### Configure your HF token

Generate a Hugging Face `read` access token by clicking [here](https://huggingface.co/settings/tokens) and add your access token to the Colab Secrets manager to securely store it.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.

In [None]:
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

### Install dependencies

In [None]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers==4.48.3 trl

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m62.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.0/76.0 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m411.0/411.0 kB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m335.7/335.7 kB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Load model from Hugging Face Hub

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/txgemma-2b-predict"

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

## Load dataset

This notebook uses adverse event prediction data from [TrialBench](https://arxiv.org/abs/2407.00631) to fine-tune TxGemma. The dataset has been preprocessed into an instruction-tuning format and is available in [Cloud Storage](https://console.cloud.google.com/storage/browser/healthai-us/txgemma/datasets).

Load the dataset using the Hugging Face [`datasets`](https://github.com/huggingface/datasets) library.

In [None]:
! pip install datasets



In [None]:
! wget -P /output/ "https://storage.googleapis.com/training_data_8008/extracts/training_data.jsonl?x-goog-signature=61994e6c5751c33c3893216594164af8bcc3c151349124e2e7bd2adcaebffc995df34fe6db6ed8811e2f087951ca34d16d236a793767ecc3900d7dc5be3b398113f7ca84e81cb066117241e2ad2d592a862ba8ed6483d42dd0944540d63209730f075bb6f27e1d1cc84132a4a2059d2cd5c731d89ad48550c0661a622937244882ef19a65b296e21f9544d0141bbcd730bd2c37c8986c87960b5702f2a73e0ad9e50f48a03508e0f8787c723cc0dd788a3f819ce60ae8d74ab3efc9a25ee0c328025f1935ab71fea89e170885ef3eeba2132c38cd9e9b625d1936be912c31a0af18ff7ea6e1792f65b33805220d0e5a7d90c4e5a24f65a3d8e9fe2ef4c99616a&x-goog-algorithm=GOOG4-RSA-SHA256&x-goog-credential=875891476182-compute%40developer.gserviceaccount.com%2F20250403%2Fus%2Fstorage%2Fgoog4_request&x-goog-date=20250403T182535Z&x-goog-expires=3600&x-goog-signedheaders=host"

The destination name is too long (769), reducing to 236
--2025-04-03 18:26:03--  https://storage.googleapis.com/training_data_8008/extracts/training_data.jsonl?x-goog-signature=61994e6c5751c33c3893216594164af8bcc3c151349124e2e7bd2adcaebffc995df34fe6db6ed8811e2f087951ca34d16d236a793767ecc3900d7dc5be3b398113f7ca84e81cb066117241e2ad2d592a862ba8ed6483d42dd0944540d63209730f075bb6f27e1d1cc84132a4a2059d2cd5c731d89ad48550c0661a622937244882ef19a65b296e21f9544d0141bbcd730bd2c37c8986c87960b5702f2a73e0ad9e50f48a03508e0f8787c723cc0dd788a3f819ce60ae8d74ab3efc9a25ee0c328025f1935ab71fea89e170885ef3eeba2132c38cd9e9b625d1936be912c31a0af18ff7ea6e1792f65b33805220d0e5a7d90c4e5a24f65a3d8e9fe2ef4c99616a&x-goog-algorithm=GOOG4-RSA-SHA256&x-goog-credential=875891476182-compute%40developer.gserviceaccount.com%2F20250403%2Fus%2Fstorage%2Fgoog4_request&x-goog-date=20250403T182535Z&x-goog-expires=3600&x-goog-signedheaders=host
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.135.207, 74.125.142.

In [None]:
from datasets import load_dataset

data = load_dataset(
    "json",
    data_files="/output/training_data.jsonl",
    split="train",
)

# Display dataset details
data

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['input_text', 'output_text'],
    num_rows: 2500000
})

Each data point includes:

* `"input_text"`: Question, which prompts the model to predict whether there will be an adverse event given information about a clinical trial. Inputs include drug SMILES strings and textual information.

* `"output_text"`: Answer, which is either "Yes" or "No".

Below is an example from the dataset:

In [None]:
data["input_text"][0]

"Considering the provided data, the gene with approved symbol 'ETV1' (full name: 'ETS variant transcription factor 1', biotype: 'protein_coding') is located on chromosome 7 (positions 13891229 to 13991425, strand -1). It interacts with target 'ENSG00000100968' with an evidence score of 0.08. Based on this, can we conclude that this interaction is biologically significant?"

In [None]:
data["output_text"][0]

'No'

The expected data format for training is a single `"text"` column containing a full sequence of text.

Here, define a function that properly formats each example in the dataset. In a later section, it will be passed to the `SFTTrainer`, which applies the formatting function to the dataset before tokenization.

In [None]:
def formatting_func(example):
    text = f"{example['input_text']} {example['output_text']}<eos>"
    return text

# Display formatted training data example
print(formatting_func(data[0]))

Considering the provided data, the gene with approved symbol 'ETV1' (full name: 'ETS variant transcription factor 1', biotype: 'protein_coding') is located on chromosome 7 (positions 13891229 to 13991425, strand -1). It interacts with target 'ENSG00000100968' with an evidence score of 0.08. Based on this, can we conclude that this interaction is biologically significant? No<eos>


## Try out the pretrained model

Prompt the pretrained model to see how it performs on a sample adverse event prediction task. Prior to fine-tuning, the model does not understand the instruction and provides an inappropriate answer.

In [None]:
prompt = "From the following information about a clinical trial, predict whether it would have an adverse event.\n\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


From the following information about a clinical trial, predict whether it would have an adverse event.

Drug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2

Answer:188


## Fine-tune the model with LoRA

Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters, using techniques like Low-Rank Adaptation (LoRA). LoRA efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices.


This section demonstrates fine-tuning TxGemma using LoRA and the `SFTTrainer` from the Hugging Face `TRL` library.

First, define the [`LoraConfig`](https://huggingface.co/docs/peft/main/en/package_reference/lora), including the rank of the adaptation matrices and the model layers to add LoRA adapters to.

In [None]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

Prepare the model for training.

In [None]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)

This example uses the Supervised Fine-Tuning (SFT) method to train the TxGemma model.

Here, construct the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer) that handles the complete training loop, including data loading, forward and backward passes, and optimizer steps. Specify the LoRA configuration and dataset formatting function defined earlier and the `SFTConfig` with training parameters.

In [None]:
import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=data,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=50,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

Applying formatting function to train dataset:   0%|          | 0/2500000 [00:00<?, ? examples/s]

Converting train dataset to ChatML:   0%|          | 0/2500000 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/2500000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/2500000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2500000 [00:00<?, ? examples/s]

Launch the fine-tuning process.

In [None]:
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
5,1.0273
10,0.6851
15,0.9121
20,0.6731
25,0.677
30,0.4778
35,0.5827
40,0.5936
45,0.4133
50,0.5631


TrainOutput(global_step=50, training_loss=0.660508861541748, metrics={'train_runtime': 121.7926, 'train_samples_per_second': 1.642, 'train_steps_per_second': 0.411, 'total_flos': 175363719750144.0, 'train_loss': 0.660508861541748})

## Test the fine-tuned model

Prompt the fine-tuned model to see how it performs on a sample adverse event prediction task. After fine-tuning, the model has learned to respond with an appropriate answer to the prompt.



In [None]:
prompt = "Considering the provided data, the gene with approved symbol 'ETV1' (full name: 'ETS variant transcription factor 1', biotype: 'protein_coding') is located on chromosome 7 (positions 13891229 to 13991425, strand -1). It interacts with target 'ENSG00000100968' with an evidence score of 0.08. Based on this, can we conclude that this interaction is biologically significant?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

  return fn(*args, **kwargs)


Considering the provided data, the gene with approved symbol 'ETV1' (full name: 'ETS variant transcription factor 1', biotype: 'protein_coding') is located on chromosome 7 (positions 13891229 to 13991425, strand -1). It interacts with target 'ENSG00000100968' with an evidence score of 0.08. Based on this, can we conclude that this interaction is biologically significant? No


In [None]:
trainer.save_model("/output")

In [None]:
model.save_pretrained("/output/txgemma/")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cp -r /output /content/drive/MyDrive/txgemma

# Next steps

Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model.