In [1]:
import os
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
# backend
os.environ["KERAS_BACKEND"] = "jax"
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [3]:
import keras
import keras_nlp



In [4]:
import jax

jax.devices()

[cuda(id=0), cuda(id=1)]

In [5]:
 keras.mixed_precision.set_global_policy('mixed_bfloat16')

The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new Keras 3 [distribution API guide](https://keras.io/guides/distribution/).

In [6]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
print(keras.distribution.list_devices())
device_mesh = keras.distribution.DeviceMesh(
    (1, 2), ["batch", "model"], devices=keras.distribution.list_devices()
)

['gpu:0', 'gpu:1']


LayoutMap from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys, for example, token_embedding/embeddings below, which are treated like regex to match tensor paths. Matched tensors are sharded with model dimensions (8 TPUs); others will be fully replicated.

In [7]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim,
    None,
    None,
)

layout_map["decoder_block.*attention_output.*kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

In [8]:
model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch"
)

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/task.json...


ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('batch': 1, 'model': 2), spec=PartitionSpec('model', None, None)), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 2048, 256))

In [None]:
decoder_block_1 = gemma_lm.backbone.get_layer("decoder_block_1")
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
    print(
        f"{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}"
    )

In [None]:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

In [None]:
import json
from icecream import ic

data = []
with open("databricks-dolly-15k.jsonl") as f:
    for line in f:
        features = json.loads(line)
        if features["context"]:
            continue
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

ic(len(data))
data = data[:5000]
ic(len(data))

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

In [None]:
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)