In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tune Gemma 2 with LoRA (FP32 Precision)\n",
    "\n",
    "This notebook shows how to fine-tune the Gemma 2 (2B) model using the LoRA (Low-Rank Adaptation) technique, using standard FP32 precision."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages\n",
    "!pip install -q -U wandb keras-nlp \"keras>=3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the required modules\n",
    "import os\n",
    "import json\n",
    "import wandb\n",
    "import random\n",
    "import math\n",
    "import keras\n",
    "import keras_nlp\n",
    "import gc\n",
    "from keras.callbacks import EarlyStopping\n",
    "from wandb.integration.keras import WandbMetricsLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# log in Weights & Biases\n",
    "wandb.login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selecting a backend\n",
    "\n",
    "Keras is a high-level, multi-framework deep learning API designed to be simple and easy to use. With Keras 3, you can run your workflows on one of three backends: TensorFlow, JAX, or PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
       "os.environ[\"KERAS_BACKEND\"] = \"jax\" 
       # or \"torch\" or \"tensorflow\".\n# Avoid memory fragmentation on the JAX backend.\nos.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"\n\n# Make sure the GPU is available\n!nvidia-smi\n