In [1]:

from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer.general_diffusion_trainer import GeneralDiffusionTrainer, ConditionalInputConfig
from flaxdiff.data.dataloaders import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel, get_latest_checkpoint
from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime
import argparse
import os

BATCH_SIZE = 32
IMAGE_SIZE = 256

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load dataset
data = get_dataset_grain(
    "oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)

# data = get_dataset_grain(
#     "laiona_coco",
#     batch_size=BATCH_SIZE,
#     image_scale=IMAGE_SIZE,
#     dataset_source="/home/mrwhite0racle/gcs_mount",
#     method=None
# )
datalen = data['train_len']
batches = datalen // BATCH_SIZE

text_encoder = defaultTextEncodeModel()
autoencoder = StableDiffusionVAE(**{"modelname": "pcuenq/sd-vae-ft-mse-flax"})

2025-05-02 00:10:18.463312: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746144618.487117  317086 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746144618.494299  317086 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746144618.511533  317086 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746144618.511552  317086 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746144618.511554  317086 computation_placer.cc:177] computation placer alr

Scaling factor: 0.18215
Calculating downscale factor...
Downscale factor: 8
Latent channels: 4


In [3]:
from flax import linen as nn
from diffusers import FlaxUNet2DConditionModel
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig

input_config = DiffusionInputConfig(
    sample_data_key='image',
    sample_data_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    conditions=[
        ConditionalInputConfig(
            encoder=text_encoder,
            conditioning_data_key='text',
            pretokenized=True,
            unconditional_input="",
            model_key_override="textcontext",
        )
    ],
)

input_shapes = input_config.get_input_shapes(
    autoencoder=autoencoder,
)

unet_model = FlaxUNet2DConditionModel(
    sample_size=input_shapes["x"][1],  # the target image resolution
    # the number of input channels, 3 for RGB images
    in_channels=input_shapes["x"][2],
    out_channels=input_shapes["x"][2],  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    # the number of output channels for each UNet block
    block_out_channels=(64, 128, 256, 512),
    cross_attention_dim=512,  # the size of the cross-attention layers
    dtype=jnp.bfloat16,
    use_memory_efficient_attention=True,
)


class BCHWModelWrapper(nn.Module):
    model: nn.Module

    @nn.compact
    def __call__(self, x, temb, textcontext):
        # Reshape the input to BCHW format from BHWC
        x = jnp.transpose(x, (0, 3, 1, 2))
        # Pass the input through the UNet model
        out = self.model(
            sample=x,
            timesteps=temb,
            encoder_hidden_states=textcontext,
        )
        # Reshape the output back to BHWC format
        out = jnp.transpose(out.sample, (0, 2, 3, 1))
        return out
    
    @property
    def __dict__(self):
        return self.model.__dict__

unet = BCHWModelWrapper(unet_model)

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


In [3]:
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig

input_config = DiffusionInputConfig(
    sample_data_key='image',
    sample_data_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    conditions=[
        ConditionalInputConfig(
            encoder=text_encoder,
            conditioning_data_key='text',
            pretokenized=True,
            unconditional_input="",
            model_key_override="textcontext",
        )
    ]
)

input_shapes = input_config.get_input_shapes(
    autoencoder=autoencoder,
)

unet = Unet(emb_features=256,
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads": 8, "dtype": jnp.float32, "flash_attention": False,
                    "use_projection": False, "use_self_and_cross": True},
                {"heads": 8, "dtype": jnp.float32, "flash_attention": False,
                    "use_projection": False, "use_self_and_cross": True},
                {"heads": 8, "dtype": jnp.float32, "flash_attention": False,
                    "use_projection": False, "use_self_and_cross": True},
                {"heads": 8, "dtype": jnp.float32, "flash_attention": False,
                    "use_projection": False, "use_self_and_cross": False}
            ],
            num_res_blocks=2,
            num_middle_res_blocks=1,
            dtype=jnp.bfloat16,
            output_channels=input_shapes["x"][2],
    )

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


In [None]:
from flaxdiff.trainer.general_diffusion_trainer import EvaluationMetric
    
def get_clip_metric(
    modelname: str = "openai/clip-vit-large-patch14",
):
    from transformers import AutoProcessor, FlaxCLIPModel
    model = FlaxCLIPModel.from_pretrained(modelname)
    processor = AutoProcessor.from_pretrained(modelname)
    
    @jax.jit
    def calc(pixel_values, input_ids, attention_mask):
        # Get the logits
        generated_out = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
            
        gen_img_emb = generated_out.image_embeds
        txt_emb = generated_out.text_embeds

        # 1. Normalize embeddings (essential for cosine similarity/distance)
        gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
        txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)

        # 2. Calculate cosine similarity
        # Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
        # Calculate cosine similarity
        similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)

        scaled_distance = (1.0 - similarity)
        # 4. Average over the batch
        mean_scaled_distance = jnp.mean(scaled_distance)

        return mean_scaled_distance
        
    def clip_metric(
        generated: jnp.ndarray,
        batch
    ):
        original_conditions = batch['text']
        
        # Convert samples from [-1, 1] to [0, 255] and uint8
        generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
        
        generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
        
        pixel_values = generated_inputs['pixel_values']
        input_ids = original_conditions['input_ids']
        attention_mask = original_conditions['attention_mask']
        
        return calc(pixel_values, input_ids, attention_mask)
    
    return EvaluationMetric(
        function=clip_metric,
        name='clip_similarity'
    )

In [5]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(
    1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model

# Define optimizer
solver = optax.adam(2e-4)

# Create the GeneralDiffusionTrainer
experiment_name = f"oxford-LDM-General_Diffusion_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"

trainer = GeneralDiffusionTrainer(
    unet,
    optimizer=solver,
    noise_schedule=edm_schedule,
    autoencoder=autoencoder,
    input_config=input_config,
    rngs=jax.random.PRNGKey(42),
    name=experiment_name,
    model_output_transform=KarrasPredictionTransform(
        sigma_data=edm_schedule.sigma_data),
    # data_key='image',  # Specify the key for image data in batches
    distributed_training=True,
    eval_metrics=[
        get_clip_metric(modelname="openai/clip-vit-large-patch14"),
    ],
    wandb_config={
        "project": 'mlops-msml605-project',
        "entity": 'umd-projects',
        "name": experiment_name,
        "config": {
            "batch_size": BATCH_SIZE,
            "image_size": IMAGE_SIZE,
            "arguments": {
                "architecture": "unet",
                "dataset": "oxford_flowers102",
                "noise_schedule": "edm",
            }
        }
    },
    native_resolution=IMAGE_SIZE,
    # Path to the checkpoint
    # load_from_checkpoint="/home/mrwhite0racle/persist/FlaxDiff/checkpoints/general_diffusion_2025-04-18_06:34:50",
)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}
Model name: diffusion-oxford_flowers102-res256


[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000,
                          sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule, val_steps_per_epoch=2)

Using classifier-free guidance
Validation run for sanity check for process index 0


100%|██████████| 200/200 [00:59<00:00,  3.35it/s]


Evaluation started for process index 0


100%|██████████| 200/200 [00:17<00:00, 11.23it/s]


[32mSanity Validation done on process index 0[0m

Epoch 0/2000


		Epoch 0:   0%|                                                          | 0/255 [00:00<?, ?step/s]

Training started for process index 0 at step 0


		Epoch 0: 300step [07:16,  1.45s/step, loss=0.6174]                                                

[32mEpoch done on index 0 => 0 Loss: 0.642382800579071[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 0 step 255





[32m
	Epoch 0 completed. Avg Loss: 0.642382800579071, Time: 436.18s, Best Loss: 0.642382800579071[0m
Validation started for process index 0


100%|██████████| 200/200 [00:57<00:00,  3.46it/s]


Evaluation started for process index 0


100%|██████████| 200/200 [00:06<00:00, 30.90it/s]


[32mValidation done on process index 0[0m

Epoch 1/2000


		Epoch 1:   0%|                                                          | 0/255 [00:00<?, ?step/s]

Training started for process index 0 at step 255


		Epoch 1: 300step [00:36,  8.18step/s, loss=0.5337]                                                

[32mEpoch done on index 0 => 1 Loss: 0.5769686698913574[0m
[32mEpoch done on process index 0[0m





Saving model at epoch 1 step 510




[32m
	Epoch 1 completed. Avg Loss: 0.5769686698913574, Time: 36.69s, Best Loss: 0.5769686698913574[0m
Validation started for process index 0


100%|██████████| 200/200 [00:06<00:00, 28.61it/s]


Evaluation started for process index 0


100%|██████████| 200/200 [00:06<00:00, 31.02it/s]


In [6]:
!pip freeze | grep jax

  pid, fd = os.forkpty()


ValueError: filedescriptor out of range in select()

In [5]:
import matplotlib.pyplot as plt


def normalizeImage(x): return jax.nn.standardize(x, mean=[127.5], std=[127.5])
def denormalizeImage(x): return (x + 1.0) * 127.5


def plotImages(imgs, fig_size=(8, 8), dpi=100):
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    imglen = imgs.shape[0]
    for i in range(imglen):
        plt.subplot(fig_size[0], fig_size[1], i + 1)
        plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8))
        plt.axis("off")
    plt.show()

In [6]:
sampler = EulerAncestralSampler(
    model=trainer.model,
    noise_schedule=karas_ve_schedule,
    model_output_transform=KarrasPredictionTransform(
        sigma_data=karas_ve_schedule.sigma_data),
    autoencoder=trainer.autoencoder,
    input_config=trainer.input_config,
    guidance_scale=3,
    timestep_spacing="linear"
)

Using classifier-free guidance


In [7]:
prompts = [
    'water tulip',
    'a water lily',
    'a water lily',
    'a photo of a rose',
    'a photo of a rose',
    'a water lily',
    'a water lily',
    'a photo of a marigold',
]
images = sampler.generate_samples(
    params=trainer.best_state.params,
    resolution=IMAGE_SIZE,
    num_samples=len(prompts),
    sequence_length=None,
    diffusion_steps=200,
    start_step=1000,
    end_step=0,
    conditioning=prompts,
    # model_conditioning_inputs=(encoded,)
)

Processing raw conditioning inputs to generate model conditioning inputs


100%|██████████| 200/200 [00:15<00:00, 13.01it/s]


In [5]:
plotImages(images, dpi=500)

NameError: name 'plotImages' is not defined

In [7]:
trainer.checkpoint_path()

'/home/mrwhite0racle/persist/FlaxDiff/checkpoints/general_diffusion_demo_for_inference'

In [5]:
trainer.push_to_registry()

[34m[1mwandb[0m: Adding directory to artifact (/home/mrwhite0racle/persist/FlaxDiff/checkpoints/general_diffusion_2025-04-18_06:34:50/411355)... Done. 8.9s


Model pushed to registry at wandb-registry-model/diffusion-oxford_flowers102-res256


<Artifact QXJ0aWZhY3Q6MTY2ODMwMTg0NA==>

In [1]:
trainer.wandb.run

NameError: name 'trainer' is not defined