In [1]:
from flaxdiff.schedulers import EDMNoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime

BATCH_SIZE = 16
IMAGE_SIZE = 128

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)

# Define model
unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
                ],
            num_res_blocks=2,
            num_middle_res_blocks=1
)

In [3]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

In [4]:
text_encoder = defaultTextEncodeModel()

2025-04-08 05:32:54.024023: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744090374.049239  527485 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744090374.056681  527485 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744090374.075269  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075312  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075314  527485 computation_placer.cc:177] computation placer alr

In [5]:
# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield tokens

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)

In [6]:
# Define optimizer
solver = optax.adam(2e-4)

# Create trainer
trainer = DiffusionTrainer(
    unet, optimizer=solver, 
    input_shapes=input_shapes,
    noise_schedule=edm_schedule,
    rngs=jax.random.PRNGKey(4), 
    name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
    model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
    encoder=text_encoder,
    distributed_training=True,
    wandb_config = {
        "project": 'mlops-msml605-project',
        "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
})


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler)

Validation run for sanity check for process index 0


100%|██████████| 200/200 [00:20<00:00,  9.64it/s]


[32mSanity Validation done on process index 0[0m

Epoch 0/2000


		Epoch 0:   0%|                                                          | 0/511 [00:00<?, ?step/s]

First batch loaded at step 0


		Epoch 0:  20%|██████▊                            | 100/511 [02:13<09:07,  1.33s/step, loss=0.6530]

Training started for process index 0 at step 0


		Epoch 0: 600step [04:54,  2.04step/s, loss=0.1422]                                                

[32mEpoch done on index 0 => 0 Loss: 0.2761317491531372[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 0 step 511





[32m
	Epoch 0 completed. Avg Loss: 0.2761317491531372, Time: 294.23s, Best Loss: 0.2761317491531372[0m
Validation started for process index 0


100%|██████████| 200/200 [00:22<00:00,  8.95it/s]


[32mValidation done on process index 0[0m

Epoch 1/2000


		Epoch 1:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.1055]

First batch loaded at step 511
Training started for process index 0 at step 511


		Epoch 1: 600step [00:29, 20.51step/s, loss=0.1218]                                                

[32mEpoch done on index 0 => 1 Loss: 0.12381739169359207[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 1 step 1022





[32m
	Epoch 1 completed. Avg Loss: 0.12381739169359207, Time: 29.25s, Best Loss: 0.12381739169359207[0m
Validation started for process index 0


100%|██████████| 200/200 [00:26<00:00,  7.65it/s]


[32mValidation done on process index 0[0m

Epoch 2/2000


		Epoch 2:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0839]

First batch loaded at step 1022
Training started for process index 0 at step 1022


		Epoch 2: 600step [00:28, 20.76step/s, loss=0.0858]                                                

[32mEpoch done on index 0 => 2 Loss: 0.11202948540449142[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 2 step 1533





[32m
	Epoch 2 completed. Avg Loss: 0.11202948540449142, Time: 28.91s, Best Loss: 0.11202948540449142[0m
Validation started for process index 0


100%|██████████| 200/200 [00:22<00:00,  8.93it/s]


[32mValidation done on process index 0[0m

Epoch 3/2000


		Epoch 3:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0857]

First batch loaded at step 1533
Training started for process index 0 at step 1533


		Epoch 3: 600step [00:29, 20.41step/s, loss=0.0974]                                                

[32mEpoch done on index 0 => 3 Loss: 0.10366462171077728[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 3 step 2044





[32m
	Epoch 3 completed. Avg Loss: 0.10366462171077728, Time: 29.41s, Best Loss: 0.10366462171077728[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.30it/s]


[32mValidation done on process index 0[0m

Epoch 4/2000


		Epoch 4:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0789]

First batch loaded at step 2044
Training started for process index 0 at step 2044


		Epoch 4: 600step [00:29, 20.51step/s, loss=0.1152]                                                

[32mEpoch done on index 0 => 4 Loss: 0.09995977580547333[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 4 step 2555





[32m
	Epoch 4 completed. Avg Loss: 0.09995977580547333, Time: 29.25s, Best Loss: 0.09995977580547333[0m
Validation started for process index 0


100%|██████████| 200/200 [00:22<00:00,  9.08it/s]


[32mValidation done on process index 0[0m

Epoch 5/2000


		Epoch 5:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0832]

First batch loaded at step 2555
Training started for process index 0 at step 2555


		Epoch 5: 600step [00:29, 20.07step/s, loss=0.0725]                                                

[32mEpoch done on index 0 => 5 Loss: 0.09577960520982742[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 5 step 3066





[32m
	Epoch 5 completed. Avg Loss: 0.09577960520982742, Time: 29.89s, Best Loss: 0.09577960520982742[0m
Validation started for process index 0


100%|██████████| 200/200 [00:22<00:00,  8.96it/s]


[32mValidation done on process index 0[0m

Epoch 6/2000


		Epoch 6:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0977]

First batch loaded at step 3066
Training started for process index 0 at step 3066


		Epoch 6: 600step [00:30, 19.96step/s, loss=0.0731]                                                

[32mEpoch done on index 0 => 6 Loss: 0.09361319243907928[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 6 step 3577





[32m
	Epoch 6 completed. Avg Loss: 0.09361319243907928, Time: 30.06s, Best Loss: 0.09361319243907928[0m
Validation started for process index 0


100%|██████████| 200/200 [00:23<00:00,  8.57it/s]


[32mValidation done on process index 0[0m

Epoch 7/2000


		Epoch 7:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0826]

First batch loaded at step 3577
Training started for process index 0 at step 3577


		Epoch 7: 600step [00:30, 19.44step/s, loss=0.0950]                                                

[32mEpoch done on index 0 => 7 Loss: 0.08971632272005081[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 7 step 4088





[32m
	Epoch 7 completed. Avg Loss: 0.08971632272005081, Time: 30.86s, Best Loss: 0.08971632272005081[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.41it/s]


[32mValidation done on process index 0[0m

Epoch 8/2000


		Epoch 8:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0840]

First batch loaded at step 4088
Training started for process index 0 at step 4088


		Epoch 8: 600step [00:31, 18.98step/s, loss=0.0797]                                                

[32mEpoch done on index 0 => 8 Loss: 0.08828653395175934[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 8 step 4599





[32m
	Epoch 8 completed. Avg Loss: 0.08828653395175934, Time: 31.62s, Best Loss: 0.08828653395175934[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.35it/s]


[32mValidation done on process index 0[0m

Epoch 9/2000


		Epoch 9:   0%|                                             | 0/511 [00:00<?, ?step/s, loss=0.0987]

First batch loaded at step 4599
Training started for process index 0 at step 4599


		Epoch 9: 600step [00:29, 20.32step/s, loss=0.0783]                                                

[32mEpoch done on index 0 => 9 Loss: 0.08612735569477081[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 9 step 5110





[32m
	Epoch 9 completed. Avg Loss: 0.08612735569477081, Time: 29.53s, Best Loss: 0.08612735569477081[0m
Validation started for process index 0


100%|██████████| 200/200 [00:20<00:00,  9.59it/s]


[32mValidation done on process index 0[0m

Epoch 10/2000


		Epoch 10:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0739]

First batch loaded at step 5110
Training started for process index 0 at step 5110


		Epoch 10: 600step [00:28, 21.04step/s, loss=0.1018]                                               

[32mEpoch done on index 0 => 10 Loss: 0.08656030148267746[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 10 completed. Avg Loss: 0.08656030148267746, Time: 28.52s, Best Loss: 0.08612735569477081[0m
Validation started for process index 0



100%|██████████| 200/200 [00:23<00:00,  8.61it/s]


[32mValidation done on process index 0[0m

Epoch 11/2000


		Epoch 11:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0724]

First batch loaded at step 5621
Training started for process index 0 at step 5621


		Epoch 11: 600step [00:28, 21.02step/s, loss=0.0815]                                               

[32mEpoch done on index 0 => 11 Loss: 0.08616255223751068[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 11 completed. Avg Loss: 0.08616255223751068, Time: 28.55s, Best Loss: 0.08612735569477081[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  9.04it/s]


[32mValidation done on process index 0[0m

Epoch 12/2000


		Epoch 12:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0773]

First batch loaded at step 6132
Training started for process index 0 at step 6132


		Epoch 12: 600step [00:30, 19.43step/s, loss=0.0953]                                               

[32mEpoch done on index 0 => 12 Loss: 0.0841614380478859[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 12 step 6643





[32m
	Epoch 12 completed. Avg Loss: 0.0841614380478859, Time: 30.89s, Best Loss: 0.0841614380478859[0m
Validation started for process index 0


100%|██████████| 200/200 [00:20<00:00,  9.59it/s]


[32mValidation done on process index 0[0m

Epoch 13/2000


		Epoch 13:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0875]

First batch loaded at step 6643
Training started for process index 0 at step 6643


		Epoch 13: 600step [00:32, 18.40step/s, loss=0.0971]                                               

[32mEpoch done on index 0 => 13 Loss: 0.08364305645227432[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 13 step 7154





[32m
	Epoch 13 completed. Avg Loss: 0.08364305645227432, Time: 32.62s, Best Loss: 0.08364305645227432[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.47it/s]


[32mValidation done on process index 0[0m

Epoch 14/2000


		Epoch 14:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0908]

First batch loaded at step 7154
Training started for process index 0 at step 7154


		Epoch 14: 600step [00:31, 19.07step/s, loss=0.1048]                                               

[32mEpoch done on index 0 => 14 Loss: 0.08233291655778885[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 14 step 7665





[32m
	Epoch 14 completed. Avg Loss: 0.08233291655778885, Time: 31.47s, Best Loss: 0.08233291655778885[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.27it/s]


[32mValidation done on process index 0[0m

Epoch 15/2000


		Epoch 15:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0695]

First batch loaded at step 7665
Training started for process index 0 at step 7665


		Epoch 15: 600step [00:32, 18.55step/s, loss=0.1131]                                               

[32mEpoch done on index 0 => 15 Loss: 0.08270074427127838[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 15 completed. Avg Loss: 0.08270074427127838, Time: 32.35s, Best Loss: 0.08233291655778885[0m
Validation started for process index 0



  0%|          | 0/200 [00:00<?, ?it/s]