In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ShieldGemma Model Container

![shieldgemma container architecture](images/2.1-ShieldGemma-Container.png)

## Install Keras, KaggleHub and Jax

Install libraries:

- Keras
- KerasNLP
- Kaggle Hub

In [None]:
! pip install -q -U "keras >= 3.0, <4.0" "keras-nlp > 0.14.1" kagglehub

Install jax (choose one of the options)

In [None]:
# CPU
# !pip install -U jax

In [None]:
# GPU (NVIDIA, CUDA12)
!pip install -U "jax[cuda12]"

In [None]:
# TPU
# !pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

## Configure your runtime

Initializes the Python and Environment variables that Keras uses to configure the deep learning runtime (JAX, TensorFlow, or Torch). these must be set before Keras is imported. Learn more at https://keras.io/getting_started/#configuring-your-backend.

In [None]:
DL_RUNTIME = 'jax'
MODEL_VARIANT = 'shieldgemma_2b_en' # @param ["shieldgemma_2b_en", "shieldgemma_9b_en", "shieldgemma_27b_en"]
MAX_SEQUENCE_LENGTH = 512

import os

os.environ["KERAS_BACKEND"] = DL_RUNTIME

## Install dependencies and authenticate with Kaggle

Install the latest version of KerasNLP and then present an HTML form for you to enter your Kaggle username and token.Learn more at https://www.kaggle.com/docs/api#authentication.

In [None]:
from collections.abc import Sequence
import enum

import keras
import keras_nlp

# ShieldGemma is only provided in bfloat16 checkpoints.
keras.config.set_floatx("bfloat16")

In [None]:
import kagglehub
kagglehub.login()

# TODO