In [1]:
from backend import workflow
from imagen_pytorch import Unet, Imagen, ImagenTrainer
import torch
from IPython.display import clear_output

  from .autonotebook import tqdm as notebook_tqdm
2025-04-18 10:16:51.874547: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-18 10:16:52.008501: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744964212.062044     415 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744964212.076451     415 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-18 10:16:52.225345: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorF

In [2]:
source_directory = 'datasets/dimensi0n/imagenet-256/zebra'
using_checkpoint = 'checkpoints/a_model.pt'

dataloader = workflow(source_directory)

Loading T5 Tokenizer and Encoder Model (t5-large)...
T5 models loaded.


In [3]:
# Imagen & Unets

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/


In [12]:
#Preparing imagen and trainer

imagen.train()
trainer = ImagenTrainer(imagen).cuda()
trainer.train()
clear_output()

In [13]:
# Load the model

trainer.load(using_checkpoint)
clear_output()

In [None]:
# Code for running the training

num_epochs = 10
unet_toTrain = 2    # Can't train both unets at the same time, need to save checkpoint and re-initiate the trainer when changing unets  z

#Training in epochs, each unet is trained once in each epoch
losses = 0
for epoch in range(num_epochs):
    clear_output()
    print(f"Epoch {epoch+1}/{num_epochs}, Training UNet {unet_toTrain}, Prev epoch avg loss: {losses/len(dataloader)}")
    losses = 0
    for images_batch, text_embeds_batch in dataloader:
        if images_batch.numel() == 0: continue # Skip if batch is empty after filtering

        images_batch = images_batch.cuda()
        text_embeds_batch = text_embeds_batch.cuda()

        loss = trainer(images_batch, text_embeds = text_embeds_batch, unet_number = unet_toTrain)
        losses += loss
        trainer.update(unet_number = unet_toTrain)
        print(f"Loss: {loss}")

In [7]:
# Save the model
trainer.save(using_checkpoint)

checkpoint saved to checkpoints/a_model.pt


In [None]:
# Workflow code for less supervised training
num_epochs_perRound = 120
unet_toTrain = 1

rounds = 5
change_unet_afterRound = 2

trainer = ImagenTrainer(imagen).cuda()

rounds_scores = []

for round in range(rounds):
    if change_unet_afterRound == i:
        unet_toTrain = 2
    imagen.train()
    trainer = ImagenTrainer(imagen).cuda()
    trainer.train()
    clear_output()
    
    trainer.load(using_checkpoint)
    clear_output()
    losses = 0
    round_losses = 0
    for epoch in range(num_epochs_perRound):
        clear_output()
        print(f"Round {round+1}, Epoch {epoch+1}/{num_epochs_perRound}, Training UNet {unet_toTrain}, Prev epoch avg loss: {losses/len(dataloader)}")
        round_losses += losses/len(dataloader)
        losses = 0
        for images_batch, text_embeds_batch in dataloader:
            if images_batch.numel() == 0: continue

            images_batch = images_batch.cuda()
            text_embeds_batch = text_embeds_batch.cuda()

            loss = trainer(images_batch, text_embeds = text_embeds_batch, unet_number = unet_toTrain)
            losses += loss
            trainer.update(unet_number = unet_toTrain)
            print(f"Loss: {loss}")
        if epoch == num_epochs_perRound - 1:
            print(f"Last epoch avg loss: {losses/len(dataloader)}")
    rounds_scores.append(round_losses/num_epochs_perRound)
    trainer.save(using_checkpoint)

clear_output()
for i, score in enumerate(rounds_scores):
    print(f"Round {i+1} avg loss: {score}")

images = trainer.sample(texts = [
    'two zebras in a field',
    'two zebras in a zoo',
    'the zebra is black and white'
], cond_scale = 3.)

import matplotlib.pyplot as plt
import torchvision.utils as vutils

images_cpu = images.cpu().detach()
images_cpu = torch.clamp(images_cpu, 0, 1)
grid = vutils.make_grid(images_cpu, nrow=len(images_cpu), padding=2, normalize=False)
np_grid = grid.permute(1, 2, 0).numpy()

plt.figure(figsize=(10, 5))
plt.imshow(np_grid)
plt.axis('off')
plt.show()

Epoch 30/120, Training UNet 2, Prev epoch avg loss: 0.020036027427833482
Loss: 0.04046138748526573
Loss: 0.029270917177200317
Loss: 0.041758596897125244
Loss: 0.013828713446855545
Loss: 0.01839420758187771
Loss: 0.010121184401214123
Loss: 0.014673742465674877
Loss: 0.006778263952583075
Loss: 0.014876039698719978
Loss: 0.02069457247853279
Loss: 0.00493717473000288
Loss: 0.021929875016212463
Loss: 0.02909013256430626
Loss: 0.011233475059270859
Loss: 0.023960234597325325
Loss: 0.02447344921529293
Loss: 0.013372305780649185
Loss: 0.005941904615610838
Loss: 0.0025979834608733654
Loss: 0.016660477966070175
Loss: 0.009339967742562294
Loss: 0.01728961616754532
Loss: 0.01661369390785694
Loss: 0.01172249112278223
Loss: 0.026525042951107025
Loss: 0.02476748265326023
Loss: 0.037353239953517914
Loss: 0.02223319560289383
Loss: 0.009747845120728016
Loss: 0.014803659170866013
Loss: 0.017430288717150688
Loss: 0.014379195868968964
Loss: 0.029190320521593094
Loss: 0.014646803960204124
Loss: 0.02847934886

KeyboardInterrupt: 

In [None]:
# Generate images

images = trainer.sample(texts = [
    'two zebras in a field',
    'two zebras in a zoo',
    'the zebra is black and white'
], cond_scale = 3.)

# Code for displaying the generated images

import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Move images tensor to CPU and detach from gradient computation
images_cpu = images.cpu().detach()

# If your images are normalized to [-1, 1], denormalize them to [0, 1]
# images_cpu = (images_cpu + 1) / 2 # Uncomment if you used transforms.Normalize(...)

# Clamp values to [0, 1] just in case
images_cpu = torch.clamp(images_cpu, 0, 1)

# Make a grid of images (rows = number of samples)
grid = vutils.make_grid(images_cpu, nrow=len(images_cpu), padding=2, normalize=False) # normalize=False as we clamped

# Convert grid tensor to numpy array and transpose dimensions for matplotlib
# (C, H, W) -> (H, W, C)
np_grid = grid.permute(1, 2, 0).numpy()

# Display the grid
plt.figure(figsize=(10, 5))
plt.imshow(np_grid)
plt.axis('off')
plt.show()