In [None]:
import torch
import os
from PIL import Image
from transformers import AutoTokenizer, AutoModel
from diffusers import AutoencoderKL
from model import SingleStreamDiTV2
from config import Config
from latents import *
from samplers import *

# --- GLOBAL SETTINGS ---
FILENAME = "ema_weights_final.pt"
CHECKPOINT_PATH = os.path.join(Config.checkpoint_dir, FILENAME)
TEXT_MODEL_ID = Config.text_model_id
VAE_ID = Config.vae_id
DEVICE = Config.device
DTYPE = Config.dtype

VAE_SCALING_FACTOR = Config.vae_scaling_factor
VAE_DOWNSAMPLE_FACTOR = Config.vae_downsample_factor

TEXT_DIM = Config.text_embed_dim
IN_CHANNELS = Config.in_channels
MAX_TOKEN_LENGTH = Config.max_token_length
SHIFT_VAL = Config.shift_val

In [None]:
class FlowerGenerator:
    def __init__(self):
        self.model = None
        self.vae = None
        self.tokenizer = None
        self.text_encoder = None
        
    def load_models(self, use_ema=False):
        print(f"--- LOADING MODELS TO {DEVICE} ---")
        
        print(f"Loading DiT from {os.path.basename(CHECKPOINT_PATH)}...")
        self.model = SingleStreamDiTV2(in_channels=IN_CHANNELS, text_embed_dim=TEXT_DIM).to(DEVICE)
        
        if os.path.exists(CHECKPOINT_PATH):
            checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
            
            if isinstance(checkpoint, dict) and ('model_state_dict' in checkpoint or 'ema_state_dict' in checkpoint):
                if use_ema and 'ema_state_dict' in checkpoint:
                    print("Detected Full Checkpoint: Loading EMA weights (Best for inference)...")
                    state_dict = checkpoint['ema_state_dict']
                else:
                    print("Detected Full Checkpoint: Loading RAW weights (The 'Healing' weights)...")
                    state_dict = checkpoint['model_state_dict']
            else:
                print("Detected Raw State Dict: Loading directly...")
                state_dict = checkpoint
            
            self.model.load_state_dict(state_dict)
            print("DiT Weights successfully loaded.")
        else:
            print(f"!!! WARNING: {CHECKPOINT_PATH} not found. Running with random weights.")
            
        self.model.eval().to(dtype=DTYPE)

        print(f"Loading VAE ({VAE_ID})...")
        self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(DEVICE, dtype=torch.float32)
        self.vae.eval()

        print(f"Loading Text Encoder and Tokenizer ({TEXT_MODEL_ID})...")
        self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
        
        full_text_model = AutoModel.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True)
        self.text_encoder = full_text_model.encoder if hasattr(full_text_model, "encoder") else full_text_model
        
        self.text_encoder.to(DEVICE).eval().to(dtype=DTYPE)
        
        print("--- ALL MODELS LOADED SUCCESSFULLY ---")
    
    def decode(self, latents):
        latents = prepare_latents_for_decode(latents)

        with torch.no_grad():
            with torch.autocast(DEVICE, enabled=False):
                image = self.vae.decode(latents.float()).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        image = (image * 255).round().astype("uint8")
        
        return Image.fromarray(image[0])
    
    def get_shifted_time(self, t_linear):
        if SHIFT_VAL == 1.0:
            return t_linear
        return (t_linear * SHIFT_VAL) / (1 + (SHIFT_VAL - 1) * t_linear)

    @torch.no_grad()
    def generate(self, prompt, neg_prompt="", steps=50, cfg=1.0, height=512, width=512, sampler="rk4", seed=None):
        if seed is not None:
            torch.manual_seed(seed)
            
        print(f"Generating: '{prompt[:40]}...' | Size: {width}x{height} | Steps: {steps} | CFG: {cfg}")
        
        inputs = self.tokenizer(prompt, max_length=MAX_TOKEN_LENGTH, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
        out = self.text_encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
        cond_embeds = out.last_hidden_state if hasattr(out, "last_hidden_state") else out[0]
        uncond_embeds = torch.zeros_like(cond_embeds)
        combined_text = torch.cat([uncond_embeds, cond_embeds], dim=0).to(dtype=DTYPE)
        
        latent_h, latent_w = height // VAE_DOWNSAMPLE_FACTOR, width // VAE_DOWNSAMPLE_FACTOR
        x = torch.randn(1, IN_CHANNELS, latent_h, latent_w, device=DEVICE, dtype=torch.float32)
        dt = 1.0 / steps

        with torch.autocast(DEVICE, dtype=DTYPE):
            for i in range(steps):
                if sampler == "euler":
                    t_linear = i / steps
                    t = torch.tensor([self.get_shifted_time(t_linear)], device=DEVICE, dtype=DTYPE)
                    x = euler_step(self.model, x, t, dt, combined_text, cfg)
                
                elif sampler == "rk4":
                    t_linear = i / steps
                    t_mid_linear = (i + 0.5) / steps

                    t = torch.tensor([self.get_shifted_time(t_linear)], device=DEVICE, dtype=DTYPE)
                    t_mid = torch.tensor([self.get_shifted_time(t_mid_linear)], device=DEVICE, dtype=DTYPE)

                    x = rk4_step(self.model, x, t, dt, combined_text, cfg, t_mid)

        return self.decode(x)

In [None]:
engine = FlowerGenerator()
engine.load_models()

In [None]:
# Prompt from a file.
PROMPT = "A macro shot of a yellow dandelion with a small bee resting in the middle, having radiating, delicate, serrated petals, the center is a dense, pollen-covered disk with visible stamens, the stem is slender and reddish, set against a blurred, earthy background, sharp focus."
# Mixed prompt from multiple files.
#PROMPT = "A macro shot of a blooming coneflower with a small bee resting in the middle, having delicate purple petals with slightly ruffled edges, the center is a dense conical head of dark purple anthers with orange-tinted pollen, the background is blurred green grass."
# Unconditional prompt.
#PROMPT = ""
NEGATIVE_PROMPT = ""

image = engine.generate(
    prompt=PROMPT,
    neg_prompt=NEGATIVE_PROMPT,
    steps=50,          
    cfg=1.15,
    width=512, 
    height=384,        
    sampler="rk4",
    seed=42         
)

display(image)