# Fine-tuning Gemma for Chinese Poetry Generation

This document describes how to fine-tune the **Gemma** model using a dataset of Chinese poetry. The goal is to adapt the model to generate Chinese poetry in a classical style by training it on a subset of poems. The fine-tuning process leverages **LoRA** (Low-Rank Adaptation) for efficient model adaptation.

---

## Prerequisites

Make sure to install the required libraries using the following commands:

In [1]:
!pip install -q -U keras-nlp datasets
!pip install -q -U keras

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25h

---


## Setup and Configuration

Before fine-tuning the model, we configure the backend and environment variables. This step is essential for optimizing memory usage and performance:

```python
import keras_nlp
import keras
import os

os.environ["KERAS_BACKEND"] = "jax"  # Use JAX as backend for optimization
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"  # Max memory usage

In [2]:
import keras_nlp
import keras
import os

In [3]:
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

---

## Hyperparameters and Model Configuration

Here, we define important parameters such as token limits, dataset size, LoRA rank, learning rate, and the number of training epochs:

```python
token_limit = 128  # Max token length per input
num_data_limit = 500  # Number of training examples to use
lora_rank = 4  # LoRA rank for model adaptation
lr_value = 1e-4  # Learning rate for training
train_epoch = 2  # Number of epochs for fine-tuning
model_id = "gemma2_instruct_2b_en"  # Pre-trained model ID for fine-tuning
```

These parameters allow you to control the model's training behavior and performance.

In [4]:
token_limit = 128 
num_data_limit = 500 
lora_rank = 4 
lr_value = 1e-4 
train_epoch = 2 
model_id = "gemma2_instruct_2b_en" 

---

## Loading the Pre-trained Gemma Model

Next, we load the pre-trained **Gemma** model and check its architecture. The model is saved into a directory called `gemma2_chinese_poetry`:

In [5]:
model_folder = "gemma2_chinese_poetry"
if not os.path.exists(model_folder):
    os.mkdir(model_folder)

In [6]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")
gemma_lm.summary()

---

## Data Preprocessing

We load and preprocess the dataset, which consists of Chinese poems. We read the poems, clean them, and prepare them for training. Only poems that meet the token length requirement are kept. The poems are read from the file, and only the ones with a length less than the specified token limit are used for training.

In [7]:
with open('/kaggle/input/chinesepoetrydataset/chinese_poems.txt', 'r', encoding='utf-8') as f:
    poems = f.readlines()

In [8]:
poems = poems[:num_data_limit]

In [9]:
train = []
for poem in poems[:num_data_limit]:
    poem = poem.strip()  # Remover espaços extras
    if len(poem.split()) < token_limit:  # Garantir que o comprimento não exceda o limite
        train.append(f"<start_of_turn>user\n{poem}\n<end_of_turn>\n<start_of_turn>model\n")

In [10]:
print(f"Number of training examples: {len(train)}")
print(f"First example: {train[0]}")

Number of training examples: 500
First example: <start_of_turn>user
欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.
<end_of_turn>
<start_of_turn>model



---

## Model Fine-tuning with LoRA

The **LoRA** technique is applied to the model to adapt it efficiently with fewer parameters. We also configure the optimizer, which uses **AdamW** for weight decay regularization. The model is fine-tuned using the defined parameters and optimizer. After training, the LoRA weights are saved to a file for later use.

In [11]:
gemma_lm.backbone.enable_lora(rank=lora_rank)
gemma_lm.summary()

In [12]:
gemma_lm.preprocessor.sequence_length = token_limit

In [13]:
optimizer = keras.optimizers.AdamW(
    learning_rate=lr_value,
    weight_decay=0.01,
)

In [14]:
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

In [15]:
gemma_lm.fit(train, epochs=train_epoch, batch_size=1)

Epoch 1/2
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 245ms/step - loss: 3.6435 - sparse_categorical_accuracy: 0.1217
Epoch 2/2
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 245ms/step - loss: 2.5870 - sparse_categorical_accuracy: 0.2798


<keras.src.callbacks.history.History at 0x7ae00f548d60>

---

## Saving the Fine-tuned Model

Once the model is fine-tuned, we save the **LoRA weights** to a file so that the trained model can be reused:

In [16]:
gemma_lm.backbone.save_lora_weights(f"/kaggle/working/gemma2_chinese_poetry/gemma2_chinese_poetry.lora.h5")

---

## Text Generation with the Fine-tuned Model

Finally, the model can generate text based on a given prompt. We define a function to generate text using the fine-tuned model, which takes an input prompt and returns a generated response:

In [17]:
def generate_text(prompt, token_limit=256):
    input_text = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
    generated_text = gemma_lm.generate(input_text, max_length=token_limit)
    print("\nGenerated text:")
    print(generated_text)

---

## Example Usage

Here’s an example of how to generate a Chinese poem using the fine-tuned model. The prompt is in Chinese and asks the model to write a Tang Dynasty-style poem about the moon:

In [18]:
prompt = "写一首关于月亮的唐代风格诗。"
generate_text(prompt)


Generated text:
<start_of_turn>user
写一首关于月亮的唐代风格诗。<end_of_turn>
<start_of_turn>model
## 月夜吟

银盘高悬夜空深,
清辉洒下寒枝眠.
孤灯摇曳花枝摇,
月影映照水波流.
静待春风拂柳,
长醉夜色无相思.<end_of_turn>


---

## Conclusion

By following this guide, you can fine-tune the **Gemma** model for various tasks, such as generating Chinese poetry. Using techniques like **LoRA** for efficient adaptation, the model can be specialized to produce creative outputs while maintaining performance.