<a href="https://colab.research.google.com/github/SeoyeonPark1223/Gemma_FineTuning/blob/main/2nd_slang_lora_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = 'trispark'
os.environ["KAGGLE_KEY"] = userdata.get('trispark')

In [None]:
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"

In [None]:
os.environ["KERAS_BACKEND"]= 'jax'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [None]:
import keras
import keras_nlp

In [None]:
import pandas as pd

## Load Dataset

In [None]:
slang_dataset = pd.read_csv("/content/drive/MyDrive/MLB_Kaggle_Gemma/all_slang_only_words.csv")

In [None]:
slang_data = []

for index, row in slang_dataset.iterrows():
    # Instruction prompts the user to input the context
    instruction = (
        "Given the context below, create a new Gen Z slang term. ",
        "The slang should be catchy, easy to use, and relevant to modern youth culture. ",
        "Make sure it's something that would feel natural in casual conversation:\n\n",
        "Context: " + row['Context'],
        "Make sure that you should provide slang, description, and example as given."
    )

    # Response provides the description and example for the slang
    response = (
        "Slang: {slang}\n\n"
        "Description: {description}\n\n"
        "Example: {example}".format(
            slang=row['Slang'],
            description=row['Description'],
            example=row['Example']
        )
    )

    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    slang_data.append(template.format(instruction=instruction, response=response))

## Load Model + LoRA fine-tuning

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

In [None]:
gemma_lm.backbone.enable_lora(rank=8)

In [None]:
gemma_lm.summary()

In [None]:
# Limit the input sequence length to 256 (to control memory usage)
gemma_lm.preprocessor.sequence_length = 256

# Use AdamW (optimizer for transformer models)
optimizer = keras.optimizers.AdamW(
    learning_rate = 5e-5,
    weight_decay = 0.01,
)

# Exclude layernorm and bias terms from decay
optimizer.exclude_from_weight_decay(var_names=['bias', 'scale'])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

gemma_lm.fit(slang_data, epochs=10, batch_size=1)

Epoch 1/10
[1m1779/1779[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1476s[0m 807ms/step - loss: 0.5051 - sparse_categorical_accuracy: 0.7887
Epoch 2/10
[1m1779/1779[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1437s[0m 794ms/step - loss: 0.2686 - sparse_categorical_accuracy: 0.8726
Epoch 3/10
[1m1654/1779[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m1:39[0m 794ms/step - loss: 0.2570 - sparse_categorical_accuracy: 0.8765

## Inference (which is soooo bad)

In [None]:
tag = "Given the context below, create a new Gen Z slang term. The slang should be catchy, easy to use, and relevant to modern youth culture. Make sure it's something that would feel natural in casual conversation:\n\n"

prompt = template.format(
    instruction = tag + "casual conversation.",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=20, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))