### Import and process dataset

In [1]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "training_output"  # the model name locally and on the HF Hub

    push_to_hub = True  # whether to upload the saved model to the HF Hub
    hub_model_id = "QLeca/NextLayerModularCharacterModel"  # the name of the repository to create on the HF Hub
    hub_private_repo = None
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0


config = TrainingConfig()

In [None]:
from datasets import load_dataset

config.dataset_name = "QLeca/modular_characters"
dataset = load_dataset(config.dataset_name, split="train")

README.md:   0%|          | 0.00/471 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Resolving data files:   0%|          | 0/62 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/62 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/62 [00:00<?, ?files/s]

train-00000-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00001-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00002-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00003-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00004-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00005-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00006-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00007-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00008-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00009-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00010-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00011-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00012-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00013-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00014-of-00062.parquet:   0%|          | 0.00/408M [00:00<?, ?B/s]

train-00015-of-00062.parquet:   0%|          | 0.00/408M [00:00<?, ?B/s]

train-00016-of-00062.parquet:   0%|          | 0.00/408M [00:00<?, ?B/s]

train-00017-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00018-of-00062.parquet:   0%|          | 0.00/407M [00:00<?, ?B/s]

train-00019-of-00062.parquet:   0%|          | 0.00/406M [00:00<?, ?B/s]

train-00020-of-00062.parquet:   0%|          | 0.00/416M [00:00<?, ?B/s]

train-00021-of-00062.parquet:   0%|          | 0.00/421M [00:00<?, ?B/s]

train-00022-of-00062.parquet:   0%|          | 0.00/420M [00:00<?, ?B/s]

train-00023-of-00062.parquet:   0%|          | 0.00/421M [00:00<?, ?B/s]

train-00024-of-00062.parquet:   0%|          | 0.00/421M [00:00<?, ?B/s]

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(4, 4, figsize=(16, 4))
for i in range(4):
    input_image = dataset[i]['input']
    target_image = dataset[i]['target']
    prompt = dataset[i]['prompt']
    
    axs[i][0].imshow(input_image)
    axs[i][0].set_axis_off()
    axs[i][1].imshow(target_image)
    axs[i][1].set_axis_off()
    print(prompt)
fig.show()


In [None]:
from torchvision import transforms

preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

In [None]:
def transform(rows):
    images_input = [preprocess(image) for image in rows["input"]]
    images_target = [preprocess(image) for image in rows["target"]]
    
    return {"input": images_input,
            'target': images_target,
            'prompt': rows['prompt']}


dataset.set_transform(transform)

In [None]:
import torch
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

### Create U-Net Model

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=4,  # the number of input channels, 3 for RGBA images
    out_channels=4,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)