In [None]:
import torch
import os
from PIL import Image
from transformers import AutoTokenizer, AutoModel
from diffusers import AutoencoderKL
from model import SingleStreamDiT
from config import Config
from latents import decode_latents_to_image, get_combined_text_embeds
from samplers import run_sampling_pipeline

FILENAME = "ema_weights_final.pt"
CHECKPOINT_PATH = os.path.join(Config.checkpoint_dir, FILENAME)

In [None]:
class FlowerGenerator:
    def __init__(self):
        self.model = None
        self.vae = None
        self.tokenizer = None
        self.text_encoder = None
        
    def load_models(self, use_ema=False):
        print(f"--- LOADING MODELS TO {Config.device} ---")
        
        print(f"Loading DiT from {os.path.basename(CHECKPOINT_PATH)}...")
        self.model = SingleStreamDiT(in_channels=Config.in_channels, text_embed_dim=Config.text_embed_dim).to(Config.device)
        
        if os.path.exists(CHECKPOINT_PATH):
            checkpoint = torch.load(CHECKPOINT_PATH, map_location=Config.device)
            
            if isinstance(checkpoint, dict) and ('model_state_dict' in checkpoint or 'ema_state_dict' in checkpoint):
                if use_ema and 'ema_state_dict' in checkpoint:
                    print("Detected Full Checkpoint: Loading EMA weights (Best for inference)...")
                    state_dict = checkpoint['ema_state_dict']
                else:
                    print("Detected Full Checkpoint: Loading RAW weights (The 'Healing' weights)...")
                    state_dict = checkpoint['model_state_dict']
            else:
                print("Detected Raw State Dict: Loading directly...")
                state_dict = checkpoint
            
            self.model.load_state_dict(state_dict)
            print("DiT Weights successfully loaded.")
        else:
            print(f"!!! WARNING: {CHECKPOINT_PATH} not found. Running with random weights.")
            
        self.model.eval().to(dtype=Config.dtype)
        
        print(f"Loading VAE ({Config.vae_id})...")
        self.vae = AutoencoderKL.from_pretrained(Config.vae_id).to(Config.device, dtype=torch.float32)
        self.vae.eval()

        print(f"Loading Text Encoder and Tokenizer ({Config.text_model_id})...")
        self.tokenizer = AutoTokenizer.from_pretrained(Config.text_model_id)
        
        full_text_model = AutoModel.from_pretrained(Config.text_model_id, trust_remote_code=True)
        self.text_encoder = full_text_model.encoder if hasattr(full_text_model, "encoder") else full_text_model
        
        self.text_encoder.to(Config.device).eval().to(dtype=Config.dtype)
        
        print("--- ALL MODELS LOADED SUCCESSFULLY ---")
        
    @torch.no_grad()
    def generate(self, prompt, neg_prompt="", steps=50, cfg=1.0, height=512, width=512, sampler="rk4", seed=None):
        if seed is not None:
            torch.manual_seed(seed)
            
        print(f"Generating: '{prompt[:40]}...' | Size: {width}x{height} | Steps: {steps} | CFG: {cfg} | Neg: '{neg_prompt[:20]}...'")
        
        combined_text = get_combined_text_embeds(
            prompt=prompt, neg_prompt=neg_prompt, cfg=cfg, 
            tokenizer=self.tokenizer, text_encoder=self.text_encoder,
            max_token_length=Config.max_token_length, device=Config.device, dtype=Config.dtype
        )

        latent_h, latent_w = height // Config.vae_downsample_factor, width // Config.vae_downsample_factor
        x = torch.randn(1, Config.in_channels, latent_h, latent_w, device=Config.device, dtype=torch.float32)

        with torch.autocast(Config.device, dtype=Config.dtype):
            x = run_sampling_pipeline(model=self.model, initial_noise=x, steps=steps, cond_embeds=combined_text, cfg=cfg,
                                      sampler_type=sampler, shift_val=Config.shift_val)

        image = decode_latents_to_image(vae_model=self.vae, latents=x, device=Config.device)

        return image

engine = FlowerGenerator()
engine.load_models()

In [None]:
# Prompt from a file.
PROMPT = "A macro shot of a yellow dandelion with a small bee resting in the middle, having radiating, delicate, serrated petals, the center is a dense, pollen-covered disk with visible stamens, the stem is slender and reddish, set against a blurred, earthy background, sharp focus."
# Mixed prompt from multiple files.
#PROMPT = "A macro shot of a blooming flower, having overlapping, velvety petals with a vibrant gradient from deep magenta to pale pink and slightly ruffled, translucent edges, the center is a dense cluster of yellow stamens with visible pollen grains, the background is blurred green foliage with soft, dappled natural daylight."
# Unconditional prompt.
#PROMPT = ""
NEGATIVE_PROMPT = ""

image = engine.generate(
    prompt=PROMPT,
    neg_prompt=NEGATIVE_PROMPT,
    steps=50,          
    cfg=3.0,
    width=512, 
    height=384,        
    sampler="euler",
    seed=None       
)

display(image)