In [None]:
import torch
import os
from PIL import Image
from transformers import AutoTokenizer, AutoModel
from diffusers import AutoencoderKL
from model import SingleStreamDiT
from config import Config
from latents import decode_latents_to_image, get_combined_text_embeds
from samplers import run_sampling_pipeline

FILENAME = "ema_weights_final.pt"
CHECKPOINT_PATH = os.path.join(Config.checkpoint_dir, FILENAME)
TEXT_MODEL_ID = Config.text_model_id
VAE_ID = Config.vae_id
DEVICE = Config.device
DTYPE = Config.dtype

VAE_SCALING_FACTOR = Config.vae_scaling_factor
VAE_DOWNSAMPLE_FACTOR = Config.vae_downsample_factor

TEXT_DIM = Config.text_embed_dim
IN_CHANNELS = Config.in_channels
MAX_TOKEN_LENGTH = Config.max_token_length
SHIFT_VAL = Config.shift_val

In [None]:
class FlowerGenerator:
    def __init__(self):
        self.model = None
        self.vae = None
        self.tokenizer = None
        self.text_encoder = None
        
    def load_models(self, use_ema=False):
        print(f"--- LOADING MODELS TO {DEVICE} ---")
        
        print(f"Loading DiT from {os.path.basename(CHECKPOINT_PATH)}...")
        self.model = SingleStreamDiT(in_channels=IN_CHANNELS, text_embed_dim=TEXT_DIM).to(DEVICE)
        
        if os.path.exists(CHECKPOINT_PATH):
            checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
            
            if isinstance(checkpoint, dict) and ('model_state_dict' in checkpoint or 'ema_state_dict' in checkpoint):
                if use_ema and 'ema_state_dict' in checkpoint:
                    print("Detected Full Checkpoint: Loading EMA weights (Best for inference)...")
                    state_dict = checkpoint['ema_state_dict']
                else:
                    print("Detected Full Checkpoint: Loading RAW weights (The 'Healing' weights)...")
                    state_dict = checkpoint['model_state_dict']
            else:
                print("Detected Raw State Dict: Loading directly...")
                state_dict = checkpoint
            
            self.model.load_state_dict(state_dict)
            print("DiT Weights successfully loaded.")
        else:
            print(f"!!! WARNING: {CHECKPOINT_PATH} not found. Running with random weights.")
            
        self.model.eval().to(dtype=DTYPE)
        
        print(f"Loading VAE ({VAE_ID})...")
        self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(DEVICE, dtype=torch.float32)
        self.vae.eval()

        print(f"Loading Text Encoder and Tokenizer ({TEXT_MODEL_ID})...")
        self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
        
        full_text_model = AutoModel.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True)
        self.text_encoder = full_text_model.encoder if hasattr(full_text_model, "encoder") else full_text_model
        
        self.text_encoder.to(DEVICE).eval().to(dtype=DTYPE)
        
        print("--- ALL MODELS LOADED SUCCESSFULLY ---")
        
    @torch.no_grad()
    def generate(self, prompt, neg_prompt="", steps=50, cfg=1.0, height=512, width=512, sampler="rk4", seed=None):
        if seed is not None:
            torch.manual_seed(seed)
            
        print(f"Generating: '{prompt[:40]}...' | Size: {width}x{height} | Steps: {steps} | CFG: {cfg} | Neg: '{neg_prompt[:20]}...'")
        
        combined_text = get_combined_text_embeds(
            prompt=prompt, neg_prompt=neg_prompt, cfg=cfg, 
            tokenizer=self.tokenizer, text_encoder=self.text_encoder,
            max_token_length=MAX_TOKEN_LENGTH, device=DEVICE, dtype=DTYPE
        )

        latent_h, latent_w = height // VAE_DOWNSAMPLE_FACTOR, width // VAE_DOWNSAMPLE_FACTOR
        x = torch.randn(1, IN_CHANNELS, latent_h, latent_w, device=DEVICE, dtype=torch.float32)

        with torch.autocast(DEVICE, dtype=DTYPE):
            x = run_sampling_pipeline(model=self.model, initial_noise=x, steps=steps, cond_embeds=combined_text, cfg=cfg,
                                      sampler_type=sampler, shift_val=SHIFT_VAL)

        image = decode_latents_to_image(vae_model=self.vae, latents=x, device=DEVICE)

        return image

engine = FlowerGenerator()
engine.load_models()

In [None]:
# Prompt from a file.
PROMPT = "A macro shot of a yellow dandelion with a small bee resting in the middle, having radiating, delicate, serrated petals, the center is a dense, pollen-covered disk with visible stamens, the stem is slender and reddish, set against a blurred, earthy background, sharp focus."
# Mixed prompt from multiple files.
#PROMPT = "A macro shot of a blooming flower, having overlapping, velvety petals with a vibrant gradient from deep magenta to pale pink and slightly ruffled, translucent edges, the center is a dense cluster of yellow stamens with visible pollen grains, the background is blurred green foliage with soft, dappled natural daylight."
# Unconditional prompt.
#PROMPT = ""
NEGATIVE_PROMPT = ""

image = engine.generate(
    prompt=PROMPT,
    neg_prompt=NEGATIVE_PROMPT,
    steps=50,          
    cfg=3.0,
    width=512, 
    height=384,        
    sampler="euler",
    seed=None       
)

display(image)