# Fine-tune Gemma 2b using LoRA

## Setup

In [None]:
import os
from google.colab import userdata
from utils import preprocess_qa_data


In [None]:
COLAB = True
KAGGLE = False
DOWNLOAD_DATA = True
SAVE_TO_GITHUB = False
GIT_REPOSITORY = "CS221-project"
FILE_NAME = "colab_tuning.ipynb"

if COLAB:
    PARENT_DIRECTORY_PATH = "/content"
    # In case you want to clone in your drive:
    PARENT_DIRECTORY_PATH = "/content/drive/MyDrive"
    PROJECT_PATH = PARENT_DIRECTORY_PATH + "/" + GIT_REPOSITORY
    %cd "{PARENT_DIRECTORY_PATH}"

In [None]:
if COLAB:
    %cd /content
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

In [None]:
if COLAB:
    import json

    with open(f"{PARENT_DIRECTORY_PATH}/Git/git.json", "r") as f:
        parsed_json = json.load(f)

    GIT_USER_NAME = parsed_json["GIT_USER_NAME"]
    GIT_TOKEN = parsed_json["GIT_TOKEN"]
    GIT_USER_EMAIL = parsed_json["GIT_USER_EMAIL"]

    GIT_PATH = (
        f"https://{GIT_TOKEN}@github.com/{GIT_USER_NAME}/{GIT_REPOSITORY}.git"
    )

    %cd "{PARENT_DIRECTORY_PATH}"

    !git clone "{GIT_PATH}"  # Clone the github repository

    %cd "{PROJECT_PATH}"

In [None]:
if COLAB:
    import os
    os.environ["KAGGLE_CONFIG_DIR"] = f"{PARENT_DIRECTORY_PATH}/Kaggle/kaggle.json"

In [None]:
if SAVE_TO_GITHUB:
    !git add {FILE_NAME}
    !git config --global user.email {GIT_USER_EMAIL}
    !git config --global user.name {GIT_USER_NAME}
    !git commit -am "update {FILE_NAME}"
    !git push

### Set environment variables

In [None]:
# os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
# os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

# Read the kaggle.json file
# with open("kaggle.json") as f:
#     kaggle_info = json.load(f)

# Set the environment variables
# os.environ["KAGGLE_USERNAME"] = kaggle_info["username"]
# os.environ["KAGGLE_KEY"] = kaggle_info["key"]

### Install dependencies

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.3/515.3 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m589.8/589.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m74.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m79.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m103.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

### Select a backend

In [None]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

In [None]:
import keras
import keras_nlp

## Load Dataset

Preprocess the data.

In [None]:
with open("qa_data.txt") as file:
        content = file.read()
data = preprocess_qa_data(content)

## Load Model

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...
Attaching 'metadata.json' from model 

## Inference before fine tuning

### Probability Prompt

In [None]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

prompt = template.format(
    instruction="What is the difference between permutations and combinations?",
    response="",
)

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What is the difference between permutations and combinations?

Response:
There are 16 ways to choose 3 people from a group of 6.
There are 12 ways to choose 3 people without repetition from a group of 9.
There are 48 ways to choose 3 people with repetition from a group of 6.
There are 10 ways to choose 3 people from a group of 3 with repetition.
The answer is the last one.
The number of combinations is the same as the number of permutations with repetition.
There are 48 combinations of 3 people in a group of 6 with repetition.

Explanation:

1. There are 16 ways to choose 3 people from a group of 6.

2. There are 12 ways to choose 3 people without repetition from a group of 9.

3. There are 48 ways to choose 3 people with repetition from a group of 6.
There are 10 ways to choose 3 people from a group of 3 with repetition.

4.

The number of permutations is the same as the number of combinations.

The number of permutations with


### Supervised Learning Prompt

In [None]:
prompt = template.format(
    instruction="What is Supervised Learning?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What is Supervised Learning?

Response:
Supervised learning is a type of machine learning in which the algorithm learns to map from a data set to a class label. In other words, it takes an input and produces a predicted output. The algorithm is given a set of training data with known class labels (e.g. 0 and 1 for binary classification). The goal is for the algorithm to predict the class label of new input data with the same accuracy as the training data.

Supervised learning is used in many fields, including computer vision, natural language processing, and medical imaging. The most common supervised learning algorithms are linear models, decision trees, and artificial neural networks.


## LoRA Fine-tuning

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m2346/2346[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3372s[0m 1s/step - loss: 0.4749 - sparse_categorical_accuracy: 0.6215


<keras.src.callbacks.history.History at 0x7c1fe4667c70>

In [None]:
# Save the fine-tuned model
gemma_lm.save("/content/drive/MyDrive/Colab Notebooks/cs221/fine_tuned_model_1.keras")


In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [None]:
# Load the fine-tuned model

# loaded_model = keras.models.load_model("/content/drive/MyDrive/Colab Notebooks/cs221/fine_tuned_model_1.keras")

# Use the loaded model for generation
# prompt = template.format(
#     instruction="What is Supervised Learning?",
#     response="",
# )
# sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
# loaded_model.compile(sampler=sampler)
# generated_text = loaded_model.generate(prompt, max_length=256)
# print(generated_text)

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4194304256 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         8B
              constant allocation:         0B
        maybe_live_out allocation:    1.95GiB
     preallocated temp allocation:    3.91GiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:    5.86GiB
              total fragmentation:       240B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 1.95GiB
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/mul" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: f32[256000,2048]
		==========================

	Buffer 2:
		Size: 1000.00MiB
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: u32[262144000]
		==========================

	Buffer 3:
		Size: 1000.00MiB
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: u32[262144000]
		==========================

	Buffer 4:
		Size: 1000.00MiB
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: u32[262144000]
		==========================

	Buffer 5:
		Size: 1000.00MiB
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: u32[262144000]
		==========================

	Buffer 6:
		Size: 16B
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: (u32[262144000], u32[262144000])
		==========================

	Buffer 7:
		Size: 16B
		Operator: op_name="jit(_normal)/jit(main)/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/usr/local/lib/python3.10/dist-packages/keras/src/backend/jax/random.py" source_line=19
		XLA Label: fusion
		Shape: (u32[262144000], u32[262144000])
		==========================

	Buffer 8:
		Size: 8B
		Entry Parameter Subshape: u32[2]
		==========================



## Inference after fine-tuning

### Probability Prompt

In [None]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
prompt = template.format(
    instruction="What is the difference between permutations and combinations?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What is the difference between permutations and combinations?

Response:
Combinations are a special type of permutation, where order does not matter. The order of the elements in the combination matters, but the total number of possible combinations (i.e., the cardinality) is the same. Permutations are combinations where order matters and the total number of possible permutations is different for each combination, even when the combination size (number of elements) is the same.


### Supervised Learning Prompt

In [None]:
prompt = template.format(
    instruction="What is Supervised Learning?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What is Supervised Learning?

Response:
Supervised Learning refers to the learning of a function or model that can accurately predict the outcome of a dependent variable based on a set of input variables. In supervised learning, the input variables are known as features and the output variable is the dependent variable, which is used to predict the value of some other variable. The supervised learning problem is formulated by specifying a function f that maps the input features to the output value. The goal is to learn the function f by training the model with a set of examples, where the examples are pairs of inputs and outputs.


To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.

Try Alpaca's configuration below

| Hyperparameter | LLaMA-7B | LLaMA-13B |
|----------------|----------|-----------|
| Batch size     | 128      | 128       |
| Learning rate  | 2e-5     | 1e-5      |
| Epochs         | 3        | 5         |
| Max length     | 512      | 512       |
| Weight decay   | 0        | 0         |
