# Knowledge Distillation with Tunix: Gemma 7B to Gemma 2B

This notebook demonstrates how to use the **Tunix** library to perform knowledge distillation. Specifically, we will use **logit-based distillation** to transfer knowledge from a larger, more capable **teacher model (Gemma 7B)** to a smaller, more efficient **student model (Gemma 2B)**.

## What is Knowledge Distillation?
Knowledge distillation is a model compression technique where a smaller "student" model is trained to mimic the behavior of a larger, pre-trained "teacher" model. Instead of training the student solely on the ground-truth labels, we also train it to replicate the teacher's outputs.

## Logit-Based Distillation
In this specific strategy, the student model learns to match the teacher's **logits** (the raw, unnormalized outputs before the final softmax layer). By doing so, the student learns the nuanced probability distribution that the teacher model has learned, which is often more informative than the hard labels alone.

The core components we'll use are:
-   **Teacher Model**: Gemma 7B
-   **Student Model**: Gemma 2B
-   **Distillation Strategy**: `tunix.distillation.strategies.LogitStrategy`
-   **Trainer**: `tunix.distillation.DistillationTrainer`

Let's get started!

In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain-nightly
!pip install -q datasets
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

In [None]:
import os
import gc

from flax import nnx
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp

from tunix.models.gemma import data as data_lib
from tunix.models.gemma import gemma as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.generate import sampler as sampler_lib
from tunix.distillation import strategies
from tunix.distillation import distillation_trainer

## Utility Function to check HBM

In [None]:
import functools
import humanize

def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

show_hbm_usage()

In [None]:
# --- Data ---
BATCH_SIZE = 4
MAX_TARGET_LENGTH = 128
NUM_TRAIN_EPOCHS = 1

# --- Model ---
MESH = [(1, 8), ("fsdp", "tp")]

# --- Training ---
MAX_STEPS = 200
EVAL_EVERY_N_STEPS = 50
LEARNING_RATE = 1e-4

# --- Distillation ---
TEMPERATURE = 2.0 # Softens the teacher's probabilities
ALPHA = 0.7       # Balances distillation loss and student's own task loss

# --- Checkpointing ---
TEACHER_CKPT_DIR = "/content/intermediate_ckpt/teacher/"
STUDENT_CKPT_DIR = "/content/intermediate_ckpt/student/"

First, we need to load our teacher and student models. We'll use Gemma 7B as the teacher and Gemma 2B as the student.

**Important**: You must have a Kaggle account and agree to the Gemma license to download the models. The first time you run this, you will be prompted to log in to Kaggle.

In [None]:
# Log in to Kaggle
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

def load_and_save_model(model_handle, version, ckpt_dir):
  """Loads a model from Kaggle, saves it locally, and cleans up memory."""
  print(f"Loading {model_handle}...")
  kaggle_ckpt_path = kagglehub.model_download(model_handle)
  ckpt_version = '2b-it'
  if '7b' in version:
    ckpt_version = '7b-it'
  # Temporarily set the default device to CPU for loading the full model
  with jax.default_device(jax.devices('cpu')[0]):
    params = params_lib.load_and_format_params(os.path.join(kaggle_ckpt_path, ckpt_version))
    gemma = gemma_lib.Transformer.from_params(params, version=version)

  print(f"Saving checkpoint to {ckpt_dir}...")
  checkpointer = ocp.StandardCheckpointer()
  _, state = nnx.split(gemma)
  checkpointer.save(os.path.join(ckpt_dir, "state"), state)
  checkpointer.wait_until_finished()
  # Clean up to save memory
  del params
  del gemma
  del state
  gc.collect()
  print(f"Finished processing {model_handle}.")

# Load Teacher Model (Gemma 7B)
load_and_save_model("google/gemma/flax/1.1-7b-it", "1.1-7b-it", TEACHER_CKPT_DIR)

# Load Student Model (Gemma 2B)
load_and_save_model("google/gemma/flax/1.1-2b-it", "1.1-2b-it", STUDENT_CKPT_DIR)

Now that we have the checkpoints saved locally, we can load them into sharded models. Sharding is essential for training large models efficiently on TPUs by distributing the model's weights and the computation across multiple devices.

In [None]:
def get_sharded_model(ckpt_path, model_config, mesh):
  """Loads a checkpoint into a sharded model."""
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: gemma_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
  )
  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma

mesh = jax.make_mesh(*MESH)

# Create Teacher Model
print("Creating sharded teacher model (Gemma 7B)...")
teacher_config = gemma_lib.TransformerConfig.gemma_7b()
teacher_model = get_sharded_model(os.path.join(TEACHER_CKPT_DIR, "state"), teacher_config, mesh)
print("Teacher model created.")
# nnx.display(teacher_model) # Optional: view model structure

# Create Student Model
print("\nCreating sharded student model (Gemma 2B)...")
student_config = gemma_lib.TransformerConfig.gemma_2b()
student_model = get_sharded_model(os.path.join(STUDENT_CKPT_DIR, "state"), student_config, mesh)
print("Student model created.")
# nnx.display(student_model) # Optional: view model structure

show_hbm_usage()

In [None]:
print("Loading tokenizer...")
gemma_tokenizer_path = os.path.join(kagglehub.model_download("google/gemma/flax/1.1-2b-it"), "tokenizer.model")
gemma_tokenizer = data_lib.GemmaTokenizer(gemma_tokenizer_path)
print("Tokenizer loaded.")

print("\nCreating datasets...")
train_ds, validation_ds = data_lib.create_datasets(
    dataset_name='mtnt/en-fr',
    global_batch_size=BATCH_SIZE,
    max_target_length=MAX_TARGET_LENGTH,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    tokenizer=gemma_tokenizer,
    instruct_tuned=True,
)
print("Datasets created.")

The `LogitStrategy` requires three key functions:
1.  `model_forward_fn`: A function that performs a forward pass for a given model and returns its logits. Since both our models are from the Gemma family, we can use the same function for both.
2.  `labels_fn`: A function that creates the ground-truth labels from the input data for the standard cross-entropy loss.
3.  `gen_model_input_fn`: A helper function to format each batch from the data loader into the dictionary format expected by the model.

In [None]:
VOCAB_SIZE = student_config.num_embed

def model_forward_fn(
    model: nnx.Module,
    input_tokens: jax.Array,
    input_mask: jax.Array,
    positions: jax.Array,
    attention_mask: jax.Array,
):
  """Performs a forward pass and returns the logits."""
  logits, _ = model(
      input_tokens,
      positions,
      None,
      attention_mask,
  )
  # Exclude the last step as it does not appear in the targets.
  return logits[:, :-1, :]


def labels_fn(
    input_tokens: jax.Array,
    input_mask: jax.Array,
    **kwargs,
):
  """Creates one-hot encoded labels for the next-token prediction task."""
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]
  labels = jax.nn.one_hot(target_tokens, VOCAB_SIZE)
  # Mask out the padding tokens from the loss calculation.
  return labels * target_mask.astype(labels.dtype)[..., None]


def gen_model_input_fn(x: distillation_trainer.TrainingInput):
  """Formats a batch from the data loader into the model's expected input format."""
  pad_mask = x.input_tokens != gemma_tokenizer.pad_id()
  positions = gemma_lib.build_positions_from_mask(pad_mask)
  attention_mask = gemma_lib.make_causal_attn_mask(pad_mask)
  return {
      'input_tokens': x.input_tokens,
      'input_mask': x.input_mask,
      'positions': positions,
      'attention_mask': attention_mask,
  }

Now we can assemble all the components. We'll instantiate the `LogitStrategy`, configure the `DistillationTrainer`, and start the training process. The trainer will handle the distributed training loop across the available TPU cores.

In [None]:
# 1. Setup the distillation strategy
strategy = strategies.LogitStrategy(
    student_forward_fn=model_forward_fn,
    teacher_forward_fn=model_forward_fn,
    labels_fn=labels_fn,
    temperature=TEMPERATURE,
    alpha=ALPHA,
)

# 2. Setup the training configuration
config = distillation_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
)

# 3. Setup the optimizer
optimizer = optax.adamw(LEARNING_RATE)


# Set teacher model in eval mode
teacher_model.eval()
# Set student model in train mode
student_model.train()
# 4. Instantiate the trainer
trainer = distillation_trainer.DistillationTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    strategy=strategy,
    optimizer=optimizer,
    training_config=config,
).with_gen_model_input_fn(gen_model_input_fn)

# 5. Run training within the mesh context
print("Starting distillation training...")
with mesh:
  trainer.train(train_ds, validation_ds)
print("Training complete.")

After training, the student model should have improved its ability to perform the translation task by learning from the teacher. Let's test it with a few sample prompts.

In [None]:
print("Setting up sampler for evaluation...")
sampler = sampler_lib.Sampler(
    transformer=student_model,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_TARGET_LENGTH + 64,
        num_layers=student_config.num_layers,
        num_kv_heads=student_config.num_kv_heads,
        head_dim=student_config.head_dim,
    ),
)

In [None]:
input_batch = [
    "Translate this into French:\nHello, my name is Morgane.\n",
    "Translate this into French:\nThis dish is delicious!\n",
    "Translate this into French:\nI am a student.\n",
]

print("Generating translations with the distilled student model...")
with mesh:
    out_data = sampler(
        input_strings=input_batch,
        total_generation_steps=20,
    )

print("\n--- Evaluation Results ---")
for input_string, out_string in zip(input_batch, out_data.text):
  print(f"----------------------")
  print(f"Prompt:\n{input_string}")
  print(f"Distilled Student's Output:\n{out_string}")