In [1]:
try:
    from google.colab import drive
    drive.mount("/content/gdrive/")
except:
    pass

In [None]:
import os
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel  # NOQA
from diffusers.optimization import get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from config.text2char_config import TrainingConfigText2Char
from utils.eval_utils import DiffusionPipelineText2Char  # NOQA
from utils.train_utils import get_dataloader
from utils.utils import get_repo_dir

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
ROOT_IMAGE_DIR = get_repo_dir() / Path("data/data")

In [None]:
cfg = TrainingConfigText2Char(
    image_size=32,
    train_batch_size=32,
    eval_batch_size=16,
    encoder_dim=512,
    save_image_epochs=1,
    save_model_epochs=5,
)

In [6]:
train_dataloader = get_dataloader(
    cfg,
    ROOT_IMAGE_DIR,
    caption_jsonl_path=ROOT_IMAGE_DIR / "metadata.jsonl",
    caption_label="Chinese Definition",
    caption_encoder="google-t5/t5-small",
)

In [7]:
# Define model
model = UNet2DConditionModel(
    sample_size=cfg.image_size,             # Target image size
    in_channels=1,                          # Input channels (1 for grayscale)
    out_channels=1,                         # Output channels (1 for grayscale)
    layers_per_block=2,                     # ResNet layers per UNet block
    block_out_channels=(32, 64, 128, 128),  # Output channels for each block
    addition_embed_type="text",             # Condition on text
    cross_attention_dim=cfg.encoder_dim,    # Cross attention dim
    encoder_hid_dim=cfg.encoder_dim,        # Encoder hidden dim size
    encoder_hid_dim_type="text_proj",       # Encoder hidden dim type
    down_block_types=(                      # Downsample block types (Top -> Bottom)
        "DownBlock2D", "DownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"
    ),
    up_block_types=(                        # Upsample block types (Bottom -> Top)
        "UpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D", "UpBlock2D"
    ),
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Model Parameters: {total_params:,}")

Total Model Parameters: 9,826,241


In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=cfg.lr_warmup_steps,
    num_training_steps=len(train_dataloader) * cfg.num_epochs,
)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
inference_scheduler = DPMSolverMultistepScheduler(num_train_timesteps=1000)

In [None]:
def train_loop(
    cfg: TrainingConfigText2Char,
    train_dataloader: DataLoader,
    model: UNet2DConditionModel,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
    noise_scheduler: DDPMScheduler,
    inference_scheduler: DPMSolverMultistepScheduler,
):
    device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
    model = model.to(device)

    # Tensorboard logging
    if cfg.output_dir is not None:
        os.makedirs(cfg.output_dir, exist_ok=True)
    run_name = f"train_text2char_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    log_dir = str(Path(cfg.output_dir) / "logs" / run_name)
    writer = SummaryWriter(log_dir=log_dir)

    # Train loop
    global_step = 0
    for epoch in range(cfg.num_epochs):
        pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}")

        model.train()
        for b_imgs, b_texts_embed, b_masks in pbar:
            b_imgs = b_imgs.to(device)
            b_texts_embed = b_texts_embed.to(device)
            b_masks = b_masks.to(device)

            noise = torch.randn_like(b_imgs)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_imgs.shape[0],), device=device).long()
            b_imgs_noisy = noise_scheduler.add_noise(b_imgs, noise, timesteps)

            optimizer.zero_grad()
            noise_pred = model(b_imgs_noisy, timesteps, encoder_hidden_states=b_texts_embed).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            pbar.set_postfix(**logs)
            writer.add_scalar("Loss/train", logs["loss"], global_step)
            writer.add_scalar("LR", logs["lr"], global_step)
            global_step += 1

        model.eval()
        pipeline = DiffusionPipelineText2Char(unet=model, scheduler=noise_scheduler)
        pipeline.set_progress_bar_config(desc="Generating evaluation image grid...")

        # Save model checkpoint
        if epoch % cfg.save_model_epochs == 0 or epoch == cfg.num_epochs - 1:
            if epoch > 0:
                pipeline.save_pretrained(str(Path(cfg.output_dir) / "models" / run_name / f"epoch_{epoch}"))
                pipeline.save_pretrained(str(Path(cfg.output_dir) / "models" / run_name / f"latest"))

        # Evaluate and log images
        if epoch % cfg.save_image_epochs == 0 or epoch == cfg.num_epochs - 1:
            texts = ["love", "hate", "hot", "cold", "one", "two", "up", "down", "day", "night", "top", "bottom", "summer", "winter", "white", "black"]
            img_grid = pipeline.evaluate_texts_to_image_grid(texts, batch_size=cfg.eval_batch_size, output_type="numpy")
            writer.add_images("eval_imgs", img_grid, global_step, dataformats="NWHC")
        writer.flush()
    writer.close()

In [None]:
# Train model
train_loop(cfg, train_dataloader, model, optimizer, lr_scheduler, noise_scheduler, inference_scheduler)

Epoch 0:   0%|          | 0/638 [00:00<?, ?it/s]

Generating evaluation image grid...:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/638 [00:00<?, ?it/s]

In [None]:
# from line_profiler import LineProfiler
# lp = LineProfiler()
# lp.add_function(train_loop)
# lp.run("train_loop(cfg, train_dataloader, model, optimizer, lr_scheduler, noise_scheduler, inference_scheduler)")
# lp.print_stats(sort=True)