In [1]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime

BATCH_SIZE = 16
IMAGE_SIZE = 128

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

In [4]:
text_encoder = defaultTextEncodeModel()

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing FlaxCLIPTextModel: {('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers',

In [5]:
# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield tokens

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)

In [None]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model
unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
            ],
            num_res_blocks=2,
            num_middle_res_blocks=1
)

In [6]:
# Define optimizer
solver = optax.adam(2e-4)
name = "prototype-edm-Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

# Create trainer
trainer = DiffusionTrainer(
    unet, optimizer=solver, 
    input_shapes=input_shapes,
    noise_schedule=edm_schedule,
    rngs=jax.random.PRNGKey(4), 
    name=name,
    model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
    encoder=text_encoder,
    distributed_training=True,
    wandb_config = {
        "entity": "umd-projects",
        "project": 'mlops-msml605-project',
        "name": name,
    })


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)

		Epoch 176: 600step [00:33, 18.04step/s, loss=0.0591]                                              

[32mEpoch done on index 0 => 176 Loss: 0.07166442275047302[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 176 completed. Avg Loss: 0.07166442275047302, Time: 33.26s, Best Loss: 0.07039067149162292[0m
Validation started for process index 0



100%|██████████| 200/200 [00:04<00:00, 43.10it/s]


[32mValidation done on process index 0[0m

Epoch 177/2000


		Epoch 177:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0670]

First batch loaded at step 90447
Training started for process index 0 at step 90447


		Epoch 177: 600step [00:32, 18.30step/s, loss=0.0723]                                              

[32mEpoch done on index 0 => 177 Loss: 0.07239051908254623[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 177 completed. Avg Loss: 0.07239051908254623, Time: 32.78s, Best Loss: 0.07039067149162292[0m
Validation started for process index 0



100%|██████████| 200/200 [00:04<00:00, 42.76it/s]


[32mValidation done on process index 0[0m

Epoch 178/2000


		Epoch 178:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0810]

First batch loaded at step 90958
Training started for process index 0 at step 90958


		Epoch 178: 600step [00:32, 18.46step/s, loss=0.0522]                                              

[32mEpoch done on index 0 => 178 Loss: 0.07139013707637787[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 178 completed. Avg Loss: 0.07139013707637787, Time: 32.51s, Best Loss: 0.07039067149162292[0m
Validation started for process index 0



100%|██████████| 200/200 [00:04<00:00, 42.77it/s]


[32mValidation done on process index 0[0m

Epoch 179/2000


		Epoch 179:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0817]

First batch loaded at step 91469
Training started for process index 0 at step 91469


		Epoch 179: 600step [00:33, 18.00step/s, loss=0.0707]                                              

[32mEpoch done on index 0 => 179 Loss: 0.07138639688491821[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 179 completed. Avg Loss: 0.07138639688491821, Time: 33.34s, Best Loss: 0.07039067149162292[0m
Validation started for process index 0



100%|██████████| 200/200 [00:04<00:00, 42.89it/s]


[32mValidation done on process index 0[0m

Epoch 180/2000


		Epoch 180:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0795]

First batch loaded at step 91980
Training started for process index 0 at step 91980


		Epoch 180:  59%|███████████████████▎             | 300/511 [00:13<00:10, 20.85step/s, loss=0.0969]