In [1]:
from flaxdiff.schedulers import EDMNoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime

BATCH_SIZE = 16
IMAGE_SIZE = 128

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)

# Define model
unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
                ],
            num_res_blocks=2,
            num_middle_res_blocks=1
)

In [3]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

In [None]:
text_encoder = defaultTextEncodeModel()

2025-04-08 05:32:54.024023: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744090374.049239  527485 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744090374.056681  527485 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744090374.075269  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075312  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075314  527485 computation_placer.cc:177] computation placer alr

In [None]:
# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield tokens

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)

In [None]:
# Define optimizer
solver = optax.adam(2e-4)

# Create trainer
trainer = DiffusionTrainer(
    unet, optimizer=solver, 
    input_shapes=input_shapes,
    noise_schedule=edm_schedule,
    rngs=jax.random.PRNGKey(4), 
    name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
    model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
    encoder=text_encoder,
    distributed_training=True,
    wandb_config = {
        "project": 'mlops-msml605-project',
        "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
})


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler)

Validation run for sanity check for process index 0


100%|██████████| 200/200 [00:14<00:00, 13.35it/s]


[32mSanity Validation done on process index 0[0m

Epoch 0/2000


		Epoch 0:   0%|                                                          | 0/511 [00:00<?, ?step/s]

First batch loaded at step 0


		Epoch 0:  20%|██████▊                            | 100/511 [02:04<08:30,  1.24s/step, loss=0.6457]

Training started for process index 0 at step 0


		Epoch 0: 600step [04:27,  2.24step/s, loss=0.1895]                                                

[32mEpoch done on index 0 => 0 Loss: 0.29585331678390503[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 0 step 511



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/511
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=ti

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/511.orbax-checkpoint-tmp-0/default.orbax-checkpoint-tmp-1
[32m
	Epoch 0 completed. Avg Loss: 0.29585331678390503, Time: 267.42s, Best Loss: 0.29585331678390503[0m
Validation started for process index 0


100%|██████████| 200/200 [00:16<00:00, 11.79it/s]


[32mValidation done on process index 0[0m

Epoch 1/2000


		Epoch 1:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1408]

First batch loaded at step 511
Training started for process index 0 at step 511


		Epoch 1: 600step [00:25, 23.40step/s, loss=0.1625]                                                

[32mEpoch done on index 0 => 1 Loss: 0.1607334315776825[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 1 step 1022



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/1022
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/1022.orbax-checkpoint-tmp-2/default.orbax-checkpoint-tmp-3
[32m
	Epoch 1 completed. Avg Loss: 0.1607334315776825, Time: 25.65s, Best Loss: 0.1607334315776825[0m
Validation started for process index 0


100%|██████████| 200/200 [00:17<00:00, 11.58it/s]


[32mValidation done on process index 0[0m

Epoch 2/2000


		Epoch 2:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1190]

First batch loaded at step 1022
Training started for process index 0 at step 1022


		Epoch 2: 600step [00:25, 23.49step/s, loss=0.1132]                                                

[32mEpoch done on index 0 => 2 Loss: 0.1495736837387085[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 2 step 1533



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/1533
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/1533.orbax-checkpoint-tmp-4/default.orbax-checkpoint-tmp-5
[32m
	Epoch 2 completed. Avg Loss: 0.1495736837387085, Time: 25.54s, Best Loss: 0.1495736837387085[0m
Validation started for process index 0


100%|██████████| 200/200 [00:16<00:00, 12.02it/s]


[32mValidation done on process index 0[0m

Epoch 3/2000


		Epoch 3:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1161]

First batch loaded at step 1533
Training started for process index 0 at step 1533


		Epoch 3: 600step [00:26, 23.05step/s, loss=0.1335]                                                

[32mEpoch done on index 0 => 3 Loss: 0.1404421031475067[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 3 step 2044



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/2044
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/2044.orbax-checkpoint-tmp-6/default.orbax-checkpoint-tmp-7
[32m
	Epoch 3 completed. Avg Loss: 0.1404421031475067, Time: 26.03s, Best Loss: 0.1404421031475067[0m
Validation started for process index 0


100%|██████████| 200/200 [00:16<00:00, 11.87it/s] 


[32mValidation done on process index 0[0m

Epoch 4/2000


		Epoch 4:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1122]

First batch loaded at step 2044
Training started for process index 0 at step 2044


		Epoch 4: 600step [00:26, 22.80step/s, loss=0.1573]                                                

[32mEpoch done on index 0 => 4 Loss: 0.13622280955314636[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 4 step 2555



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/2555
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/2555.orbax-checkpoint-tmp-8/default.orbax-checkpoint-tmp-9
[32m
	Epoch 4 completed. Avg Loss: 0.13622280955314636, Time: 26.32s, Best Loss: 0.13622280955314636[0m
Validation started for process index 0


100%|██████████| 200/200 [00:16<00:00, 12.33it/s]


[32mValidation done on process index 0[0m

Epoch 5/2000


		Epoch 5:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1176]

First batch loaded at step 2555
Training started for process index 0 at step 2555


		Epoch 5: 600step [00:26, 22.76step/s, loss=0.1072]                                                

[32mEpoch done on index 0 => 5 Loss: 0.13253240287303925[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 5 step 3066



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/3066
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/3066.orbax-checkpoint-tmp-10/default.orbax-checkpoint-tmp-11
[32m
	Epoch 5 completed. Avg Loss: 0.13253240287303925, Time: 26.37s, Best Loss: 0.13253240287303925[0m
Validation started for process index 0


100%|██████████| 200/200 [00:16<00:00, 12.35it/s]


[32mValidation done on process index 0[0m

Epoch 6/2000


		Epoch 6:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1273]

First batch loaded at step 3066
Training started for process index 0 at step 3066


		Epoch 6: 600step [00:26, 22.65step/s, loss=0.1068]                                                

[32mEpoch done on index 0 => 6 Loss: 0.1299276053905487[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 6 step 3577



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/3577
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/3577.orbax-checkpoint-tmp-12/default.orbax-checkpoint-tmp-13
[32m
	Epoch 6 completed. Avg Loss: 0.1299276053905487, Time: 26.49s, Best Loss: 0.1299276053905487[0m
Validation started for process index 0


100%|██████████| 200/200 [00:16<00:00, 12.18it/s]


[32mValidation done on process index 0[0m

Epoch 7/2000


		Epoch 7:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1179]

First batch loaded at step 3577
Training started for process index 0 at step 3577


		Epoch 7: 600step [00:27, 21.45step/s, loss=0.1304]                                                

[32mEpoch done on index 0 => 7 Loss: 0.12540051341056824[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 7 step 4088



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/4088
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 85, in _background_wait_for_commit_futures
    commit_future.result()
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 130, in result
    f.result(timeout=time_remaining)
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/orbax/checkpoint/_src/futures/future.py", line 297, in result
    return self._t.join(timeout=t

Error saving checkpoint Checkpoint path should be absolute. Got checkpoints/diffusion_sde_ve_2025-04-08_05:13:31/4088.orbax-checkpoint-tmp-14/default.orbax-checkpoint-tmp-15
[32m
	Epoch 7 completed. Avg Loss: 0.12540051341056824, Time: 27.98s, Best Loss: 0.12540051341056824[0m
Validation started for process index 0


  0%|          | 0/200 [00:00<?, ?it/s]