In [2]:
pip install minimagen

Defaulting to user installation because normal site-packages is not writeable
Collecting minimagen
  Downloading minimagen-0.0.9-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.0/43.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiohttp==3.8.1 (from minimagen)
  Downloading aiohttp-3.8.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m64.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiosignal==1.2.0 (from minimagen)
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting async-timeout==4.0.2 (from minimagen)
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting attrs==21.4.0 (from minimagen)
  Downloading attrs-21.4.0-py2.py3-none-any.whl (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m34.1 MB/s[0m eta [36m0:

In [1]:
import torch
from torch import optim
import numpy as np

from minimagen.Imagen import Imagen
from minimagen.Unet import Unet
from minimagen.t5 import get_encoded_dim

from model.training import CustomDataset
from torch.utils.data import DataLoader

log_file='minimagen_training_progress.txt'
def show_msg(msg, file=log_file):
    if file is not None:
        with open(file, 'a') as f:
            f.write(msg+'\n')
    print(msg)
open(log_file, 'w').close()

# Constants
BATCH_SIZE = 32  # Batch size training data
EPOCHS = 5  # Number of epochs to train from
T5_NAME = "t5_base"  # Name of the T5 encoder to use
LR = 0.0001
save_dir = './weights_minimagen/'

# Captions to generate samples for
CAPTIONS = [
    'a happy dog',
    'a big red house',
    'a woman standing on a beach',
    'a man on a bike'
]

# Get device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Get encoding dimension of the text encoder
text_embed_dim = get_encoded_dim(T5_NAME)

# Create Unets
pre_unet = Unet(
    dim=128,
    text_embed_dim=text_embed_dim,
    cond_dim=256,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 4),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=False,
    memory_efficient=True
)

base_unet = Unet(
    dim=128,
    text_embed_dim=text_embed_dim,
    cond_dim=256,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 4),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=False,
    memory_efficient=True
)

super_res_unet = Unet(
    dim=128,
    text_embed_dim=text_embed_dim,
    cond_dim=256,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 4),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=False,
    memory_efficient=True
)
show_msg("Created Unets")

unets=(pre_unet,base_unet,super_res_unet)
# Create Imagen from Unets
imagen = Imagen(
    unets=unets,
    image_sizes=(64,64,64),
    timesteps=500,
    cond_drop_prob=0.1,
    text_encoder_name=T5_NAME
).to(device)
#imagen.load_state_dict(torch.load(save_dir+"model_000.pth", map_location=device))
show_msg("Created Imagen")

# Create example data
dataset_data_path = './dataset/Flickr8k_dataset.npy'
# load dataset
dataset = CustomDataset(dataset_data_path)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, drop_last=True)

# Create optimizer
optimizer = optim.Adam(imagen.parameters(), lr=LR)
show_msg("Created optimzer")

# Train on example data
show_msg("Training Imagen...")
list_total_loss = []
training_steps = 0
for epoch in range(EPOCHS):
    show_msg("------------------------------------ epoch {:03d} ------------------------------------".format(epoch + 1))
    imagen.train()
    
    total_loss = 0
    loss_list = []
    for x_0, labels in dataloader:   # x_0: images
        x_0 = x_0.to(device)
        step_loss = [0 for i in range(1,len(unets)+1)]
        for i in range(1,len(unets)+1):
            optimizer.zero_grad()
            loss = imagen(x_0, texts=labels, unet_number=i)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(imagen.parameters(), 50)
            step_loss[i-1] = loss.item()
            
        optimizer.step()
        training_steps+=1
        loss_list.append(sum(step_loss))
        if (training_steps%50) == 0:
            show_msg("Total train step: {}, Loss: {}".format(training_steps,step_loss))
            
    loss_list = np.array(loss_list)
    show_msg("Min loss: {}".format(loss_list.min()))
    show_msg("Max loss: {}".format(loss_list.max()))
    total_loss = loss_list.sum()
    show_msg("Mean loss: {}".format(loss_list.mean()))
    show_msg("Std loss: {}".format(loss_list.std()))
    show_msg("Total Loss: {}".format(total_loss))
    list_total_loss.append(total_loss)
  # save model periodically
    if epoch%5==0 or epoch == int(EPOCHS-1):
        torch.save(imagen.state_dict(), save_dir + "model_{:03d}.pth".format(epoch + 1))
        show_msg('saved model at ' + save_dir + "model_{:03d}.pth".format(epoch + 1))
        
plt.figure()
plt.plot(list_total_loss)
plt.title("Total Loss vs Epoch")
plt.savefig('train.png')

# Generate images with "trained" model
imagen.eval()
show_msg("Sampling from Imagen...")
images = imagen.sample(texts=CAPTIONS, cond_scale=3., return_pil_images=True)

# Save output PIL images
show_msg("Saving Images")
for idx, img in enumerate(images):
    img.save(f'Generated_Image_{idx}.png')

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Created Unets
Created Imagen
Created optimzer
Training Imagen...
------------------------------------ epoch 001 ------------------------------------


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Total train step: 50, Loss: [1.0464767217636108, 1.0925509929656982, 0.7452291250228882]
Total train step: 100, Loss: [1.0444884300231934, 1.0893826484680176, 0.49260348081588745]
Total train step: 150, Loss: [1.045423984527588, 1.0864667892456055, 0.4157952666282654]
Total train step: 200, Loss: [1.0430138111114502, 1.0864639282226562, 0.2685897946357727]
Total train step: 250, Loss: [1.0444200038909912, 1.067206621170044, 0.23961210250854492]
Total train step: 300, Loss: [1.0488615036010742, 1.0735745429992676, 0.2255721539258957]


KeyboardInterrupt: 

In [1]:
import torch
from torch import optim

from minimagen.Imagen import Imagen
from minimagen.Unet import Unet
from minimagen.t5 import get_encoded_dim

from model.training import *
from torch.utils.data import DataLoader

# Constants
BATCH_SIZE = 32  # Batch size training data
MAX_NUM_WORDS = 128  # Max number of words allowed in a caption
IMG_SIDE_LEN = 64  # Side length of the training images/final output image from Imagen
T5_NAME = "t5_base"  # Name of the T5 encoder to use
save_dir = './weights_minimagen/'

# Captions to generate samples for
CAPTIONS = [
    'a happy dog',
    'a big red house',
    'a woman standing on a beach',
    'a man on a bike'
]

# Get device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Get encoding dimension of the text encoder
text_embed_dim = get_encoded_dim(T5_NAME)

# Create Unets
pre_unet = Unet(
    dim=64,
    text_embed_dim=text_embed_dim,
    cond_dim=256,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=1,
    layer_attns=(False, True, True),
    layer_cross_attns=(False, True, True),
    attend_at_middle=True,
    memory_efficient=True
)

base_unet = Unet(
    dim=128,
    text_embed_dim=text_embed_dim,
    cond_dim=256,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=2,
    layer_attns=(False, True, True),
    layer_cross_attns=(False, True, True),
    attend_at_middle=True,
    memory_efficient=True
)

super_res_unet = Unet(
    dim=128,
    text_embed_dim=text_embed_dim,
    cond_dim=256,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 4),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=False,
    memory_efficient=True
)
print("Created Unets")

# Create Imagen from Unets
imagen = Imagen(
    unets=(pre_unet,base_unet,super_res_unet),
    image_sizes=(16,32,64),
    timesteps=500,
    cond_drop_prob=0.1,
    text_encoder_name=T5_NAME
).to(device)
print("Created Imagen")

imagen.load_state_dict(torch.load(save_dir+"model_031.pth", map_location=device))
imagen.train()

# Generate images with "trained" model
print("Sampling from Imagen...")
images = imagen.sample(texts=CAPTIONS, cond_scale=3., return_pil_images=True)

# Save output PIL images
print("Saving Images")
for idx, img in enumerate(images):
    img.save(f'Generated_Image_{idx}.png')


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Created Unets
Created Imagen
Sampling from Imagen...


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
0it [00:00, ?it/s

Saving Images



