In [2]:
from backend import workflow
from imagen_pytorch import Unet, Imagen, ImagenTrainer
import torch
from IPython.display import clear_output
import gc
import time

config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

In [None]:
from download_dataset import update_datasets
update_datasets()

Downloading from https://www.kaggle.com/api/v1/datasets/download/dimensi0n/imagenet-256?dataset_version_number=1...


100%|██████████| 7.15G/7.15G [00:37<00:00, 205MB/s] 

Extracting files...





In [None]:
from backend import create_captions
create_captions('datasets/datasets/dimensi0n/imagenet-256')

In [None]:
source_directory = 'datasets/datasets/dimensi0n/imagenet-256'
using_checkpoint = 'checkpoints/a_model.pt'

dataloader = workflow(source_directory)

In [None]:
# Imagen & Unets

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

In [4]:
#Preparing imagen and trainer

imagen.train()
trainer = ImagenTrainer(imagen, use_ema=False).cuda()
trainer.train()
clear_output()

In [5]:
# Load the model

trainer.load(using_checkpoint)
clear_output()

In [None]:
# Code for running the training

num_epochs = 120
unet_toTrain = 2    # Can't train both unets at the same time, need to save checkpoint and re-initiate the trainer when changing unets  z

#Training in epochs, each unet is trained once in each epoch
losses = 0
for epoch in range(num_epochs):
    clear_output()
    print(f"Epoch {epoch+1}/{num_epochs}, Training UNet {unet_toTrain}, Prev epoch avg loss: {losses/len(dataloader)}")
    losses = 0
    for images_batch, text_embeds_batch in dataloader:
        if images_batch.numel() == 0: continue # Skip if batch is empty after filtering

        images_batch = images_batch.cuda()
        text_embeds_batch = text_embeds_batch.cuda()

        loss = trainer(images_batch, text_embeds = text_embeds_batch, unet_number = unet_toTrain)
        losses += loss
        trainer.update(unet_number = unet_toTrain)
        print(f"Loss: {loss}")

In [None]:
# Save the model
trainer.save(using_checkpoint)

In [None]:
# Manually clear the cache
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Workflow code for less supervised training
num_epochs_perRound = 120
unet_toTrain = 2

rounds = 2
change_unet_afterRound = 1

trainer = ImagenTrainer(imagen).cuda()

rounds_scores = []

for round in range(rounds):
    if change_unet_afterRound == round:
        unet_toTrain = 2
    del trainer
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(10)
    imagen.train()
    trainer = ImagenTrainer(imagen).cuda()
    trainer.train()
    clear_output()
    
    trainer.load(using_checkpoint)
    clear_output()
    losses = 0
    round_losses = 0
    for epoch in range(num_epochs_perRound):
        clear_output()
        print(f"Round {round+1}, Epoch {epoch+1}/{num_epochs_perRound}, Training UNet {unet_toTrain}, Prev epoch avg loss: {losses/len(dataloader)}")
        losses = 0
        for images_batch, text_embeds_batch in dataloader:
            if images_batch.numel() == 0: continue

            images_batch = images_batch.cuda()
            text_embeds_batch = text_embeds_batch.cuda()

            loss = trainer(images_batch, text_embeds = text_embeds_batch, unet_number = unet_toTrain)
            losses += loss
            trainer.update(unet_number = unet_toTrain)
            print(f"Loss: {loss}")
        round_losses += losses/len(dataloader)
        if epoch == num_epochs_perRound - 1:
            print(f"Last epoch avg loss: {losses/len(dataloader)}")
    rounds_scores.append(round_losses/num_epochs_perRound)
    trainer.save(using_checkpoint)

clear_output()
for i, score in enumerate(rounds_scores):
    print(f"Round {i+1} avg loss: {score}")

images = trainer.sample(texts = [
    'two zebras in a field',
    'two zebras in a zoo',
    'the zebra is black and white'
], cond_scale = 3.)

import matplotlib.pyplot as plt
import torchvision.utils as vutils

images_cpu = images.cpu().detach()
images_cpu = torch.clamp(images_cpu, 0, 1)
grid = vutils.make_grid(images_cpu, nrow=len(images_cpu), padding=2, normalize=False)
np_grid = grid.permute(1, 2, 0).numpy()

plt.figure(figsize=(10, 5))
plt.imshow(np_grid)
plt.axis('off')
plt.show()

In [None]:
# Generate images

images = trainer.sample(texts = [
    'two zebras in a field',
    'two zebras in a zoo',
    'the zebra is black and white'
], cond_scale = 3.)

# Code for displaying the generated images

import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Move images tensor to CPU and detach from gradient computation
images_cpu = images.cpu().detach()

# If your images are normalized to [-1, 1], denormalize them to [0, 1]
# images_cpu = (images_cpu + 1) / 2 # Uncomment if you used transforms.Normalize(...)

# Clamp values to [0, 1] just in case
images_cpu = torch.clamp(images_cpu, 0, 1)

# Make a grid of images (rows = number of samples)
grid = vutils.make_grid(images_cpu, nrow=len(images_cpu), padding=2, normalize=False) # normalize=False as we clamped

# Convert grid tensor to numpy array and transpose dimensions for matplotlib
# (C, H, W) -> (H, W, C)
np_grid = grid.permute(1, 2, 0).numpy()

# Display the grid
plt.figure(figsize=(10, 5))
plt.imshow(np_grid)
plt.axis('off')
plt.show()