In [1]:
try:
    from google.colab import drive
    drive.mount("/content/gdrive/")
except:
    pass

In [2]:
import os
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator, notebook_launcher
from accelerate.tracking import TensorBoardTracker
from datasets import load_dataset
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel  # NOQA
from diffusers.optimization import get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from config.config import TrainingConfig
from utils.train_utils import Collator, GlyffuserPipeline, evaluate  # NOQA
from utils.utils import get_repo_dir

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
cfg = TrainingConfig(
    image_size=32,
    train_batch_size=32,
    eval_batch_size=32,
    encoder_dim=512,
    mixed_precision="no",
    save_image_epochs=1,
    save_model_epochs=5
)

In [5]:
def train_loop(
    cfg: TrainingConfig,
    model: torch.nn.Module,
    train_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
    noise_scheduler: DDPMScheduler,
    inference_scheduler: DPMSolverMultistepScheduler,
):
    # Initialize accelerator
    accelerator = Accelerator(
        mixed_precision=cfg.mixed_precision,
        gradient_accumulation_steps=cfg.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=Path(cfg.output_dir) / "logs",
    )

    # Initialize tensorboard logging
    if accelerator.is_main_process:
        if cfg.output_dir is not None:
            os.makedirs(cfg.output_dir, exist_ok=True)
        run_name = f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        accelerator.init_trackers(run_name)

    # Prepare everything for accelerator
    model, optimizer, train_dataloader, lr_scheduler = \
        accelerator.prepare(model, optimizer, train_dataloader, lr_scheduler)

    # Train model
    global_step = 0
    for epoch in range(cfg.num_epochs):
        pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        pbar.set_description(f"Epoch {epoch}")

        model.train()
        for step, (b_imgs, b_texts_embed, b_masks) in pbar:
            noise = torch.randn(b_imgs.shape).to(b_imgs.device)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_imgs.shape[0],), device=b_imgs.device).long()
            b_imgs_noisy = noise_scheduler.add_noise(b_imgs, noise, timesteps)

            with accelerator.accumulate(model):
                noise_pred = model(b_imgs_noisy, timesteps, encoder_hidden_states=b_texts_embed, return_dict=False)[0]

                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            pbar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        model.eval()
        if accelerator.is_main_process and epoch > 0:
            pipeline = GlyffuserPipeline(unet=accelerator.unwrap_model(model), scheduler=inference_scheduler)

            # Save model checkpoint
            if epoch % cfg.save_model_epochs == 0 or epoch == cfg.num_epochs - 1:
                pipeline.save_pretrained(os.path.join(cfg.output_dir, "models", run_name))

            # Evaluate and log images
            if epoch % cfg.save_image_epochs == 0 or epoch == cfg.num_epochs - 1:
                texts = ["love", "hate", "hot", "cold", "one", "two", "up", "down", "day", "night", "top", "bottom", "summer", "winter", "white", "black"]
                with torch.no_grad():
                    img_grid = evaluate(cfg, epoch, texts, pipeline)
                tb_tracker: TensorBoardTracker = accelerator.get_tracker("tensorboard")
                tb_tracker.log_images({"eval_imgs": np.array(img_grid)}, step=global_step, dataformats="WHC")

In [6]:
# Define collator
collator = Collator(
    image_size=cfg.image_size,
    text_label="Chinese Definition",
    image_label="Filename",
    name=cfg.text_encoder,
    channels="L"
)

# Define data source
# data_files = get_repo_dir() / Path("data/metadata.jsonl")
data_files = str(get_repo_dir() / Path("data/data/metadata_top1000.jsonl"))
dataset = load_dataset("json", data_files=data_files)
train_dataloader = DataLoader(dataset["train"], batch_size=cfg.train_batch_size, collate_fn=collator, shuffle=True)

# Define model
model = UNet2DConditionModel(
    sample_size=cfg.image_size,             # Target image size
    in_channels=1,                          # Input channels (1 for grayscale)
    out_channels=1,                         # Output channels (1 for grayscale)
    layers_per_block=2,                     # ResNet layers per UNet block
    block_out_channels=(32, 32, 64, 128),   # Output channels for each block
    addition_embed_type="text",             # Condition on text
    cross_attention_dim=cfg.encoder_dim,    # Cross attention dim
    encoder_hid_dim=cfg.encoder_dim,        # Encoder hidden dim size
    encoder_hid_dim_type="text_proj",       # Encoder hidden dim type
    down_block_types=(                      # Downsample block types (Top -> Bottom)
        "DownBlock2D", "DownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"
    ),
    up_block_types=(                        # Upsample block types (Bottom -> Top)
        "UpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D", "UpBlock2D"
    ),
)

# Define optimizers and schedulers
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
inference_scheduler = DPMSolverMultistepScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=cfg.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * cfg.num_epochs),
)

# Check model parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Model Parameters: {total_params:,}")

Total Model Parameters: 6,031,425


In [None]:
# Train model
args = (cfg, model, train_dataloader, optimizer, lr_scheduler, noise_scheduler, inference_scheduler)
notebook_launcher(train_loop, args, num_processes=1)