In [1]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer.general_diffusion_trainer import GeneralDiffusionTrainer, ConditionalInputConfig
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime
import argparse
import os

BATCH_SIZE = 16
IMAGE_SIZE = 256

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

text_encoder = defaultTextEncodeModel()
autoencoder = StableDiffusionVAE(**{"modelname": "pcuenq/sd-vae-ft-mse-flax"})

# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield {"text": tokens}

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)


2025-04-12 03:12:59.283474: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744427579.307033 3069880 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744427579.314157 3069880 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744427579.331699 3069880 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744427579.331741 3069880 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744427579.331743 3069880 computation_placer.cc:177] computation placer alr

Calculating downscale factor...
Downscale factor: 8
Latent channels: 4


In [3]:
from flax import linen as nn
from diffusers import FlaxUNet2DConditionModel
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig

input_config = DiffusionInputConfig(
    sample_data_key='image',
    sample_data_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    conditions=[
        ConditionalInputConfig(
            encoder=text_encoder,
            conditioning_data_key='text',
            pretokenized=True,
            unconditional_input="",
            model_key_override="textcontext",
        )
    ],
)

input_shapes = input_config.get_input_shapes(
    autoencoder=autoencoder,
)

unet_model = FlaxUNet2DConditionModel(
    sample_size=input_shapes["x"][1],  # the target image resolution
    in_channels=input_shapes["x"][2],  # the number of input channels, 3 for RGB images
    out_channels=input_shapes["x"][2],  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 256, 512),  # the number of output channels for each UNet block
    cross_attention_dim=512,  # the size of the cross-attention layers
    dtype=jnp.bfloat16,
    use_memory_efficient_attention=True,
)
        
class BCHWModelWrapper(nn.Module):
    model: nn.Module

    @nn.compact
    def __call__(self, x, temb, textcontext):
        # Reshape the input to BCHW format from BHWC
        x = jnp.transpose(x, (0, 3, 1, 2))
        # Pass the input through the UNet model
        out = self.model(
            sample=x,
            timesteps=temb,
            encoder_hidden_states=textcontext,
        )
        # Reshape the output back to BHWC format
        out = jnp.transpose(out.sample, (0, 2, 3, 1))
        return out
    
unet = BCHWModelWrapper(unet_model)

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


In [None]:
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig

input_config = DiffusionInputConfig(
    sample_data_key='image',
    sample_data_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    conditions=[
        ConditionalInputConfig(
            encoder=text_encoder,
            conditioning_data_keys='text',
            pretokenized=True,
            unconditional_input="",
            model_key_override="textcontext",
        )
    ]
)

input_shapes = input_config.get_input_shapes(
    autoencoder=autoencoder,
)

unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
            ],
            num_res_blocks=2,
            num_middle_res_blocks=1,
            dtype=jnp.bfloat16,
            output_channels=input_shapes["x"][2],
)

In [4]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model

# Define optimizer
solver = optax.adam(2e-4)

# Create the GeneralDiffusionTrainer
experiment_name = f"General_Diffusion_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"

trainer = GeneralDiffusionTrainer(
    unet,
    optimizer=solver,
    noise_schedule=edm_schedule,
    autoencoder=autoencoder,
    input_config=input_config,
    rngs=jax.random.PRNGKey(42),
    name=experiment_name,
    model_output_transform=KarrasPredictionTransform(
        sigma_data=edm_schedule.sigma_data),
    # data_key='image',  # Specify the key for image data in batches
    distributed_training=True,
    wandb_config={
        "project": 'mlops-msml605-project',
        "entity": 'umd-projects',
        "name": experiment_name,
        "config": {
            "batch_size": BATCH_SIZE,
            "image_size": IMAGE_SIZE,
            "architecture": "unet",
        }
    },
    native_resolution=IMAGE_SIZE
)

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)

Using classifier-free guidance
Validation run for sanity check for process index 0


100%|██████████| 200/200 [00:26<00:00,  7.49it/s]


[32mSanity Validation done on process index 0[0m

Epoch 0/2000


		Epoch 0:   0%|                                                          | 0/511 [00:00<?, ?step/s]

First batch loaded at step 0


In [1]:
from diffusers.models import AutoencoderKL, FlaxAutoencoderKL
vae = FlaxAutoencoderKL.from_pretrained("pcuenq/sd-vae-ft-mse-flax")


  from .autonotebook import tqdm as notebook_tqdm
2025-04-10 19:43:23.697291: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744314203.721345 2391839 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744314203.728729 2391839 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744314203.746453 2391839 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744314203.746479 2391839 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744314203.746481 2391839

In [2]:
vae, params = vae

In [3]:
params['encoder']

{'conv_in': {'bias': Array([ 4.96427529e-03,  2.53305554e-01,  2.12010033e-02, -1.41209185e-01,
          2.86090141e-03,  3.08178291e-02,  1.43235341e-01, -1.40822262e-01,
         -4.74144965e-02,  3.02515477e-02, -1.41965076e-02, -7.25287050e-02,
          1.98573947e-01,  5.43444715e-02, -1.91042889e-02, -2.30714515e-01,
         -1.20401338e-01,  2.07467690e-01,  9.51506793e-02,  1.60127580e-01,
          5.96853942e-02, -9.90543813e-02, -1.66639537e-01,  4.43165004e-03,
          2.21085712e-01, -2.34749496e-01, -2.44124085e-01,  1.86373934e-01,
          1.93680525e-01, -1.41324401e-01,  3.98991592e-02,  2.00074136e-01,
         -1.56219065e-01, -6.38444722e-02, -4.58901636e-02, -5.38855419e-02,
         -2.07794979e-02,  2.32121378e-01, -2.38970608e-01,  2.52936959e-01,
          8.12750682e-02, -1.83061007e-02, -1.70167655e-01,  1.51383013e-01,
         -3.59456018e-02, -1.67356193e-01,  2.28695408e-01, -2.18679756e-01,
          8.68279785e-02, -1.70631543e-01,  4.13337797e-0

In [7]:
vae

FlaxAutoencoderKL(
    # attributes
    in_channels = 3
    out_channels = 3
    down_block_types = ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
    up_block_types = ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
    block_out_channels = [128, 256, 512, 512]
    layers_per_block = 2
    act_fn = 'silu'
    latent_channels = 4
    norm_num_groups = 32
    sample_size = 256
    scaling_factor = 0.18215
    dtype = float32
)

In [8]:
256//8

32