# DreamBooth Training Notebook

**Nguyễn Khang Hy (2352662)**

Fine-tune Stable Diffusion với DreamBooth cho các phong cách nghệ thuật từ WikiArt.

## Mục tiêu

- Fine-tune DreamBooth cho 1-2 phong cách đại diện
- Ghi lại thời gian train, kích thước checkpoint, GPU usage
- So sánh với LoRA về chất lượng vs chi phí


## 1. Import libraries


In [None]:
import os
import math
import shutil
import json
from pathlib import Path
from datetime import datetime
from PIL import Image
import torch
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, StableDiffusionPipeline, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
from tqdm.auto import tqdm

## 2. Cấu hình và tham số


In [None]:
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
STYLE_NAME = "Contemporary_Realism"
UNIQUE_TOKEN = "sks"
INSTANCE_PROMPT = f"a {UNIQUE_TOKEN} style painting"
CLASS_PROMPT = "a painting"

INSTANCE_DIR = f"/kaggle/working/dreambooth/{STYLE_NAME}/instance_images"
CLASS_DIR = f"/kaggle/working/dreambooth/{STYLE_NAME}/class_images"
OUTPUT_DIR = f"/kaggle/working/dreambooth_checkpoints/{STYLE_NAME}"

CLASS_IMAGES_DATASET_PATH = None
for possible_path in [
    f"/kaggle/input/priorimages/dreambooth/{STYLE_NAME}/class_images",
    f"/kaggle/input/dreambooth-class-images/dreambooth/{STYLE_NAME}/class_images",
    f"/kaggle/input/dreambooth/dreambooth/{STYLE_NAME}/class_images",
    f"/kaggle/input/dreambooth-class-images/class_images",
]:
    if os.path.exists(possible_path) and len([f for f in os.listdir(possible_path) if f.endswith(('.png', '.jpg', '.jpeg'))]) >= 200:
        CLASS_IMAGES_DATASET_PATH = possible_path
        print(f"Tìm thấy class images dataset tại: {possible_path}")
        break

# Tối ưu memory: Giảm resolution và tăng gradient accumulation
RESOLUTION = 256  # Giảm xuống 256 để tiết kiệm memory tối đa
TRAIN_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 16  # Tăng lên 16 để mô phỏng batch size lớn hơn
MAX_TRAIN_STEPS = 1000
LEARNING_RATE = 5e-6
PRIOR_LOSS_WEIGHT = 1.0
MIXED_PRECISION = "fp16"
GRADIENT_CHECKPOINTING = True
SCALE_LR = False
LR_SCHEDULER = "constant"
LR_WARMUP_STEPS = 0
SNR_GAMMA = None
USE_8BIT_ADAM = False  # Không dùng được do bitsandbytes không tương thích với kaggle
SEED = 2025

# Chỉ có thể train attention layer vì phần cứng không cho phép
TRAIN_ONLY_ATTENTION = True  

# Memory optimization flags
ENABLE_VAE_SLICING = True  # Chia VAE encoding thành các slice nhỏ hơn
ENABLE_VAE_TILING = True  # Chia VAE thành các tile để xử lý ảnh lớn
ENABLE_ATTENTION_SLICING = "max"  # Chia attention thành các slice
ENABLE_XFORMERS = True  # Sử dụng xformers memory efficient attention
ENABLE_CPU_OFFLOAD = True  # Offload VAE và text encoder sang CPU khi không dùng

# CUDA memory management
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,expandable_segments:True")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True  # Tối ưu cuDNN
 
os.makedirs(INSTANCE_DIR, exist_ok=True)
os.makedirs(CLASS_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)


## 3. Chuẩn bị dataset


In [None]:
wikiart_path = "/kaggle/input/wikiart"
style_path = os.path.join(wikiart_path, STYLE_NAME)

if not os.path.exists(style_path):
    print(f"Style {STYLE_NAME} không tồn tại. Kiểm tra lại đường dẫn.")
    print(f"Các style có sẵn: {os.listdir(wikiart_path)[:10]}")
else:
    image_files = [f for f in os.listdir(style_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    print(f"Tìm thấy {len(image_files)} ảnh trong {STYLE_NAME}")
    
    selected_images = image_files[:20]
    print(f"Chọn {len(selected_images)} ảnh cho instance images")
    
    for img_file in selected_images:
        src = os.path.join(style_path, img_file)
        dst = os.path.join(INSTANCE_DIR, img_file)
        shutil.copy(src, dst)
        
        caption_file = os.path.splitext(dst)[0] + ".txt"
        with open(caption_file, "w", encoding="utf-8") as f:
            f.write(INSTANCE_PROMPT)
    
    print(f"Đã copy {len(selected_images)} instance images và tạo captions")


## 4. Generate prior preservation images (class images)


In [None]:
NUM_CLASS_IMAGES = 200

existing_images = [f for f in os.listdir(CLASS_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))] if os.path.exists(CLASS_DIR) else []

if len(existing_images) >= NUM_CLASS_IMAGES:
    print(f"Đã có {len(existing_images)} class images trong {CLASS_DIR}, skip generation")
elif CLASS_IMAGES_DATASET_PATH is not None:
    print(f"Copying class images từ dataset: {CLASS_IMAGES_DATASET_PATH}")
    dataset_images = [f for f in os.listdir(CLASS_IMAGES_DATASET_PATH) if f.endswith(('.png', '.jpg', '.jpeg'))]
    dataset_images = sorted(dataset_images)[:NUM_CLASS_IMAGES]
    
    for img_file in tqdm(dataset_images):
        src = os.path.join(CLASS_IMAGES_DATASET_PATH, img_file)
        dst = os.path.join(CLASS_DIR, img_file)
        shutil.copy(src, dst)
        
        caption_file = os.path.splitext(dst)[0] + ".txt"
        if not os.path.exists(caption_file):
            with open(caption_file, "w", encoding="utf-8") as f:
                f.write(CLASS_PROMPT)
    
    print(f"Đã copy {len(dataset_images)} class images từ dataset")
else:
    print(f"Generating {NUM_CLASS_IMAGES} class images...")
    
    pipeline = StableDiffusionPipeline.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if MIXED_PRECISION == "fp16" else torch.float32,
    )
    pipeline = pipeline.to("cuda")
    
    start_idx = len(existing_images)
    for i in tqdm(range(start_idx, NUM_CLASS_IMAGES)):
        image = pipeline(CLASS_PROMPT, num_inference_steps=50, guidance_scale=7.5).images[0]
        image.save(os.path.join(CLASS_DIR, f"{i:04d}.png"))
        
        caption_file = os.path.join(CLASS_DIR, f"{i:04d}.txt")
        with open(caption_file, "w", encoding="utf-8") as f:
            f.write(CLASS_PROMPT)
    
    del pipeline
    torch.cuda.empty_cache()
    print(f"Đã generate {NUM_CLASS_IMAGES - start_idx} class images mới (tổng: {NUM_CLASS_IMAGES})")


## 5. Load models và tokenizer


In [None]:
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(MODEL_NAME, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(MODEL_NAME, subfolder="unet")

noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")

vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Chỉ train attention layers để giảm memory (optimizer state sẽ nhỏ hơn)
if TRAIN_ONLY_ATTENTION:
    # Freeze tất cả parameters trước
    unet.requires_grad_(False)
    # Chỉ enable gradient cho attention layers
    for name, module in unet.named_modules():
        if "attn" in name or "attention" in name:
            for param in module.parameters():
                param.requires_grad = True
    print("Chỉ train attention layers (giảm ~70% parameters cần train)")
else:
    unet.requires_grad_(True)
    print("Train toàn bộ UNet")

if GRADIENT_CHECKPOINTING:
    unet.enable_gradient_checkpointing()

# UNet phải là float32 để gradients là float32
# Autocast sẽ tự động convert sang float16 trong forward pass
# VAE và text_encoder có thể dùng float16 vì không train
vae_dtype = torch.float16 if MIXED_PRECISION == "fp16" else torch.float32
unet_dtype = torch.float32  # Luôn dùng float32 cho UNet để tránh lỗi gradient scaling

# Áp dụng memory optimizations
if ENABLE_VAE_SLICING:
    vae.enable_slicing()
    print("VAE slicing enabled")

if ENABLE_VAE_TILING:
    vae.enable_tiling()
    print("VAE tiling enabled")

if ENABLE_ATTENTION_SLICING:
    unet.set_attention_slice(ENABLE_ATTENTION_SLICING)
    print(f"Attention slicing enabled: {ENABLE_ATTENTION_SLICING}")

if ENABLE_XFORMERS:
    try:
        unet.enable_xformers_memory_efficient_attention()
        print("XFormers memory efficient attention enabled")
    except Exception as e:
        print(f"XFormers không khả dụng: {e}")

if ENABLE_CPU_OFFLOAD:
    vae.to("cpu")
    text_encoder.to("cpu")
    print("✓ CPU offloading enabled cho VAE và Text Encoder")
else:
    vae.to("cuda", dtype=vae_dtype)
    text_encoder.to("cuda", dtype=vae_dtype)

unet.to("cuda", dtype=unet_dtype)

print(f"VAE dtype: {vae_dtype} (device: {'CPU' if ENABLE_CPU_OFFLOAD else 'CUDA'})")
print(f"Text Encoder dtype: {vae_dtype} (device: {'CPU' if ENABLE_CPU_OFFLOAD else 'CUDA'})")
print(f"UNet dtype: {unet_dtype} (sẽ tự động convert sang float16 trong forward pass)")
print("Đã load models với memory optimizations")


## 6. Setup Accelerator


In [None]:
project_config = ProjectConfiguration(project_dir=OUTPUT_DIR)
accelerator = Accelerator(
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    mixed_precision=MIXED_PRECISION,
    project_config=project_config,
)

if accelerator.is_main_process:
    accelerator.init_trackers("dreambooth")


## 7. Chuẩn bị dataset và dataloader


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms

class DreamBoothDataset(Dataset):
    def __init__(self, instance_data_root, class_data_root, tokenizer, size=512, repeats=1):
        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.instance_images_path = [p for p in self.instance_images_path if p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = INSTANCE_PROMPT
        self._length = self.num_instance_images * repeats

        if class_data_root is not None:
            self.class_images_path = list(Path(class_data_root).iterdir())
            self.class_images_path = [p for p in self.class_images_path if p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images) * repeats
        else:
            self.class_images_path = None

        self.size = size
        self.tokenizer = tokenizer

        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        if self.class_images_path:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                CLASS_PROMPT,
                truncation=True,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt",
            ).input_ids

        return example

def collate_fn(examples):
    batch = {
        "instance_images": torch.stack([example["instance_images"] for example in examples]),
        "instance_prompt_ids": torch.stack([example["instance_prompt_ids"] for example in examples]).squeeze(1),
    }
    if "class_images" in examples[0]:
        batch["class_images"] = torch.stack([example["class_images"] for example in examples])
        batch["class_prompt_ids"] = torch.stack([example["class_prompt_ids"] for example in examples]).squeeze(1)
    return batch

train_dataset = DreamBoothDataset(
    instance_data_root=INSTANCE_DIR,
    class_data_root=CLASS_DIR,
    tokenizer=tokenizer,
    size=RESOLUTION,
    repeats=1,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,
)

print(f"Dataset size: {len(train_dataset)}")


## 8. Setup optimizer và scheduler


In [None]:
if USE_8BIT_ADAM:
    import bitsandbytes as bnb
    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

learning_rate = LEARNING_RATE
if SCALE_LR:
    learning_rate = LEARNING_RATE * GRADIENT_ACCUMULATION_STEPS * TRAIN_BATCH_SIZE * accelerator.num_processes

# optimize các parameters có requires_grad=True để giảm optimizer state
trainable_params = [p for p in unet.parameters() if p.requires_grad]
num_trainable = sum(p.numel() for p in trainable_params)
num_total = sum(p.numel() for p in unet.parameters())
print(f"Trainable parameters: {num_trainable:,} / {num_total:,} ({100*num_trainable/num_total:.1f}%)")

optimizer = optimizer_class(
    trainable_params,  # optimize trainable parameters
    lr=learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / GRADIENT_ACCUMULATION_STEPS)
max_train_steps = MAX_TRAIN_STEPS
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

lr_scheduler = get_scheduler(
    LR_SCHEDULER,
    optimizer=optimizer,
    num_warmup_steps=LR_WARMUP_STEPS * GRADIENT_ACCUMULATION_STEPS,
    num_training_steps=max_train_steps * GRADIENT_ACCUMULATION_STEPS,
)

unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

if not ENABLE_CPU_OFFLOAD:
    text_encoder.to(accelerator.device, dtype=vae_dtype)
    vae.to(accelerator.device, dtype=vae_dtype)

print(f"Max train steps: {max_train_steps}")
print(f"Num epochs: {num_train_epochs}")
print(f"Mixed precision: {MIXED_PRECISION}")
print(f"CPU offloading: {ENABLE_CPU_OFFLOAD}")
print(f"UNet sẽ được autocast sang float16 trong forward pass" if MIXED_PRECISION == "fp16" else "UNet dùng float32")


## 9. Training loop


In [None]:
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0

for epoch in range(num_train_epochs):
    unet.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            with accelerator.autocast():
                pixel_values = batch["instance_images"].to(accelerator.device, dtype=vae_dtype)
                input_ids = batch["instance_prompt_ids"].to(accelerator.device)

                # CPU offloading: Move VAE lên GPU khi encode
                if ENABLE_CPU_OFFLOAD:
                    vae.to(accelerator.device, dtype=vae_dtype)
                
                with torch.no_grad():
                    latents = vae.encode(pixel_values).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor
                
                # Move VAE về CPU sau khi encode
                if ENABLE_CPU_OFFLOAD:
                    vae.to("cpu")
                    torch.cuda.empty_cache()

                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()

                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # CPU offloading: Move text encoder lên GPU khi encode
                if ENABLE_CPU_OFFLOAD:
                    text_encoder.to(accelerator.device, dtype=vae_dtype)
                
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(input_ids.squeeze(1))[0]
                
                # Move text encoder về CPU sau khi encode
                if ENABLE_CPU_OFFLOAD:
                    text_encoder.to("cpu")
                    torch.cuda.empty_cache()

                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                if SNR_GAMMA is None:
                    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
                else:
                    snr = compute_snr(noise_scheduler, timesteps)
                    mse_loss_weights = (
                        torch.stack([snr, SNR_GAMMA * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
                    )
                    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none")
                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                    loss = loss.mean()

                if "class_images" in batch:
                    class_pixel_values = batch["class_images"].to(accelerator.device, dtype=vae_dtype)
                    class_input_ids = batch["class_prompt_ids"].to(accelerator.device)

                    # CPU offloading: Move VAE lên GPU khi encode class images
                    if ENABLE_CPU_OFFLOAD:
                        vae.to(accelerator.device, dtype=vae_dtype)
                    
                    with torch.no_grad():
                        class_latents = vae.encode(class_pixel_values).latent_dist.sample()
                        class_latents = class_latents * vae.config.scaling_factor
                    
                    # Move VAE về CPU sau khi encode
                    if ENABLE_CPU_OFFLOAD:
                        vae.to("cpu")
                        torch.cuda.empty_cache()

                    class_noise = torch.randn_like(class_latents)
                    class_timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=class_latents.device).long()

                    class_noisy_latents = noise_scheduler.add_noise(class_latents, class_noise, class_timesteps)

                    # CPU offloading: Move text encoder lên GPU khi encode class prompts
                    if ENABLE_CPU_OFFLOAD:
                        text_encoder.to(accelerator.device, dtype=vae_dtype)
                    
                    with torch.no_grad():
                        class_encoder_hidden_states = text_encoder(class_input_ids.squeeze(1))[0]
                    
                    # Move text encoder về CPU sau khi encode
                    if ENABLE_CPU_OFFLOAD:
                        text_encoder.to("cpu")
                        torch.cuda.empty_cache()

                    class_model_pred = unet(class_noisy_latents, class_timesteps, class_encoder_hidden_states).sample

                    if noise_scheduler.config.prediction_type == "epsilon":
                        class_target = class_noise
                    elif noise_scheduler.config.prediction_type == "v_prediction":
                        class_target = noise_scheduler.get_velocity(class_latents, class_noise, class_timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                    if SNR_GAMMA is None:
                        class_loss = torch.nn.functional.mse_loss(class_model_pred.float(), class_target.float(), reduction="mean")
                    else:
                        class_snr = compute_snr(noise_scheduler, class_timesteps)
                        class_mse_loss_weights = (
                            torch.stack([class_snr, SNR_GAMMA * torch.ones_like(class_timesteps)], dim=1).min(dim=1)[0] / class_snr
                        )
                        class_loss = torch.nn.functional.mse_loss(class_model_pred.float(), class_target.float(), reduction="none")
                        class_loss = class_loss.mean(dim=list(range(1, len(class_loss.shape)))) * class_mse_loss_weights
                        class_loss = class_loss.mean()

                    loss = loss + PRIOR_LOSS_WEIGHT * class_loss

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                if MIXED_PRECISION != "fp16":
                    # Chỉ clip gradient của trainable parameters
                    trainable_params = [p for p in unet.parameters() if p.requires_grad]
                    accelerator.clip_grad_norm_(trainable_params, 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                # Clear cache sau mỗi optimizer step để giải phóng memory
                torch.cuda.empty_cache()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            # Memory monitoring và logging
            if global_step % 100 == 0 and accelerator.is_main_process:
                # Log loss
                accelerator.log({"loss": loss.detach().item()}, step=global_step)
                
                # Log memory usage
                if torch.cuda.is_available():
                    memory_allocated = torch.cuda.memory_allocated() / 1024**3  # GB
                    memory_reserved = torch.cuda.memory_reserved() / 1024**3  # GB
                    accelerator.log({
                        "memory_allocated_gb": memory_allocated,
                        "memory_reserved_gb": memory_reserved
                    }, step=global_step)
                    print(f"Step {global_step}: Loss={loss.detach().item():.4f}, "
                          f"Memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved")
                
                # Clear cache định kỳ
                torch.cuda.empty_cache()

            if global_step >= max_train_steps:
                break

    accelerator.wait_for_everyone()

if accelerator.is_main_process:
    unet = accelerator.unwrap_model(unet)
    pipeline = StableDiffusionPipeline.from_pretrained(
        MODEL_NAME,
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
    )
    pipeline.save_pretrained(OUTPUT_DIR)
    print(f"Đã lưu checkpoint tại {OUTPUT_DIR}")

accelerator.end_training()


## 10. Ghi lại thông tin training


In [None]:
training_info = {
    "style_name": STYLE_NAME,
    "model_name": MODEL_NAME,
    "max_train_steps": MAX_TRAIN_STEPS,
    "learning_rate": LEARNING_RATE,
    "batch_size": TRAIN_BATCH_SIZE,
    "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
    "mixed_precision": MIXED_PRECISION,
    "resolution": RESOLUTION,
    "num_instance_images": len(os.listdir(INSTANCE_DIR)) // 2,
    "num_class_images": len(os.listdir(CLASS_DIR)) // 2,
    "checkpoint_size_mb": sum(f.stat().st_size for f in Path(OUTPUT_DIR).rglob('*') if f.is_file()) / (1024 * 1024),
    "timestamp": datetime.now().isoformat(),
}

info_path = os.path.join(OUTPUT_DIR, "training_info.json")
with open(info_path, "w", encoding="utf-8") as f:
    json.dump(training_info, f, indent=2)

print("Training info:")
for key, value in training_info.items():
    print(f"  {key}: {value}")
