
# Text-to-Pic: Lightweight Text-to-Image Trainer
Build and train a tiny text-to-image generator that fits inside ~1 GB RAM at inference while still being able to learn from multiple captioned image corpora. The notebook downloads or ingests several lightweight datasets, sets up a compact diffusion backbone, and walks through training, evaluation, and export steps you can run locally (RTX 3070 Ti) before deploying to your site.



## Notebook Map
1. Environment + dependency prep
2. Config (model size, datasets, logging)
3. Data ingest utilities (Hugging Face auto-discovery + local folders)
4. Tiny UNet + text encoder wiring with `diffusers`
5. Training loop w/ `accelerate`
6. Sampling utilities + qualitative eval grid
7. Export hints for 1 GB RAM web deployment


In [1]:

import os
import platform
from pathlib import Path
import torch

def describe_env():
    cwd = Path.cwd()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Working dir: {cwd}")
    print(f"Python: {platform.python_version()} | PyTorch: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)} | capability {torch.cuda.get_device_capability(0)}")
        print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    else:
        print("CUDA not detected - training will fall back to CPU (slow).")

describe_env()


Working dir: z:\Code
Python: 3.11.0 | PyTorch: 2.5.1+cu121
CUDA device: NVIDIA GeForce RTX 3070 Ti | capability (8, 6)
Total VRAM: 8.00 GB


In [2]:

# If the environment is missing dependencies, uncomment the next cell and run it once.
# %pip install -q --upgrade torch torchvision diffusers transformers datasets accelerate Pillow einops matplotlib tensorboard huggingface_hub



## 1. Configuration
Tweak the dataclasses below to control dataset mix, image size, training length, data discovery, and deployment-specific knobs. Defaults keep memory small (48x48 RGB outputs, Bert-tiny text encoder, mini UNet) so the exported weights load within ~900 MB when kept in fp16.


In [3]:
from dataclasses import dataclass, field, asdict
from typing import List, Optional

@dataclass
class DatasetSpec:
    name: str
    subset: Optional[str] = None
    split: str = "train"
    image_column: str = "auto"
    caption_column: str = "auto"
    max_samples: Optional[int] = None
    weight: float = 1.0
    type: str = "huggingface"  # or "imagefolder"
    local_dir: Optional[str] = None
    streaming: bool = False
    trust_remote_code: bool = False

@dataclass
class TrainingConfig:
    project_name: str = "text-to-pic"
    output_dir: str = "outputs/text_to_pic"
    checkpoint_dir: str = "checkpoints/text_to_pic"
    sample_dir: str = "outputs/text_to_pic/samples"
    data_cache_dir: str = "data/hf_text2pic"
    seed: int = 42
    image_size: int = 128
    train_batch_size: int = 8
    gradient_accumulation_steps: int = 2
    num_epochs: int = 32
    max_train_steps: int = 32000
    learning_rate: float = 5e-5
    lr_warmup_steps: int = 2000
    weight_decay: float = 5e-3
    mixed_precision: str = "fp16"  # fp16/bf16/no
    num_workers: int = 0
    max_grad_norm: float = 0.8
    num_inference_steps: int = 40
    guidance_scale: float = 5.5
    save_every: int = 2000
    eval_every: int = 4000
    max_prompt_length: int = 77
    tokenizer_name: str = "prajjwal1/bert-tiny"
    text_encoder_name: str = "prajjwal1/bert-tiny"
    noise_steps: int = 1000
    use_auto_dataset_search: bool = True
    auto_dataset_search: str = "text-to-image"
    auto_dataset_limit: int = 6
    auto_dataset_max_samples: int = 8000
    dataset_specs: List[DatasetSpec] = field(default_factory=lambda: [
        DatasetSpec(
            name="lambdalabs/naruto-blip-captions",
            image_column="image",
            caption_column="text",
            max_samples=12000,
            weight=0.1,
        ),
        DatasetSpec(
            name="poloclub/diffusiondb",
            subset="2m_first_1k",
            split="train",
            image_column="image",
            caption_column="prompt",
            max_samples=20000,
            weight=0.45,
            streaming=True,
            trust_remote_code=True,
        ),
        DatasetSpec(
            name="conceptual_captions",
            split="train",
            image_column="image",
            caption_column="caption",
            max_samples=25000,
            weight=0.45,
        ),
    ])
    eval_prompts: List[str] = field(default_factory=lambda: [
        "a cozy watercolor cabin in the woods, warm light",
        "a futuristic city skyline at dusk, cinematic lighting",
        "a bowl of ramen sketched in pastel art style",
        "an astronaut riding a horse through neon fog",
        "a vibrant street market in Marrakech, detailed",
        "a majestic lion resting on a rock, photorealistic",
        "a fantasy landscape with floating islands, vivid colors",
        "a close-up portrait of a cyberpunk character, dramatic lighting",
        "a serene beach at sunrise, soft pastel colors",
        "mario hitting a baseball with pikachu cheering in the background",
        "a steampunk airship flying over a bustling city, intricate details",
        "a macro photo of dew on a purple flower, ultra sharp",
    ])

config = TrainingConfig()
print(f"Configuration loaded. Image size: {config.image_size}px, total datasets: {len(config.dataset_specs)}")


Configuration loaded. Image size: 128px, total datasets: 3


In [4]:

import math
import random
from pathlib import Path

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def prepare_dirs(cfg: TrainingConfig):
    for folder in [cfg.output_dir, cfg.checkpoint_dir, cfg.sample_dir, cfg.data_cache_dir]:
        Path(folder).mkdir(parents=True, exist_ok=True)

set_seed(config.seed)
prepare_dirs(config)



## 2. Data ingestion helpers
Supports:
- HuggingFace datasets (multiple subsets) with automatic discovery + column detection
- Local folders staged under `data/` via the `imagefolder` loader
- Automatic download + caching of open datasets inside `config.data_cache_dir` so the repo holds everything needed for offline training
All datasets are normalized to `{ "image": PIL.Image, "caption": str, "source": str }` so the PyTorch dataset wrapper can blend them.


In [5]:

from pathlib import Path
import json
from datasets import load_dataset, concatenate_datasets, Dataset, DownloadConfig, load_from_disk
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset as TorchDataset, DataLoader
from PIL import Image
from transformers import AutoTokenizer

hf_download_config = DownloadConfig(resume_download=True, max_retries=3)


def resolve_cache_dir(spec: DatasetSpec, cfg: TrainingConfig) -> Path:
    slug = spec.name.replace('/', '_')
    if spec.subset:
        slug += f"_{spec.subset}"
    cache_dir = Path(cfg.data_cache_dir) / slug
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def infer_columns(dataset, spec: DatasetSpec):
    image_col = spec.image_column
    caption_col = spec.caption_column
    if (image_col or '').lower() == 'auto':
        image_col = next((name for name, feature in dataset.features.items() if feature.__class__.__name__.lower() == 'image' or getattr(feature, '_type', '') == 'Image'), None)
    if (caption_col or '').lower() == 'auto':
        caption_candidates = [
            name for name, feature in dataset.features.items()
            if getattr(feature, 'dtype', '') == 'string' or feature.__class__.__name__ == 'Value'
        ]
        caption_col = caption_candidates[0] if caption_candidates else None
    if not image_col or not caption_col:
        raise ValueError(f"Unable to auto-detect columns for {spec.name}. Please set image_column/caption_column explicitly.")
    return image_col, caption_col


def load_spec_dataset(spec: DatasetSpec, cfg: TrainingConfig):
    cache_dir = resolve_cache_dir(spec, cfg)
    if spec.type == "imagefolder":
        if not spec.local_dir:
            raise ValueError(f"local_dir must be set for imagefolder spec {spec.name}")
        ds = load_dataset("imagefolder", data_dir=spec.local_dir, split=spec.split)
        return ds, cache_dir

    load_kwargs = dict(
        split=spec.split,
        streaming=spec.streaming,
        cache_dir=str(cache_dir),
        trust_remote_code=spec.trust_remote_code,
    )
    if not spec.streaming:
        load_kwargs.update(download_config=hf_download_config, keep_in_memory=False)
    print(f"Downloading {spec.name} ({spec.split}) -> {cache_dir}")
    ds = load_dataset(spec.name, spec.subset, **load_kwargs)
    if spec.streaming:
        max_take = spec.max_samples or 2000
        ds = Dataset.from_list(list(ds.take(max_take)))
    return ds, cache_dir


def auto_discover_datasets(cfg: TrainingConfig):
    if not cfg.use_auto_dataset_search or cfg.auto_dataset_limit <= 0:
        return []
    try:
        from huggingface_hub import list_datasets, DatasetFilter
    except ImportError:
        print("huggingface_hub missing - skip auto dataset discovery.")
        return []
    hf_filter = DatasetFilter(task_categories=["text-to-image", "image-to-text"])
    items = list_datasets(
        search=cfg.auto_dataset_search,
        filter=hf_filter,
        limit=cfg.auto_dataset_limit,
    )
    extra_specs = []
    existing_names = {spec.name for spec in cfg.dataset_specs}
    for item in items:
        if item.id in existing_names:
            continue
        extra_specs.append(
            DatasetSpec(
                name=item.id,
                image_column="auto",
                caption_column="auto",
                max_samples=cfg.auto_dataset_max_samples,
                weight=0.15,
            )
        )
    return extra_specs


auto_specs = auto_discover_datasets(config)
if auto_specs:
    config.dataset_specs.extend(auto_specs)
    print(f"Auto-added {len(auto_specs)} datasets from HuggingFace search '{config.auto_dataset_search}'. Total now: {len(config.dataset_specs)}")
else:
    print("No extra datasets added via HuggingFace search (check config.auto_dataset_search or install huggingface_hub)")


def load_and_merge_datasets(cfg: TrainingConfig):
    dataset_pieces = []
    stats = []
    for spec in cfg.dataset_specs:
        if spec.weight <= 0:
            continue
        ds, cache_dir = load_spec_dataset(spec, cfg)
        image_col, caption_col = infer_columns(ds, spec)
        keep_map = {image_col: "image", caption_col: "caption"}
        missing = [key for key in keep_map if key not in ds.column_names]
        if missing:
            print(f"Skipping {spec.name} â€” missing columns {missing}")
            continue
        drop_cols = [c for c in ds.column_names if c not in keep_map]
        if drop_cols:
            ds = ds.remove_columns(drop_cols)
        ds = ds.rename_columns(keep_map)
        if spec.max_samples and not spec.streaming and spec.max_samples < len(ds):
            ds = ds.shuffle(seed=cfg.seed).select(range(spec.max_samples))
        size = len(ds)
        ds = ds.add_column("source", [spec.name] * size)
        ds = ds.add_column("weight", [spec.weight] * size)
        dataset_pieces.append(ds)
        stats.append({"name": spec.name, "samples": size, "weight": spec.weight, "cache_dir": str(cache_dir)})
    if not dataset_pieces:
        raise RuntimeError("No datasets were loaded. Please check DatasetSpec entries or auto-discovery settings.")
    merged = concatenate_datasets(dataset_pieces).shuffle(seed=cfg.seed)
    print(f"Total merged samples: {len(merged)}")
    return merged, stats


tokenizer = AutoTokenizer.from_pretrained(
    config.tokenizer_name,
    model_max_length=config.max_prompt_length,
    padding="max_length",
    truncation=True,
)

def build_transforms(cfg: TrainingConfig):
    return transforms.Compose([
        transforms.Resize(cfg.image_size, interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(cfg.image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])


class TextImageDataset(TorchDataset):
    def __init__(self, hf_dataset, tokenizer, cfg: TrainingConfig):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.cfg = cfg
        self.transforms = build_transforms(cfg)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[int(idx)]
        image = sample["image"]
        if not isinstance(image, Image.Image):
            image = Image.open(image).convert("RGB")
        else:
            image = image.convert("RGB")
        caption = sample.get("caption") or "an image"
        tokens = self.tokenizer(
            caption,
            max_length=self.cfg.max_prompt_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "pixel_values": self.transforms(image),
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "caption": caption,
            "source": sample.get("source", "unknown"),
        }


merged_cache_dir = Path(config.data_cache_dir) / "merged_dataset"
stats_cache_path = merged_cache_dir / "stats.json"

if merged_cache_dir.exists() and (merged_cache_dir / "dataset_info.json").exists():
    print(f"Loading cached merged dataset from {merged_cache_dir}")
    raw_dataset = load_from_disk(str(merged_cache_dir))
    if stats_cache_path.exists():
        dataset_stats = json.loads(stats_cache_path.read_text())
    else:
        dataset_stats = []
else:
    raw_dataset, dataset_stats = load_and_merge_datasets(config)
    merged_cache_dir.mkdir(parents=True, exist_ok=True)
    raw_dataset.save_to_disk(str(merged_cache_dir))
    stats_cache_path.write_text(json.dumps(dataset_stats, indent=2))

train_dataset = TextImageDataset(raw_dataset, tokenizer, config)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.train_batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=True,
)

print("Dataset mix:")
for entry in dataset_stats:
    print(f" - {entry['name']}: {entry['samples']} samples (weight {entry['weight']}), cached in {entry['cache_dir']}")
print(f"Batches/epoch (approx): {len(train_dataloader)}")


  from .autonotebook import tqdm as notebook_tqdm


huggingface_hub missing - skip auto dataset discovery.
No extra datasets added via HuggingFace search (check config.auto_dataset_search or install huggingface_hub)
Loading cached merged dataset from data\hf_text2pic\merged_dataset
Dataset mix:
 - lambdalabs/naruto-blip-captions: 1221 samples (weight 0.2), cached in data\hf_text2pic\lambdalabs_naruto-blip-captions
 - poloclub/diffusiondb: 1000 samples (weight 0.3), cached in data\hf_text2pic\poloclub_diffusiondb_2m_first_1k
Batches/epoch (approx): 277



## 3. Model and scheduler setup
We use a tiny text encoder (BERT-tiny) and a heavily down-scaled `UNet2DConditionModel` from diffusers. This keeps inference memory under 1 GB (fp16 weights about 140 MB plus runtime buffers). Training uses `Accelerate` so you can scale to multi-GPU later without code changes.


In [6]:

from accelerate import Accelerator
from diffusers import UNet2DConditionModel, DDPMScheduler
from transformers import AutoModel, get_cosine_schedule_with_warmup
import torch.nn.functional as F
from torch.optim import AdamW

accelerator = Accelerator(
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    mixed_precision=config.mixed_precision,
)

device = accelerator.device
precision_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else (
    torch.bfloat16 if accelerator.mixed_precision == "bf16" else torch.float32
)

text_encoder = AutoModel.from_pretrained(
    config.text_encoder_name,
    use_safetensors=True,
    torch_dtype=precision_dtype,
)
text_encoder.requires_grad_(False)
text_encoder.to(device=device, dtype=precision_dtype)
text_encoder.eval()

noise_scheduler = DDPMScheduler(
    num_train_timesteps=config.noise_steps,
    beta_schedule="squaredcos_cap_v2",
    prediction_type="epsilon",
)

unet = UNet2DConditionModel(
    sample_size=config.image_size,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 192),
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
    attention_head_dim=4,
    cross_attention_dim=text_encoder.config.hidden_size,
)

optimizer = AdamW(unet.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=config.max_train_steps,
)

unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

from pprint import pprint
accelerator.print("Training with configuration:")
accelerator.print(pprint(asdict(config)))


`torch_dtype` is deprecated! Use `dtype` instead!


Training with configuration:
{'auto_dataset_limit': 6,
 'auto_dataset_max_samples': 8000,
 'auto_dataset_search': 'text-to-image',
 'checkpoint_dir': 'checkpoints/text_to_pic',
 'data_cache_dir': 'data/hf_text2pic',
 'dataset_specs': [{'caption_column': 'text',
                    'image_column': 'image',
                    'local_dir': None,
                    'max_samples': 12000,
                    'name': 'lambdalabs/naruto-blip-captions',
                    'split': 'train',
                    'streaming': False,
                    'subset': None,
                    'trust_remote_code': False,
                    'type': 'huggingface',
                    'weight': 0.1},
                   {'caption_column': 'prompt',
                    'image_column': 'image',
                    'local_dir': None,
                    'max_samples': 20000,
                    'name': 'poloclub/diffusiondb',
                    'split': 'train',
                    'streaming': True,
     


## 4. Training utilities
Includes checkpoint saving/loading, preview sampling during training, and the main loop. To actually train set `RUN_TRAINING = True` in the cell after the loop definition.


In [7]:
from torchvision.utils import make_grid, save_image
import time

def unwrap_model_safely(model):
    try:
        return accelerator.unwrap_model(model)
    except (ImportError, ModuleNotFoundError, RuntimeError, AttributeError):
        return getattr(model, "module", model)

@torch.no_grad()
def sample_prompts(prompts, model, tokenizer, text_encoder, scheduler, cfg: TrainingConfig, device, guidance_scale=None, num_inference_steps=None):
    model = unwrap_model_safely(model)
    model.eval()
    scheduler = DDPMScheduler.from_config(scheduler.config)
    num_inference_steps = num_inference_steps or cfg.num_inference_steps
    scheduler.set_timesteps(num_inference_steps, device=device)
    guidance_scale = guidance_scale or cfg.guidance_scale
    batch_size = len(prompts)
    latents = torch.randn((batch_size, 3, cfg.image_size, cfg.image_size), device=device, dtype=precision_dtype)

    text_inputs = tokenizer(prompts, max_length=cfg.max_prompt_length, padding="max_length", truncation=True, return_tensors="pt")
    text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
    text_embeddings = text_encoder(**text_inputs).last_hidden_state

    uncond_inputs = tokenizer([" "] * batch_size, max_length=cfg.max_prompt_length, padding="max_length", return_tensors="pt")
    uncond_inputs = {k: v.to(device) for k, v in uncond_inputs.items()}
    uncond_embeddings = text_encoder(**uncond_inputs).last_hidden_state

    for t in scheduler.timesteps:
        latent_model_input = torch.cat([latents, latents], dim=0)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        encoder_hidden_states = torch.cat([uncond_embeddings, text_embeddings], dim=0)
        noise_pred = model(latent_model_input, t, encoder_hidden_states=encoder_hidden_states).sample
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    imgs = (latents * 0.5 + 0.5).clamp(0, 1)
    return imgs.cpu()


def build_checkpoint_state(step, epoch, best_loss, model, optimizer, lr_scheduler, cfg: TrainingConfig):
    unwrapped = unwrap_model_safely(model)
    return {
        "unet": unwrapped.state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict(),
        "step": step,
        "epoch": epoch,
        "best_loss": best_loss,
        "config": asdict(cfg),
    }


def persist_checkpoint(step, epoch, best_loss, model, optimizer, lr_scheduler, cfg: TrainingConfig, tag: str, state=None, announce: bool = True):
    accelerator.wait_for_everyone()
    if state is None:
        state = build_checkpoint_state(step, epoch, best_loss, model, optimizer, lr_scheduler, cfg)
    ckpt_dir = Path(cfg.checkpoint_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path = ckpt_dir / f"{tag}.pt"
    accelerator.save(state, ckpt_path)
    if announce and accelerator.is_main_process:
        accelerator.print(f"Saved checkpoint '{tag}' to {ckpt_path}")
    return state


def resume_from_checkpoint(model, optimizer, lr_scheduler, cfg: TrainingConfig, checkpoint_path: str | None = None):
    ckpt_path = Path(checkpoint_path) if checkpoint_path else Path(cfg.checkpoint_dir) / "latest.pt"
    if not ckpt_path.exists():
        return 0, 0, float("inf"), None
    accelerator.print(f"Resuming from {ckpt_path}")
    state = torch.load(ckpt_path, map_location="cpu")
    unwrap_model_safely(model).load_state_dict(state["unet"])
    optimizer.load_state_dict(state["optimizer"])
    lr_scheduler.load_state_dict(state["lr_scheduler"])
    return state.get("step", 0), state.get("epoch", 0), state.get("best_loss", float("inf")), ckpt_path


def train_loop(resume_from: str | None = None):
    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
    accelerator.print(f"Starting training. Effective batch size: {total_batch_size}")

    checkpoint_dir = Path(config.checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    global_step = 0
    best_loss = float("inf")
    start_epoch = 0

    resume_target = resume_from
    if resume_target is None:
        latest_candidate = checkpoint_dir / "latest.pt"
        if latest_candidate.exists():
            resume_target = str(latest_candidate)

    if resume_target:
        step, epoch, best_loss, used_path = resume_from_checkpoint(unet, optimizer, lr_scheduler, config, resume_target)
        if used_path is not None:
            global_step = step
            start_epoch = epoch
            accelerator.print(f"Resumed state -> epoch {epoch}, step {step}, best_loss {best_loss:.4f}")
        else:
            accelerator.print(f"Requested resume checkpoint '{resume_target}' not found. Starting fresh.")

    batches_per_epoch = len(train_dataloader)
    if batches_per_epoch == 0:
        accelerator.print("Train dataloader returned 0 batches. Check dataset configuration.")
        return
    progress_interval = max(1, batches_per_epoch // 20)
    report_every_secs = 30
    dataloader_wait_threshold = 20

    stop_early = False
    for epoch in range(start_epoch, config.num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        epoch_start = time.perf_counter()
        last_report = epoch_start
        if accelerator.is_main_process:
            accelerator.print(f"Epoch {epoch + 1}/{config.num_epochs} started ({batches_per_epoch} batches)")
        data_iter = iter(train_dataloader)

        for batch_idx in range(batches_per_epoch):
            wait_start = time.perf_counter()
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(train_dataloader)
                batch = next(data_iter)
            wait_duration = time.perf_counter() - wait_start
            if accelerator.is_main_process and wait_duration >= dataloader_wait_threshold:
                accelerator.print(
                    f"Epoch {epoch + 1}/{config.num_epochs} batch {batch_idx + 1}: dataloader stalled {wait_duration:.1f}s"
                )

            if global_step >= config.max_train_steps:
                stop_early = True
                break
            with accelerator.accumulate(unet):
                pixel_values = batch["pixel_values"].to(device, dtype=precision_dtype)
                noise = torch.randn_like(pixel_values)
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (pixel_values.shape[0],), device=device
                ).long()
                noisy_images = noise_scheduler.add_noise(pixel_values, noise, timesteps).to(pixel_values.dtype)

                encoder_input = {
                    "input_ids": batch["input_ids"].to(device),
                    "attention_mask": batch["attention_mask"].to(device),
                }
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(**encoder_input).last_hidden_state.to(precision_dtype)
                model_pred = unet(noisy_images, timesteps, encoder_hidden_states=encoder_hidden_states).sample
                loss = F.mse_loss(model_pred.float(), noise.float())

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            step_loss = loss.detach().item()
            epoch_loss += step_loss
            num_batches += 1
            global_step += 1

            time_since_last = time.perf_counter() - last_report
            should_report = (
                accelerator.is_main_process
                and (
                    (batch_idx + 1) % progress_interval == 0
                    or batch_idx == 0
                    or (batch_idx + 1) == batches_per_epoch
                    or time_since_last >= report_every_secs
                )
            )
            if should_report:
                progress = (batch_idx + 1) / batches_per_epoch
                elapsed = time.perf_counter() - epoch_start
                eta = (elapsed / progress - elapsed) if progress > 0 else 0.0
                accelerator.print(
                    f"Epoch {epoch + 1}/{config.num_epochs} [{batch_idx + 1}/{batches_per_epoch} | {progress * 100:.1f}%] "
                    f"global step {global_step} | loss {step_loss:.4f} | elapsed {elapsed/60:.1f}m ETA {eta/60:.1f}m"
                )
                last_report = time.perf_counter()
            if accelerator.is_main_process and config.save_every and global_step % config.save_every == 0:
                state = persist_checkpoint(global_step, epoch, best_loss, unet, optimizer, lr_scheduler, config, tag=f"step_{global_step:06d}")
                persist_checkpoint(global_step, epoch, best_loss, unet, optimizer, lr_scheduler, config, tag="latest", state=state, announce=False)

            if accelerator.is_main_process and config.eval_every and global_step % config.eval_every == 0:
                imgs = sample_prompts(config.eval_prompts, unet, tokenizer, text_encoder, noise_scheduler, config, device)
                grid = make_grid(imgs, nrow=min(4, len(imgs)))
                out_path = Path(config.sample_dir) / f"preview_step_{global_step:06d}.png"
                save_image(grid, out_path)
                accelerator.print(f"Saved preview grid to {out_path}")

        if num_batches == 0:
            accelerator.print(f"Epoch {epoch + 1} had no batches; check dataset configuration.")
            continue

        avg_epoch_loss = epoch_loss / num_batches
        epoch_time = time.perf_counter() - epoch_start
        if accelerator.is_main_process:
            state = persist_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer, lr_scheduler, config, tag=f"epoch_{epoch + 1:03d}")
            persist_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer, lr_scheduler, config, tag="latest", state=state, announce=False)
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                persist_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer, lr_scheduler, config, tag="best", state=state, announce=False)
                accelerator.print(f"New best loss {best_loss:.4f} at epoch {epoch + 1}")
        accelerator.print(f"Epoch {epoch + 1}/{config.num_epochs} complete | avg loss {avg_epoch_loss:.4f} | time {epoch_time/60:.1f}m")

        if stop_early:
            break

    accelerator.print("Training complete.")


In [8]:

RUN_TRAINING = False # <- flip to True when ready
if RUN_TRAINING:
    train_loop()
else:
    print("Training loop is defined but not running. Set RUN_TRAINING = True to start training.")


Training loop is defined but not running. Set RUN_TRAINING = True to start training.


In [9]:
import time

@torch.no_grad()
def run_inference(prompts=None, checkpoint: str = "best", num_inference_steps=None, guidance_scale=None, save_name: str | None = None):
    ckpt_dir = Path(config.checkpoint_dir)
    if checkpoint in {"best", "latest"}:
        ckpt_path = ckpt_dir / f"{checkpoint}.pt"
    else:
        ckpt_path = Path(checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint {ckpt_path} not found")

    state = torch.load(ckpt_path, map_location="cpu")
    unwrap_model_safely(unet).load_state_dict(state["unet"])
    prompts = prompts or config.eval_prompts

    images = sample_prompts(
        prompts,
        unet,
        tokenizer,
        text_encoder,
        noise_scheduler,
        config,
        device,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
    )
    grid = make_grid(images, nrow=min(4, len(images)))
    stamp = save_name or f"{ckpt_path.stem}_{int(time.time())}"
    out_path = Path(config.sample_dir) / f"inference_{stamp}.png"
    save_image(grid, out_path)
    accelerator.print(f"Saved inference grid to {out_path}")
    return out_path

# Example usage (uncomment to run a quick check):
# run_inference()


In [10]:
from IPython.display import display
from torchvision.transforms import functional as TF

SINGLE_PROMPT_DIR = Path(config.sample_dir) / "single_prompts"
SINGLE_PROMPT_DIR.mkdir(parents=True, exist_ok=True)

def generate_prompt_image(prompt: str, checkpoint: str = "best", guidance_scale=None, num_inference_steps=None, save_name: str | None = None):
    ckpt_dir = Path(config.checkpoint_dir)
    if checkpoint in {"best", "latest"}:
        ckpt_path = ckpt_dir / f"{checkpoint}.pt"
    else:
        ckpt_path = Path(checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint {ckpt_path} not found")

    state = torch.load(ckpt_path, map_location="cpu")
    unwrap_model_safely(unet).load_state_dict(state["unet"])

    images = sample_prompts(
        [prompt],
        unet,
        tokenizer,
        text_encoder,
        noise_scheduler,
        config,
        device,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
    )
    image = images[0]
    save_slug = save_name or f"prompt_{int(time.time())}"
    out_path = SINGLE_PROMPT_DIR / f"{save_slug}.png"
    save_image(image, out_path)
    accelerator.print(f"Saved single-prompt image to {out_path}")
    display(TF.to_pil_image(image))
    return out_path

# Example usage:
generate_prompt_image("a retro myspace-style selfie with neon lighting", checkpoint="best")


  state = torch.load(ckpt_path, map_location="cpu")


Saved single-prompt image to outputs\text_to_pic\samples\single_prompts\prompt_1763121310.png


NameError: name 'TF' is not defined

In [None]:
ï»¿from torchvision.transforms import functional as TF

FINETUNE_ROOT = Path("finetune")
FINETUNE_CHECKPOINT_DIR = Path(config.checkpoint_dir) / "finetune"
FINETUNE_SAMPLE_DIR = Path(config.sample_dir) / "finetune"
FINETUNE_DEPLOY_DIR = Path(config.output_dir) / "deployment" / "finetune"
ALLOWED_FINETUNE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}

class FinetuneFolderDataset(TorchDataset):
    def __init__(self, root: Path, tokenizer, cfg: TrainingConfig, default_caption: str = "nostalgic myspace portrait"):
        root = Path(root)
        files = sorted([p for p in root.rglob("*") if p.suffix.lower() in ALLOWED_FINETUNE_SUFFIXES])
        if not files:
            raise ValueError(f"No finetune images found in {root}. Add a few JPG/PNG/WebP files.")
        self.files = files
        self.tokenizer = tokenizer
        self.cfg = cfg
        self.default_caption = default_caption
        self.transforms = transforms.Compose([
            transforms.Lambda(lambda img: TF.center_crop(img, min(img.size))),
            transforms.Resize(cfg.image_size, interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        image = Image.open(path).convert("RGB")
        image = self.transforms(image)
        caption = path.stem.replace("_", " ") or self.default_caption
        tokens = tokenizer(
            caption,
            max_length=self.cfg.max_prompt_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "pixel_values": image,
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "caption": caption,
            "path": str(path),
        }

def build_finetune_dataloader(root=FINETUNE_ROOT, default_caption="nostalgic myspace portrait"):
    dataset = FinetuneFolderDataset(root, tokenizer, config, default_caption=default_caption)
    loader = DataLoader(
        dataset,
        batch_size=max(1, config.train_batch_size // 2),
        shuffle=True,
        num_workers=0,
        drop_last=len(dataset) > 1,
    )
    accelerator.print(f"Finetune dataset -> {len(dataset)} images from {root}")
    return dataset, loader

def save_finetune_checkpoint(step, epoch, best_loss, model, optimizer, scheduler, tag: str, state=None):
    FINETUNE_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    state_dict = state if state is not None else {
        "unet": unwrap_model_safely(model).state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": scheduler.state_dict(),
        "step": step,
        "epoch": epoch,
        "best_loss": best_loss,
    }
    ckpt_path = FINETUNE_CHECKPOINT_DIR / f"{tag}.pt"
    accelerator.save(state_dict, ckpt_path)
    return state_dict, ckpt_path

def run_finetune(epochs: int = 6, resume: str | None = None, base_checkpoint: str | None = None):
    dataset, dataloader = build_finetune_dataloader()
    total_batches = len(dataloader)
    if total_batches == 0:
        raise ValueError("Finetune dataloader has zero batches.")

    FINETUNE_SAMPLE_DIR.mkdir(parents=True, exist_ok=True)
    base_candidates = []
    if resume:
        base_candidates.append(Path(resume))
    if base_checkpoint:
        base_candidates.append(Path(base_checkpoint))
    base_candidates.append(FINETUNE_CHECKPOINT_DIR / "finetune_best.pt")
    base_candidates.append(Path(config.checkpoint_dir) / "best.pt")
    base_candidates.append(Path(config.checkpoint_dir) / "latest.pt")
    ckpt_path = next((p for p in base_candidates if p and p.exists()), None)
    if ckpt_path:
        state = torch.load(ckpt_path, map_location="cpu")
        unwrap_model_safely(unet).load_state_dict(state["unet"])
        accelerator.print(f"Loaded base weights from {ckpt_path}")
    else:
        accelerator.print("No base checkpoint found; finetuning current weights.")

    ft_lr = config.learning_rate * 0.25
    optimizer_ft = AdamW(unet.parameters(), lr=ft_lr, weight_decay=config.weight_decay)
    total_steps = epochs * total_batches
    scheduler_ft = get_cosine_schedule_with_warmup(
        optimizer_ft,
        num_warmup_steps=max(10, total_steps // 20),
        num_training_steps=total_steps,
    )

    global_step = 0
    best_loss = float("inf")
    best_path = FINETUNE_CHECKPOINT_DIR / "finetune_best.pt"
    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_start = time.perf_counter()
        if accelerator.is_main_process:
            accelerator.print(f"[Finetune] Epoch {epoch + 1}/{epochs} ({total_batches} batches)")
        for batch_idx, batch in enumerate(dataloader):
            with accelerator.accumulate(unet):
                pixel_values = batch["pixel_values"].to(device, dtype=precision_dtype)
                noise = torch.randn_like(pixel_values)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (pixel_values.shape[0],), device=device).long()
                noisy_images = noise_scheduler.add_noise(pixel_values, noise, timesteps).to(pixel_values.dtype)
                encoder_input = {
                    "input_ids": batch["input_ids"].to(device),
                    "attention_mask": batch["attention_mask"].to(device),
                }
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(**encoder_input).last_hidden_state.to(precision_dtype)
                model_pred = unet(noisy_images, timesteps, encoder_hidden_states=encoder_hidden_states).sample
                loss = F.mse_loss(model_pred.float(), noise.float())

                accelerator.backward(loss)
                optimizer_ft.step()
                scheduler_ft.step()
                optimizer_ft.zero_grad()

            step_loss = loss.detach().item()
            epoch_loss += step_loss
            global_step += 1
            if accelerator.is_main_process:
                accelerator.print(
                    f"[Finetune] Epoch {epoch + 1}/{epochs} step {global_step}/{total_steps} | batch {batch_idx + 1}/{total_batches} | loss {step_loss:.4f}"
                )

        avg_loss = epoch_loss / total_batches
        state_dict, _ = save_finetune_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer_ft, scheduler_ft, tag=f"epoch_{epoch + 1:03d}")
        _, latest_path = save_finetune_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer_ft, scheduler_ft, tag="latest", state=state_dict)
        if avg_loss < best_loss:
            best_loss = avg_loss
            _, best_path = save_finetune_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer_ft, scheduler_ft, tag="finetune_best", state=state_dict)
            accelerator.print(f"[Finetune] New best loss {best_loss:.4f} at epoch {epoch + 1}")
        accelerator.print(f"[Finetune] Epoch {epoch + 1}/{epochs} avg loss {avg_loss:.4f} | time {(time.perf_counter() - epoch_start)/60:.1f}m")

        imgs = sample_prompts(config.eval_prompts[:4], unet, tokenizer, text_encoder, noise_scheduler, config, device)
        grid = make_grid(imgs, nrow=min(4, len(imgs)))
        out_path = FINETUNE_SAMPLE_DIR / f"finetune_epoch_{epoch + 1:03d}.png"
        save_image(grid, out_path)
        accelerator.print(f"[Finetune] Saved preview to {out_path}")

    accelerator.print(f"Finetune complete. Latest checkpoint: {latest_path}")
    return latest_path, best_path

def export_finetune_model(checkpoint: str | None = None):
    FINETUNE_DEPLOY_DIR.mkdir(parents=True, exist_ok=True)
    if checkpoint is None:
        checkpoint = FINETUNE_CHECKPOINT_DIR / "finetune_best.pt"
    ckpt_path = Path(checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint {ckpt_path} not found.")
    state = torch.load(ckpt_path, map_location="cpu")
    unwrap_model_safely(unet).load_state_dict(state["unet"])
    traceable_cls = globals().get("TraceableUNet")
    if traceable_cls is None:
        class TraceableUNet(torch.nn.Module):
            def __init__(self, unet):
                super().__init__()
                self.unet = unet

            def forward(self, sample, timestep, encoder_hidden_states):
                return self.unet(sample, timestep, encoder_hidden_states=encoder_hidden_states).sample

        traceable_cls = TraceableUNet
    model_to_export = unwrap_model_safely(unet).to(torch.float32).cpu()
    model_to_export.eval()
    traceable = traceable_cls(model_to_export)
    dtype = next(model_to_export.parameters()).dtype
    dummy_latents = torch.randn(1, 3, config.image_size, config.image_size, dtype=dtype)
    dummy_timestep = torch.tensor([0], dtype=torch.long)
    dummy_hidden = torch.randn(1, config.max_prompt_length, text_encoder.config.hidden_size, dtype=dtype)
    with torch.inference_mode(), torch.autocast("cpu", enabled=False):
        traced = torch.jit.trace(traceable, (dummy_latents, dummy_timestep, dummy_hidden), strict=False)
        traced_path = FINETUNE_DEPLOY_DIR / "finetunemodel.ts"
        traced.save(traced_path)
    accelerator.print(f"Exported finetune TorchScript model to {traced_path}")
    return traced_path

# Example usage:
# latest_path, best_path = run_finetune(epochs=6)
# export_finetune_model(best_path)


In [None]:
ï»¿import copy

STRONGMODEL_DATASETS = copy.deepcopy(config.dataset_specs)
STRONGMODEL_EVAL_PROMPTS = list(dict.fromkeys(config.eval_prompts + [
    "a detailed charcoal sketch of a cat lounging on a sofa",
    "a busy city street at noon, photorealistic",
    "a vibrant watercolor of mountains and rivers",
    "a child's drawing of a rocket ship, crayon texture",
    "a close-up of a succulent plant, macro lens",
    "an 8-bit pixel art wizard casting a spell",
]))

STRONGMODEL_CONFIG = TrainingConfig(
    project_name=f"{config.project_name}-strong",
    output_dir="outputs/strong_text_to_pic",
    checkpoint_dir="checkpoints/strong_text_to_pic",
    sample_dir="outputs/strong_text_to_pic/samples",
    data_cache_dir=config.data_cache_dir,
    seed=config.seed,
    image_size=256,
    train_batch_size=4,
    gradient_accumulation_steps=4,
    num_epochs=100,
    max_train_steps=0,
    learning_rate=2e-5,
    lr_warmup_steps=4000,
    weight_decay=1e-2,
    mixed_precision=config.mixed_precision,
    num_workers=0,
    max_grad_norm=0.7,
    num_inference_steps=60,
    guidance_scale=4.5,
    save_every=2000,
    eval_every=4000,
    max_prompt_length=77,
    tokenizer_name="openai/clip-vit-base-patch32",
    text_encoder_name="openai/clip-vit-base-patch32",
    noise_steps=1000,
    use_auto_dataset_search=config.use_auto_dataset_search,
    auto_dataset_search=config.auto_dataset_search,
    auto_dataset_limit=config.auto_dataset_limit,
    auto_dataset_max_samples=config.auto_dataset_max_samples,
    dataset_specs=STRONGMODEL_DATASETS,
    eval_prompts=STRONGMODEL_EVAL_PROMPTS,
)

STRONGMODEL_CHECKPOINT_DIR = Path(STRONGMODEL_CONFIG.checkpoint_dir)
STRONGMODEL_SAMPLE_DIR = Path(STRONGMODEL_CONFIG.sample_dir)
STRONGMODEL_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
STRONGMODEL_SAMPLE_DIR.mkdir(parents=True, exist_ok=True)


def build_strongmodel_components(cfg: TrainingConfig):
    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)
    text_encoder = AutoModel.from_pretrained(
        cfg.text_encoder_name,
        use_safetensors=True,
        torch_dtype=precision_dtype,
    )
    text_encoder.requires_grad_(False)
    text_encoder.to(device=device, dtype=precision_dtype)
    text_encoder.eval()

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=cfg.noise_steps,
        beta_schedule="squaredcos_cap_v2",
        prediction_type="epsilon",
    )

    unet = UNet2DConditionModel(
        sample_size=cfg.image_size,
        in_channels=3,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(256, 320, 512, 640),
        down_block_types=(
            "DownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
        ),
        up_block_types=(
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
            "UpBlock2D",
        ),
        attention_head_dim=8,
        cross_attention_dim=text_encoder.config.hidden_size,
    )
    unet.to(device=device, dtype=precision_dtype)
    return tokenizer, text_encoder, noise_scheduler, unet


def build_strongmodel_dataloader(cfg: TrainingConfig, tokenizer):
    if 'raw_dataset' in globals():
        source_dataset = raw_dataset
    else:
        merged_dir = Path(cfg.data_cache_dir) / "merged_dataset"
        if merged_dir.exists():
            source_dataset = load_from_disk(str(merged_dir))
        else:
            source_dataset, _ = load_and_merge_datasets(cfg)
    dataset = TextImageDataset(source_dataset, tokenizer, cfg)
    loader = DataLoader(
        dataset,
        batch_size=cfg.train_batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    accelerator.print(f"[StrongModel] Dataset ready -> {len(dataset)} samples | {len(loader)} batches/epoch")
    return dataset, loader


def save_strongmodel_checkpoint(step, epoch, best_loss, model, optimizer, scheduler, cfg: TrainingConfig, tag: str, state=None):
    state_dict = state if state is not None else {
        "unet": unwrap_model_safely(model).state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": scheduler.state_dict(),
        "step": step,
        "epoch": epoch,
        "best_loss": best_loss,
        "config": asdict(cfg),
    }
    ckpt_path = STRONGMODEL_CHECKPOINT_DIR / f"{tag}.pt"
    accelerator.save(state_dict, ckpt_path)
    return state_dict, ckpt_path


def train_strongmodel(cfg: TrainingConfig = STRONGMODEL_CONFIG, resume: str | None = None):
    tokenizer, text_encoder, noise_scheduler, unet = build_strongmodel_components(cfg)
    dataset, dataloader = build_strongmodel_dataloader(cfg, tokenizer)

    optimizer = AdamW(unet.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    total_steps = cfg.num_epochs * len(dataloader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max(cfg.lr_warmup_steps, total_steps // 20),
        num_training_steps=total_steps,
    )

    unet, optimizer, dataloader, scheduler = accelerator.prepare(unet, optimizer, dataloader, scheduler)

    start_epoch = 0
    global_step = 0
    best_loss = float("inf")
    resume_path = Path(resume) if resume else STRONGMODEL_CHECKPOINT_DIR / "latest.pt"
    if resume_path.exists():
        state = torch.load(resume_path, map_location="cpu")
        if "unet" in state:
            unwrap_model_safely(unet).load_state_dict(state["unet"])
            optimizer.load_state_dict(state.get("optimizer", optimizer.state_dict()))
            scheduler.load_state_dict(state.get("lr_scheduler", scheduler.state_dict()))
            best_loss = state.get("best_loss", best_loss)
            start_epoch = state.get("epoch", 0)
            global_step = state.get("step", 0)
            accelerator.print(f"[StrongModel] Resumed from {resume_path} @ epoch {start_epoch}, step {global_step}")

    for epoch in range(start_epoch, cfg.num_epochs):
        epoch_loss = 0.0
        epoch_start = time.perf_counter()
        noise_scheduler.set_timesteps(cfg.noise_steps, device=device)
        for batch_idx, batch in enumerate(dataloader, start=1):
            with accelerator.accumulate(unet):
                pixel_values = batch["pixel_values"].to(device, dtype=precision_dtype)
                noise = torch.randn_like(pixel_values)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (pixel_values.shape[0],), device=device).long()
                noisy_images = noise_scheduler.add_noise(pixel_values, noise, timesteps).to(pixel_values.dtype)

                encoder_input = {
                    "input_ids": batch["input_ids"].to(device),
                    "attention_mask": batch["attention_mask"].to(device),
                }
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(**encoder_input).last_hidden_state.to(precision_dtype)

                model_pred = unet(noisy_images, timesteps, encoder_hidden_states=encoder_hidden_states).sample
                loss = F.mse_loss(model_pred.float(), noise.float())

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), cfg.max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            epoch_loss += loss.item()
            global_step += 1
            if accelerator.is_main_process and (batch_idx % 50 == 0 or batch_idx == len(dataloader)):
                accelerator.print(
                    f"[StrongModel] Epoch {epoch + 1}/{cfg.num_epochs} batch {batch_idx}/{len(dataloader)} | step {global_step}/{total_steps} | loss {loss.item():.4f}"
                )

        avg_epoch_loss = epoch_loss / len(dataloader)
        state_dict, latest_path = save_strongmodel_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer, scheduler, cfg, tag="latest")
        save_strongmodel_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer, scheduler, cfg, tag=f"epoch_{epoch + 1:03d}", state=state_dict)
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            save_strongmodel_checkpoint(global_step, epoch + 1, best_loss, unet, optimizer, scheduler, cfg, tag="best", state=state_dict)
            accelerator.print(f"[StrongModel] New best loss {best_loss:.4f} at epoch {epoch + 1}")

        imgs = sample_prompts(
            cfg.eval_prompts[:8],
            unet,
            tokenizer,
            text_encoder,
            noise_scheduler,
            cfg,
            device,
            guidance_scale=cfg.guidance_scale,
            num_inference_steps=cfg.num_inference_steps,
        )
        grid = make_grid(imgs, nrow=min(4, len(imgs)))
        preview_path = STRONGMODEL_SAMPLE_DIR / f"strong_epoch_{epoch + 1:03d}.png"
        save_image(grid, preview_path)
        accelerator.print(f"[StrongModel] Epoch {epoch + 1} complete | avg loss {avg_epoch_loss:.4f} | preview -> {preview_path}")

    accelerator.print("[StrongModel] Training finished. Best checkpoint stored as 'best.pt'.")
    return latest_path, STRONGMODEL_CHECKPOINT_DIR / "best.pt"

# Example usage:
# latest_strong, best_strong = train_strongmodel()



## 5. Sampling and qualitative checks
Run inference with the latest checkpoint (or current weights) to eyeball quality and verify memory usage. Adjust prompts to cover the domains you care about before pushing live.


In [None]:
@torch.no_grad()
def load_checkpoint(path: str):
    state = torch.load(path, map_location=device)
    unwrap_model_safely(unet).load_state_dict(state["unet"])
    accelerator.print(f"Loaded checkpoint {path}")
    return state

ckpt_dir = Path(config.checkpoint_dir)
preferred = [ckpt_dir / "best.pt", ckpt_dir / "latest.pt"]
ckpt_path = next((p for p in preferred if p.exists()), None)
if ckpt_path is None:
    accelerator.print("No checkpoints found yet - sampling with current weights.")
else:
    _state = load_checkpoint(str(ckpt_path))
    accelerator.print(f"Sampling with weights from {ckpt_path.stem}")

preview = sample_prompts(
    config.eval_prompts,
    unet,
    tokenizer,
    text_encoder,
    noise_scheduler,
    config,
    device,
)
preview_path = Path(config.sample_dir) / "latest_preview.png"
save_image(make_grid(preview, nrow=min(4, len(preview))), preview_path)
accelerator.print(f"Saved preview to {preview_path}")


In [None]:
import torch

class TraceableUNet(torch.nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet

    def forward(self, sample, timestep, encoder_hidden_states):
        return self.unet(sample, timestep, encoder_hidden_states=encoder_hidden_states).sample

export_dir = Path(config.output_dir) / "deployment"
export_dir.mkdir(parents=True, exist_ok=True)

base_model = unwrap_model_safely(unet).to(torch.float16).cpu()
base_model.eval()
traceable = TraceableUNet(base_model)
model_dtype = next(base_model.parameters()).dtype

dummy_latents = torch.randn(1, 3, config.image_size, config.image_size, dtype=model_dtype)
dummy_timestep = torch.tensor([0], dtype=torch.long)
dummy_encoder_hidden_states = torch.randn(1, config.max_prompt_length, text_encoder.config.hidden_size, dtype=model_dtype)

with torch.inference_mode():
    traced = torch.jit.trace(traceable, (dummy_latents, dummy_timestep, dummy_encoder_hidden_states), strict=False)
    traced.save(export_dir / "tiny_text2img_unet_fp16.ts")

print(f"Exported TorchScript UNet to {export_dir}")



### Web runtime tips
- Convert TorchScript to ONNX + WebNN/WebGPU via onnxruntime-web if you need portable WASM/WebGPU inference.
- Keep tokenizer + text encoder on the server (ship embeddings only) to cut client RAM by roughly 60 percent.
- Use num_inference_steps between 16 and 24 plus guidance_scale around 3.5 for fast single-image renders.
- Quantize to int8 or int4 with torch.ao.quantization or bitsandbytes before exporting if you must stay well under 1 GB.
