From 22de090abf5ee622213a6f2d71dd75c0468465ed Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Wed, 18 Mar 2026 20:39:41 +0000 Subject: [PATCH] Distillation optimizer fix --- .../trainers/post_train/distillation/train_distill.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 85eb045bfe..bb416a1dad 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -199,7 +199,11 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer): """ def __init__(self, model, strategy, optimizer, training_config, **kwargs): - super().__init__(model=model, optimizer=optimizer, training_config=training_config, **kwargs) + # We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly + # allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before + # redefining the trainer optimizer here. + dummy_optimizer = optax.set_to_zero() + super().__init__(model=model, optimizer=dummy_optimizer, training_config=training_config, **kwargs) self.strategy = strategy