In [1]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime

BATCH_SIZE = 16
IMAGE_SIZE = 128

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

In [3]:
text_encoder = defaultTextEncodeModel()

2025-04-10 06:23:43.248339: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744266223.273050 2055796 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744266223.280744 2055796 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744266223.298347 2055796 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744266223.298373 2055796 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744266223.298376 2055796 computation_placer.cc:177] computation placer alr

In [4]:
# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield tokens

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)

In [None]:
from flax import linen as nn
from diffusers import FlaxUNet2DConditionModel

input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

# input_shapes = {
#     "sample": (3, IMAGE_SIZE, IMAGE_SIZE),
#     "timesteps": (),
#     "encoder_hidden_states": (77, 768)
# }
# Write a wrapper model around FlaxUNet2DConditionModel 

unet_model = FlaxUNet2DConditionModel(
    sample_size=IMAGE_SIZE,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 256, 512),  # the number of output channels for each UNet block
    cross_attention_dim=512,  # the size of the cross-attention layers
    dtype=jnp.bfloat16,
    use_memory_efficient_attention=True,
)
        
class BCHWModelWrapper(nn.Module):
    model: nn.Module

    @nn.compact
    def __call__(self, x, temb, textcontext):
        # Reshape the input to BCHW format from BHWC
        x = jnp.transpose(x, (0, 3, 1, 2))
        # Pass the input through the UNet model
        out = self.model(
            sample=x,
            timesteps=temb,
            encoder_hidden_states=textcontext,
        )
        # Reshape the output back to BHWC format
        out = jnp.transpose(out.sample, (0, 2, 3, 1))
        return out
    
unet = BCHWModelWrapper(unet_model)

In [None]:
# model = BCHWModelWrapper(unet_model)
params = unet.init(jax.random.PRNGKey(0), jnp.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3)), jnp.ones((1,)), jnp.ones((1, 77, 768)))

In [None]:
out = unet.apply(params, jnp.ones((4,IMAGE_SIZE, IMAGE_SIZE, 3)), jnp.ones((4,)), jnp.ones((4, 77, 768)))

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 16.00G. That was not possible. There are 13.93G free.; (0x0x0_HBM0)

In [5]:
input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
            ],
            num_res_blocks=2,
            num_middle_res_blocks=1
)

In [9]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model

# Define optimizer
solver = optax.adam(2e-4)

# Create trainer
trainer = DiffusionTrainer(
    unet, optimizer=solver, 
    input_shapes=input_shapes,
    noise_schedule=edm_schedule,
    rngs=jax.random.PRNGKey(4), 
    name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
    model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
    encoder=text_encoder,
    distributed_training=True,
    wandb_config = {
        "project": 'mlops-msml605-project',
        "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
    }
)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [10]:
trainer.summary()






In [None]:
ones = trainer.get_input_ones()

In [14]:
out = trainer.model.apply(
    trainer.state.params,
    **ones,
)

In [None]:
out.

AttributeError: 'FlaxUNet2DConditionOutput' object has no attribute 'shape'

In [7]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)

Using classifier-free guidance
Validation run for sanity check for process index 0


  0%|          | 0/200 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/mrwhite0racle/persist/FlaxDiff/flaxdiff/trainer/diffusion_trainer.py", line 320, in validation_loop
    samples = generate_samples(
              ^^^^^^^^^^^^^^^^^
  File "/home/mrwhite0racle/persist/FlaxDiff/flaxdiff/trainer/diffusion_trainer.py", line 291, in generate_samples
    samples = sampler.generate_images(
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrwhite0racle/persist/FlaxDiff/flaxdiff/samplers/common.py", line 162, in generate_images
    samples, rngstate = sample_step(sample_model_fn, rngstate, samples, current_step, next_step)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrwhite0racle/persist/FlaxDiff/flaxdiff/samplers/common.py", line 142, in sample_step
    samples, state = self.sample_step(sample_model_fn=sample_model_fn, current_samples=samples,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Error logging images to wandb Initializer expected to generate shape (3, 3, 3, 64) but got shape (3, 3, 128, 64) instead for parameter "kernel" in "/conv_in". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
[32mSanity Validation done on process index 0[0m

Epoch 0/2


		Epoch 0:   0%|                                                          | 0/511 [00:00<?, ?step/s]

First batch loaded at step 0


ScopeParamShapeError: Initializer expected to generate shape (3, 3, 3, 64) but got shape (3, 3, 128, 64) instead for parameter "kernel" in "/conv_in". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)