In [None]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime

BATCH_SIZE = 16
IMAGE_SIZE = 128

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model
unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
                ],
            num_res_blocks=2,
            num_middle_res_blocks=1
)

In [3]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

In [4]:
text_encoder = defaultTextEncodeModel()

2025-04-08 05:32:54.024023: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744090374.049239  527485 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744090374.056681  527485 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744090374.075269  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075312  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075314  527485 computation_placer.cc:177] computation placer alr

In [5]:
# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield tokens

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)

In [6]:
# Define optimizer
solver = optax.adam(2e-4)

# Create trainer
trainer = DiffusionTrainer(
    unet, optimizer=solver, 
    input_shapes=input_shapes,
    noise_schedule=edm_schedule,
    rngs=jax.random.PRNGKey(4), 
    name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
    model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
    encoder=text_encoder,
    distributed_training=True,
    wandb_config = {
        "project": 'mlops-msml605-project',
        "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
})


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)

		Epoch 96: 600step [00:30, 19.58step/s, loss=0.0896]                                               

[32mEpoch done on index 0 => 96 Loss: 0.07253348082304001[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 96 completed. Avg Loss: 0.07253348082304001, Time: 30.64s, Best Loss: 0.07236480712890625[0m
Validation started for process index 0



100%|██████████| 200/200 [00:25<00:00,  7.85it/s]


[32mValidation done on process index 0[0m

Epoch 97/2000


		Epoch 97:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0821]

First batch loaded at step 49567
Training started for process index 0 at step 49567


		Epoch 97: 600step [00:32, 18.65step/s, loss=0.0585]                                               

[32mEpoch done on index 0 => 97 Loss: 0.07285355776548386[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 97 completed. Avg Loss: 0.07285355776548386, Time: 32.17s, Best Loss: 0.07236480712890625[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.46it/s]


[32mValidation done on process index 0[0m

Epoch 98/2000


		Epoch 98:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0638]

First batch loaded at step 50078
Training started for process index 0 at step 50078


		Epoch 98: 600step [00:33, 18.15step/s, loss=0.0703]                                               

[32mEpoch done on index 0 => 98 Loss: 0.07217347621917725[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 98 step 50589





[32m
	Epoch 98 completed. Avg Loss: 0.07217347621917725, Time: 33.07s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.47it/s]


[32mValidation done on process index 0[0m

Epoch 99/2000


		Epoch 99:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0725]

First batch loaded at step 50589
Training started for process index 0 at step 50589


		Epoch 99: 600step [00:31, 19.33step/s, loss=0.0665]                                               

[32mEpoch done on index 0 => 99 Loss: 0.0727015808224678[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 99 completed. Avg Loss: 0.0727015808224678, Time: 31.04s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:23<00:00,  8.56it/s]


[32mValidation done on process index 0[0m

Epoch 100/2000


		Epoch 100:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0771]

First batch loaded at step 51100
Training started for process index 0 at step 51100


		Epoch 100: 600step [00:30, 19.93step/s, loss=0.0671]                                              

[32mEpoch done on index 0 => 100 Loss: 0.07300964742898941[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 100 completed. Avg Loss: 0.07300964742898941, Time: 30.10s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.14it/s]


[32mValidation done on process index 0[0m

Epoch 101/2000


		Epoch 101:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0830]

First batch loaded at step 51611
Training started for process index 0 at step 51611


		Epoch 101: 600step [00:29, 20.05step/s, loss=0.1006]                                              

[32mEpoch done on index 0 => 101 Loss: 0.07275137305259705[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 101 completed. Avg Loss: 0.07275137305259705, Time: 29.94s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  8.93it/s]


[32mValidation done on process index 0[0m

Epoch 102/2000


		Epoch 102:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0616]

First batch loaded at step 52122
Training started for process index 0 at step 52122


		Epoch 102:  78%|█████████████████████████▊       | 400/511 [00:17<00:05, 21.07step/s, loss=0.0650]