In [None]:
from backend import workflow
from imagen_pytorch import Unet, Imagen, ImagenTrainer
import torch

In [None]:
source_directory = 'datasets/dimensi0n/imagenet-256/zebra'
using_checkpoint = 'checkpoints/a_model.pt'

dataloader = workflow(source_directory)

In [None]:
# Imagen & Unets

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

In [None]:
#Preparing imagen, optimizer and trainer

imagen.train()
trainer = ImagenTrainer(imagen).cuda()

optimizer = torch.optim.Adam(imagen.parameters(), lr=1e-4) # Example optimizer, not sure if correct one

In [None]:
# Load the model

trainer.load(using_checkpoint)

In [None]:
# Code for running the training

num_epochs = 1
unet_toTrain = 2    # Can't train both unets at the same time, need to save and reload whole code before changing unets

#Training in epochs, each unet is trained once in each epoch

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}, Training UNet {unet_toTrain}")
    for images_batch, text_embeds_batch in dataloader:
        if images_batch.numel() == 0: continue # Skip if batch is empty after filtering

        images_batch = images_batch.cuda()
        text_embeds_batch = text_embeds_batch.cuda()

        optimizer.zero_grad()
        loss = trainer(images_batch, text_embeds = text_embeds_batch, unet_number = unet_toTrain)
        trainer.update(unet_number = unet_toTrain)
        print(f"Loss: {loss}")
        optimizer.step()

In [None]:
# Save the model
trainer.save(using_checkpoint)

In [None]:
# Generate images

images = trainer.sample(texts = [
    'two zebras in a field'
], cond_scale = 3.)

# Code for displaying the generated images

import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Move images tensor to CPU and detach from gradient computation
images_cpu = images.cpu().detach()

# If your images are normalized to [-1, 1], denormalize them to [0, 1]
# images_cpu = (images_cpu + 1) / 2 # Uncomment if you used transforms.Normalize(...)

# Clamp values to [0, 1] just in case
images_cpu = torch.clamp(images_cpu, 0, 1)

# Make a grid of images (rows = number of samples)
grid = vutils.make_grid(images_cpu, nrow=len(images_cpu), padding=2, normalize=False) # normalize=False as we clamped

# Convert grid tensor to numpy array and transpose dimensions for matplotlib
# (C, H, W) -> (H, W, C)
np_grid = grid.permute(1, 2, 0).numpy()

# Display the grid
plt.figure(figsize=(10, 5))
plt.imshow(np_grid)
plt.axis('off')
plt.show()