In [None]:
!pip install -U git+https://github.com/Sakib323/AI-Game-Engine.git
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install transformers
!pip install datasets
!pip install wandb
!pip install -U datasets
!pip install objaverse
!pip install diffusers
!pip install trimesh
!pip install jaxtyping
!pip install pytorch-lightning
!pip install ijson
!pip install triton==3.2.0
!pip install wandb

In [None]:
pip install trimesh pyrender xatlas opencv-python torch scipy xatlas

In [None]:
!git clone https://github.com/Sakib323/AI-Game-Engine.git
!git clone --depth 1 --branch main https://github.com/stepfun-ai/Step1X-3D.git

In [None]:
!cp "/kaggle/working/AI-Game-Engine/Step1x3d repo script patched/mesh_render.py" "/kaggle/working/Step1X-3D/step1x3d_texture/differentiable_renderer/mesh_render.py"
!ls -l "/kaggle/working/Step1X-3D/step1x3d_texture/differentiable_renderer/mesh_render.py"
print("\n✅ File replaced successfully.")

In [None]:
import sys
sys.path.append("./Step1X-3D")  
import os
print(os.listdir("./Step1X-3D"))
!pip install -r ./Step1X-3D/requirements.txt --verbose

In [None]:
!pip install torch-cluster -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html
!apt-get update && apt-get install -y libaio-dev

# =========================================================================================
# TRAIN MESH GENERATION MODEL: DOWNLOAD DATASET & TRAIN THE MODEL
# This script download the dataset then process them and lastly train the model.
# =========================================================================================

In [None]:
import os
import json
import torch
import traceback
import numpy as np
import trimesh
from tqdm import tqdm
import logging
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, Trainer, TrainingArguments, T5EncoderModel, CLIPVisionModelWithProjection
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule, _extract_into_tensor
from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline
from safetensors.torch import load_file as safetensors_load
import shutil
import pyrender
from PIL import Image
from torch.optim import AdamW
from diffusers import AutoencoderKL
import torch.nn as nn
import traceback
from torchvision import transforms

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['PYOPENGL_PLATFORM'] = 'egl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")

pyrender_logger = logging.getLogger('pyrender')
pyrender_logger.setLevel(logging.ERROR)
trimesh_logger = logging.getLogger('trimesh')
trimesh_logger.setLevel(logging.ERROR)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
# --- Global Configuration ---
USE_IMAGE_CONDITIONING = False # Set to True to enable text + image conditioning

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
IMAGE_SIZE = 512
DATASET_DIRECTORY = "AI-Game-Engine/testing_dataset"
IMAGE_OUTPUT_DIR = os.path.join(DATASET_DIRECTORY, "reference_images")
PROCESSED_DATA_PATH = os.path.join(DATASET_DIRECTORY, "processed_training_data_text_only.pt" if not USE_IMAGE_CONDITIONING else "processed_training_data_img_text.pt")


# --- Phase 1 Config ---
PHASE1_BATCH_SIZE = 2
PHASE1_LR = 1e-4
PHASE1_EPOCHS = 200

# --- Phase 2 Config ---
PHASE2_BATCH_SIZE = 2
PHASE2_LR = 5e-5
PHASE2_EPOCHS = 200

SAVE_CHECKPOINT_EPOCH = 20

In [None]:
def initialize_models():
    """Initializes all required models and tokenizers as specified."""
    logger.info("--- Initializing All Models ---")
    models = {}
    try:
        logger.info("Loading Step1X-3D VAE...")
        geometry_pipeline = Step1X3DGeometryPipeline.from_pretrained(
            "stepfun-ai/Step1X-3D", subfolder='Step1X-3D-Geometry-1300m', torch_dtype=DTYPE)
        models['vae'] = geometry_pipeline.vae.to(DEVICE).eval()
        
        if USE_IMAGE_CONDITIONING:
            logger.info("Loading Image VAE...")
            models['image_vae'] = AutoencoderKL.from_pretrained(
                "stabilityai/stable-diffusion-2-1", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval()
        else:
            models['image_vae'] = None

        logger.info("Loading tokenizer...")
        models['tokenizer'] = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
        models['tokenizer'].pad_token = models['tokenizer'].eos_token
        
        logger.info("Loading T5 Text Encoder...")
        models['text_encoder'] = T5EncoderModel.from_pretrained("google-t5/t5-small").to(DEVICE).eval()

        logger.info("Initializing custom Ternary MeshDiT model...")
        models['mesh_dit'] = MeshDiT_models['MeshDiT-S'](input_dim=models['vae'].cfg.embed_dim).to(DEVICE)
        
        logger.info("All models initialized successfully.")
        return models
    except Exception as e:
        logger.error(f"FATAL: Could not initialize models: {e}\n{traceback.format_exc()}")
        return None

In [None]:
def load_local_dataset(dataset_dir):
    """Loads dataset information from a local directory (your provided function)."""
    json_path = os.path.join(dataset_dir, "dataset.json")
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"dataset.json not found in {dataset_dir}")
    with open(json_path) as f:
        metadata = json.load(f)
    glb_paths, texts = [], []
    for glb_file, info in metadata.items():
        glb_path = os.path.join(dataset_dir, glb_file)
        if not os.path.exists(glb_path): continue
        merged_text = (f"category: {info.get('category', 'Unknown')} "
                       f"make: {info.get('make', 'Unknown')} "
                       f"model: {info.get('model', 'Unknown')} "
                       f"year: {info.get('year', 'Unknown')} "
                       f"description: {info.get('description', 'No description')} "
                       f"tags: {', '.join(info.get('tags', [])) if info.get('tags') else 'None'}")
        glb_paths.append(glb_path)
        texts.append(merged_text)
    logger.info(f"Found {len(glb_paths)} valid GLB files.")
    return glb_paths, texts

In [None]:
def render_and_save_image(mesh, glb_path, output_dir, image_size):
    """Renders a high-quality image of the mesh, saves it, and returns the PIL image."""
    os.makedirs(output_dir, exist_ok=True)
    image_name = os.path.splitext(os.path.basename(glb_path))[0] + ".png"
    save_path = os.path.join(output_dir, image_name)
    if os.path.exists(save_path):
        return Image.open(save_path).convert("RGB")

    # MODIFIED: Implemented robust mesh normalization for stable rendering
    # 1. Get the bounding box center
    center = mesh.bounds.mean(axis=0)
    # 2. Translate the mesh so its bounding box is centered at the origin
    mesh.apply_translation(-center)
    # 3. Find the largest dimension of the now-centered bounding box
    scale = 1.0 / np.max(mesh.extents)
    # 4. Scale the mesh to fit within a unit cube
    mesh.apply_scale(scale)

    scene = pyrender.Scene(ambient_light=[0.3, 0.3, 0.3], bg_color=[1, 1, 1, 1])
    scene.add(pyrender.Mesh.from_trimesh(mesh, smooth=True))
    camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
    camera_pose = np.array([[1,0,0,0], [0,1,0,0], [0,0,1,2.0], [0,0,0,1]])
    scene.add(camera, pose=camera_pose)
    light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=6.0)
    scene.add(light, pose=camera_pose)
    r = pyrender.OffscreenRenderer(image_size, image_size)
    color, _ = r.render(scene)
    r.delete()
    img = Image.fromarray(color, 'RGB')
    img.save(save_path)
    return img

In [None]:
def sample_points_from_mesh(mesh, num_points=16384, with_sharp_data=True, sharp_threshold_deg=60.0):
    """Official SES logic adapted from MichelangeloAutoencoder."""
    dense_points, face_indices = trimesh.sample.sample_surface(mesh, num_points * 2)
    dense_normals = mesh.face_normals[face_indices]
    surface_indices = np.random.choice(len(dense_points), num_points, replace=False)
    surface_points = dense_points[surface_indices]
    surface_normals = dense_normals[surface_indices]
    surface_cloud = np.hstack([surface_points, surface_normals]).astype(np.float32)

    if not with_sharp_data:
        return {"surface": torch.from_numpy(surface_cloud).unsqueeze(0), "sharp_surface": torch.from_numpy(surface_cloud).unsqueeze(0)}

    try:
        edge_angles = mesh.face_adjacency_angles
        sharp_mask = edge_angles > np.deg2rad(sharp_threshold_deg)
        if np.any(sharp_mask):
            sharp_face_indices = np.unique(mesh.face_adjacency[sharp_mask].flatten())
            sharp_mesh = mesh.submesh([sharp_face_indices], append=True)
            if sharp_mesh.vertices.shape[0] > 3 and sharp_mesh.faces.shape[0] > 1:
                sharp_points, face_indices_sharp = trimesh.sample.sample_surface(sharp_mesh, num_points)
                sharp_normals = sharp_mesh.face_normals[face_indices_sharp]
            else: raise ValueError("Sharp submesh invalid.")
        else: raise ValueError("No sharp edges.")
    except Exception:
        sharp_indices = np.random.choice(len(dense_points), num_points, replace=False)
        sharp_points = dense_points[sharp_indices]
        sharp_normals = dense_normals[sharp_indices]
    sharp_cloud = np.hstack([sharp_points, sharp_normals]).astype(np.float32)
    return {"surface": torch.from_numpy(surface_cloud).unsqueeze(0), "sharp_surface": torch.from_numpy(sharp_cloud).unsqueeze(0)}

class ConditioningProcessor(nn.Module):
    """Handles merging image/text embeddings or projecting a single modality."""
    def __init__(self, text_embed_dim, image_embed_dim, target_dim, use_image_cond=True):
        super().__init__()
        self.use_image_cond = use_image_cond
        self.text_proj = nn.Linear(text_embed_dim, target_dim)
        if self.use_image_cond:
            self.image_proj = nn.Linear(image_embed_dim, target_dim)

    def forward(self, text_embed, image_embed=None):
        text_projected = self.text_proj(text_embed.to(DTYPE))
        if self.use_image_cond and image_embed is not None:
            image_projected = self.image_proj(image_embed.to(DTYPE))
            return torch.cat([image_projected, text_projected], dim=1)
        return text_projected

def create_dataset_from_local_files(glb_paths, texts, models, cond_processor):
    """Processes all meshes and creates the final dataset file."""
    logger.info("Starting dataset creation from local files...")
    processed_data = []
    
    image_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    for mesh_path, caption in tqdm(zip(glb_paths, texts), total=len(glb_paths), desc="Processing Local Meshes"):
        try:
            mesh = trimesh.load(mesh_path, force='mesh', process=True)
            if not isinstance(mesh, trimesh.Trimesh) or len(mesh.vertices) == 0: continue
            
            mesh_inputs = sample_points_from_mesh(mesh, num_points=16384)
            if mesh_inputs is None: continue
            mesh_inputs_on_device = {k: v.to(DEVICE, dtype=DTYPE) for k, v in mesh_inputs.items()}

            with torch.no_grad():
                _shape_embeds, kl_embed, _ = models["vae"].encode(sample_posterior=True, **mesh_inputs_on_device)
                latent_3d = kl_embed.squeeze(0).cpu()
                
                tokens = models["tokenizer"](caption, padding="max_length", max_length=128, truncation=True, return_tensors="pt").to(DEVICE)
                text_embeddings = models["text_encoder"](**tokens).last_hidden_state
                
                image_latent_flat = None
                if USE_IMAGE_CONDITIONING:
                    rendered_image = render_and_save_image(mesh, mesh_path, IMAGE_OUTPUT_DIR, IMAGE_SIZE)
                    image_tensor = image_transform(rendered_image).unsqueeze(0).to(DEVICE, dtype=DTYPE)
                    image_latent = models["image_vae"].encode(image_tensor).latent_dist.sample() * models["image_vae"].config.scaling_factor
                    image_latent_flat = image_latent.permute(0, 2, 3, 1).flatten(1, 2)
                
                final_conditioning = cond_processor(text_embeddings, image_latent_flat)

            processed_data.append({"x": latent_3d, "y_cond": final_conditioning.squeeze(0).cpu()})
        except Exception as e:
            logger.error(f"CRITICAL ERROR processing {mesh_path}: {e}\n{traceback.format_exc()}")
            
    torch.save(processed_data, PROCESSED_DATA_PATH)
    logger.info(f"Successfully created dataset with {len(processed_data)} samples.")
    return processed_data


In [None]:
class MeshDataset(Dataset):
    def __init__(self, data_file):
        self.data = torch.load(data_file) if isinstance(data_file, str) else data_file
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

In [None]:
def train_phase(phase_num, model, cond_processor, dataloader, epochs, lr, output_dir):
    logger.info(f"--- Starting Training PHASE {phase_num} ---")
    os.makedirs(output_dir, exist_ok=True)
    
    trainable_params = list(model.parameters()) + list(cond_processor.parameters())
    optimizer = AdamW(trainable_params, lr=lr)
    criterion = torch.nn.MSELoss()
    
    # Initialize GradScaler for stable mixed-precision training
    scaler = torch.cuda.amp.GradScaler()
    
    model.train()
    cond_processor.train()

    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Phase {phase_num} - Epoch {epoch+1}/{epochs}")
        for batch in pbar:
            optimizer.zero_grad()
            latents_3d = batch["x"].to(DEVICE, dtype=DTYPE)
            conditioning = batch["y_cond"].to(DEVICE, dtype=DTYPE)
            
            noise = torch.randn_like(latents_3d)
            timesteps = torch.rand(latents_3d.shape[0], device=DEVICE)
            
            noisy_latents = (1 - timesteps.view(-1, 1, 1)) * latents_3d + timesteps.view(-1, 1, 1) * noise
            
            # Use autocast for the forward pass
            with torch.cuda.amp.autocast(dtype=DTYPE):
                predicted_velocity = model(noisy_latents, timesteps, conditioning)
                target_velocity = latents_3d - noise
                loss = criterion(predicted_velocity.float(), target_velocity.float())

            # NaN check and scaled backward pass
            if not torch.isfinite(loss):
                logger.warning(f"NaN loss detected in batch. Skipping optimizer step.")
                continue

            scaler.scale(loss).backward()
            
            # Unscale gradients before clipping
            scaler.unscale_(optimizer)
            
            # Add Gradient Clipping
            torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
            
            # Scaler optimizer step
            scaler.step(optimizer)
            scaler.update()
            
            pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        if (epoch + 1) % SAVE_CHECKPOINT_EPOCH == 0:
            torch.save(model.state_dict(), os.path.join(output_dir, f"mesh_dit_phase{phase_num}_epoch_{epoch+1}.pt"))
            torch.save(cond_processor.state_dict(), os.path.join(output_dir, f"cond_processor_phase{phase_num}_epoch_{epoch+1}.pt"))
            logger.info(f"Saved checkpoints for Phase {phase_num}, epoch {epoch+1}")
            
    logger.info(f"--- Finished Training PHASE {phase_num} ---")

def main():
    models = initialize_models()
    if not models: return
        
    cond_processor = ConditioningProcessor(
        text_embed_dim=models['text_encoder'].config.d_model,
        image_embed_dim=models['image_vae'].config.latent_channels if USE_IMAGE_CONDITIONING else 0,
        target_dim=models['mesh_dit'].hidden_size,
        use_image_cond=USE_IMAGE_CONDITIONING
    ).to(DEVICE, dtype=DTYPE)

    if not os.path.exists(PROCESSED_DATA_PATH):
        glb_paths, texts = load_local_dataset(DATASET_DIRECTORY)
        if not glb_paths:
            logger.error(f"No GLB files found in {DATASET_DIRECTORY}. Aborting.")
            return
        create_dataset_from_local_files(glb_paths, texts, models, cond_processor)
    
    full_dataset = MeshDataset(PROCESSED_DATA_PATH)

    # --- PHASE 1 TRAINING ---
    phase1_dataloader = DataLoader(full_dataset, batch_size=PHASE1_BATCH_SIZE, shuffle=True)
    train_phase(1, models['mesh_dit'], cond_processor, phase1_dataloader, PHASE1_EPOCHS, PHASE1_LR, "./mesh_dit_phase1_checkpoints")
    
    torch.save(models['mesh_dit'].state_dict(), "./mesh_dit_phase1_final.pt")
    torch.save(cond_processor.state_dict(), "./cond_processor_phase1_final.pt")
    
    # --- PHASE 2 TRAINING ---
    logger.info("Loading best model from Phase 1 for fine-tuning...")
    models['mesh_dit'].load_state_dict(torch.load("./mesh_dit_phase1_final.pt"))
    cond_processor.load_state_dict(torch.load("./cond_processor_phase1_final.pt"))
    
    phase2_dataloader = DataLoader(full_dataset, batch_size=PHASE2_BATCH_SIZE, shuffle=True)
    train_phase(2, models['mesh_dit'], cond_processor, phase2_dataloader, PHASE2_EPOCHS, PHASE2_LR, "./mesh_dit_phase2_checkpoints")
    
    torch.save(models['mesh_dit'].state_dict(), "./mesh_dit_final.pt")
    torch.save(cond_processor.state_dict(), "./cond_processor_final.pt")
    logger.info("--- Training Finished. Final models saved. ---")

if __name__ == "__main__":
    main()



# =========================================================================================
# PHASE 1 OF MESH GENERATION MODEL TRAINING
# Smaller latent set size (512), a higher learning rate (1e-4), and a higher batch size (2)
# =========================================================================================

In [None]:
# =================================================================
#                 MESH GENERATION SCRIPT
# =================================================================
print("\n" + "="*50)
print("      Starting Mesh Generation")
print("="*50 + "\n")

MODEL_PATH = "./mesh_dit_final"
OUTPUT_FILENAME = "generated_object.glb"
TEXT_PROMPT = "chair"
CFG_SCALE_TEXT = 7.5
NUM_SAMPLING_STEPS = 250

if not os.path.exists(MODEL_PATH):
    print(f"Model path not found: {MODEL_PATH}. Please ensure the model from Phase 2 training is saved.")
    exit()

print(f"Loading models from {MODEL_PATH} onto {device}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token

generation_model = MeshDiT_models['MeshDiT-S'](
    input_tokens=2048,
    vocab_size=tokenizer.vocab_size,
    use_rope=False,
    use_ternary_rope=False,
    image_condition=False
).to(device, dtype=dtype).eval()

state_dict = safetensors_load(os.path.join(MODEL_PATH, "model.safetensors"), device=str(device))
generation_model.load_state_dict(state_dict)
print("-> DiT model loaded.")

try:
    geometry_pipeline = Step1X3DGeometryPipeline.from_pretrained("stepfun-ai/Step1X-3D",subfolder='Step1X-3D-Geometry-1300m',torch_dtype=dtype)
    vae = geometry_pipeline.vae.to(device).eval()
    print("-> 3D VAE loaded.")
except Exception as e:
    print(f"Error loading VAE pipeline: {e}")
    exit()

diffusion = GaussianDiffusion(
    betas=get_named_beta_schedule("linear", 1000),
    model_mean_type=ModelMeanType.EPSILON,
    model_var_type=ModelVarType.FIXED_SMALL,
    loss_type=LossType.MSE,
)
print("-> Diffusion process ready.")

print("--- Preparing Generation Inputs ---")
tokens = tokenizer(TEXT_PROMPT, padding="max_length", max_length=128, truncation=True, return_tensors="pt")
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens["attention_mask"].to(device)

null_input_ids = torch.zeros_like(input_ids)
null_attention_mask = torch.zeros_like(attention_mask)

y_in = {
    "input_ids": torch.cat([input_ids, null_input_ids], dim=0),
    "attention_mask": torch.cat([attention_mask, null_attention_mask], dim=0)
}

y_in["image_latent"] = torch.zeros(
    (y_in["input_ids"].shape[0], 4, 64, 64), 
    device=device, 
    dtype=dtype
)
print(f"-> Text prompt tokenized: '{TEXT_PROMPT}'")

print("--- Starting Denoising Process ---")
z = torch.randn(1, 2048, 64, device=device, dtype=dtype)
ddim_timesteps = np.asarray(list(range(0, 1000, 1000 // NUM_SAMPLING_STEPS)))
ddim_steps = torch.from_numpy(ddim_timesteps).long().to(device)

with torch.no_grad():
    for i in tqdm(range(NUM_SAMPLING_STEPS - 1, -1, -1), desc="DDIM Sampling"):
        t = ddim_steps[i].expand(z.shape[0])
        z_in = torch.cat([z, z], dim=0)
        t_in = torch.cat([t, t], dim=0)
        noise_pred = generation_model.forward_with_cfg(z_in, t_in, y_in, CFG_SCALE_TEXT, 0)
        alpha_t = _extract_into_tensor(diffusion.alphas_cumprod, t, z.shape)
        t_prev_idx = ddim_steps[i - 1] if i > 0 else torch.tensor([-1], device=device, dtype=torch.long)
        alpha_t_prev = _extract_into_tensor(diffusion.alphas_cumprod, t_prev_idx, z.shape)
        pred_x0 = (z - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        dir_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        z = alpha_t_prev.sqrt() * pred_x0 + dir_xt
generated_latent_seq = z

print("--- Denoising complete. ---")
print("--- Decoding Latent and Extracting Mesh ---")
with torch.no_grad():
    decoded_latents = vae.decode(generated_latent_seq)
    mesh_result = vae.extract_geometry(
        decoded_latents, mc_level=0.5, bounds=[-1, -1, -1, 1, 1, 1], octree_resolution=256
    )[0]

final_mesh = trimesh.Trimesh(
    vertices=mesh_result.verts.cpu().numpy(),
    faces=mesh_result.faces.cpu().numpy()
)
final_mesh.export(OUTPUT_FILENAME)
print(f"\n--- ✨ Success! Mesh saved to {OUTPUT_FILENAME} ---")


# =========================================================================================
# TRAIN TEXTURE GENERATION MODEL: DOWNLOAD DATASET & TRAIN THE MODEL
# This script download the dataset then process them and lastly train the Texture gen model.
# =========================================================================================

In [None]:
from huggingface_hub import login
login(token="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")

In [None]:
import os
import torch
import trimesh
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
import traceback
import xatlas # Import the xatlas library

print("Initializing package structure...")
# Ensure the project is in the Python path
import sys
sys.path.append("./Step1X-3D")

package_dirs = [
    "./Step1X-3D/step1x3d_geometry",
    "./Step1X-3D/step1x3d_geometry/utils",
    "./Step1X-3D/step1x3d_geometry/models",
    "./Step1X-3D/step1x3d_geometry/models/pipelines",
    "./Step1X-3D/step1x3d_texture",
    "./Step1X-3D/step1x3d_texture/utils",
    "./Step1X-3D/step1x3d_texture/pipelines",
    "./Step1X-3D/step1x3d_texture/differentiable_renderer",
]

# Create an empty __init__.py file in each directory to make it importable
for pkg_dir in package_dirs:
    os.makedirs(pkg_dir, exist_ok=True)
    init_path = os.path.join(pkg_dir, "__init__.py")
    if not os.path.exists(init_path):
        with open(init_path, 'w') as f:
            pass
        print(f"Created: {init_path}")
print("✅ Package structure initialized successfully.\n")


from step1x3d_texture.utils.render import load_mesh, render, NVDiffRastContextWrapper
from step1x3d_texture.utils.camera import get_orthogonal_camera
from step1x3d_texture.utils.saving import tensor_to_image
from step1x3d_geometry.models.pipelines.pipeline_utils import preprocess_image
from step1x3d_texture.pipelines.step1x_3d_texture_synthesis_pipeline import Step1X3DTexturePipeline, Step1X3DTextureConfig

INPUT_DIR = '/kaggle/working/AI-Game-Engine/testing_dataset'
OUTPUT_DIR = 'processed_texture_data_complete'
RESOLUTION = 768

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

print("Initializing GPU rendering context...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ctx = NVDiffRastContextWrapper(device=device, context_type="cuda")
print(f"Context initialized on {device}.")

# --- UV OPTIMIZATION FUNCTION (REVISED) ---
def generate_optimized_uvs(mesh):
    """
    Takes a trimesh object and generates optimized UVs using the xatlas.parametrize function.
    """
    print("Generating optimized UVs with xAtlas...")
    if isinstance(mesh, trimesh.Scene):
        mesh = mesh.dump(concatenate=True)

    # --- FIX STARTS HERE ---
    # Use the simpler xatlas.parametrize function which directly returns the needed arrays.
    # This avoids the object access error from the previous version.
    try:
        vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)

        # Create the new mesh data using the mapping and arrays provided by xatlas
        new_vertices = mesh.vertices[vmapping]
        new_faces = indices
    
        # Create a new trimesh object with the optimized data
        uv_optimized_mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces, process=False)
        uv_optimized_mesh.visual = trimesh.visual.texture.TextureVisuals(uv=uvs)
    except Exception as e:
        print(f"    xAtlas failed with error: {e}. Returning original mesh.")
        return mesh # Return the original mesh if xatlas fails
    # --- FIX ENDS HERE ---
    
    return uv_optimized_mesh

# --- CAMERA SETUP ---
elevations = [0, 0, 0, 0, 90, -90]
azimuths = [0, 180, 90, -90, 0, 0] 
view_names = ['front', 'back', 'right', 'left', 'top', 'bottom']

cameras = get_orthogonal_camera(
    elevation_deg=elevations,
    azimuth_deg=azimuths,
    distance=[1.8] * 6,
    left=-0.55, right=0.55, bottom=-0.55, top=0.55,
    device=device
)
camera_poses_dict = {name: cameras.c2w[i].cpu().numpy().tolist() for i, name in enumerate(view_names)}

glb_files = [f for f in os.listdir(INPUT_DIR) if f.endswith('.glb') and os.path.isfile(os.path.join(INPUT_DIR, f))]
batch_data = []

for glb_filename in tqdm(glb_files, desc="Processing Meshes"):
    glb_path = os.path.join(INPUT_DIR, glb_filename)
    model_id = os.path.splitext(glb_filename)[0]
    
    model_output_dir = os.path.join(OUTPUT_DIR, model_id)
    os.makedirs(model_output_dir, exist_ok=True)

    try:
        raw_mesh = trimesh.load(glb_path, force='mesh', process=False)
        uv_optimized_mesh = generate_optimized_uvs(raw_mesh)
        uv_mesh_path = os.path.join(model_output_dir, "mesh_with_uvs.glb")
        uv_optimized_mesh.export(uv_mesh_path)
        
        mesh_for_render, _ = load_mesh(uv_optimized_mesh, rescale=True, device=device)

        render_output = render(
            ctx, mesh_for_render, cameras, height=RESOLUTION, width=RESOLUTION,
            render_attr=True, render_normal=True
        )

        position_maps = (render_output.pos + 0.5).clamp(0, 1)
        normal_maps = (render_output.normal / 2 + 0.5).clamp(0, 1)
        albedo_maps = render_output.attr
        
        output_paths = { 'albedo': {}, 'normal': {}, 'position': {}, 'reference_image': '' }
        
        front_albedo_tensor = albedo_maps[view_names.index('front')]
        front_albedo_pil = tensor_to_image(front_albedo_tensor)

        reference_image_pil = preprocess_image(front_albedo_pil)
        ref_image_path = os.path.join(model_output_dir, "reference_image.png")
        reference_image_pil.save(ref_image_path)
        output_paths['reference_image'] = ref_image_path

        for i, view_name in enumerate(view_names):
            albedo_path = os.path.join(model_output_dir, f"{view_name}_albedo.png")
            tensor_to_image(albedo_maps[i]).save(albedo_path)
            output_paths['albedo'][view_name] = albedo_path
            
            normal_path = os.path.join(model_output_dir, f"{view_name}_normal.png")
            tensor_to_image(normal_maps[i]).save(normal_path)
            output_paths['normal'][view_name] = normal_path
            
            position_path = os.path.join(model_output_dir, f"{view_name}_position.png")
            tensor_to_image(position_maps[i]).save(position_path)
            output_paths['position'][view_name] = position_path

        camera_pose_path = os.path.join(model_output_dir, 'camera_poses.json')
        with open(camera_pose_path, 'w') as f:
            json.dump(camera_poses_dict, f, indent=4)

        record = {
            'model_id': model_id,
            'uv_optimized_mesh_path': uv_mesh_path,
            'reference_image_path': output_paths['reference_image'],
            'albedo_map_paths': output_paths['albedo'],
            'normal_map_paths': output_paths['normal'],
            'position_map_paths': output_paths['position'],
            'camera_pose_path': camera_pose_path
        }
        batch_data.append(record)

    except Exception as e:
        print(f"Failed to process {glb_filename}: {e}")
        traceback.print_exc()

batch_file_path = os.path.join(OUTPUT_DIR, 'dataset_manifest.json')
with open(batch_file_path, 'w') as f:
    json.dump(batch_data, f, indent=4)

print(f"\n✅ Processing complete. {len(batch_data)} assets processed.")
print(f"Processed data saved to: {OUTPUT_DIR}")
print(f"Dataset manifest saved to: {batch_file_path}")

In [None]:
from mmfreelm.models.hgrn_bit.texture_dit import TernaryMVAdapter
import os
import json
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm.auto import tqdm
from accelerate import Accelerator
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from torchvision import transforms
import numpy as np

CONFIG = {
    "dataset_manifest": "./processed_texture_data_complete/dataset_manifest.json",
    "output_dir": "./texture_model_output",
    "image_resolution": 768, # Resolution of your rendered images
    "latent_resolution": 96,  # Resolution in VAE latent space (768 / 8)
    "train_batch_size": 1,    # Each "item" is one 3D model with 6 views
    "num_train_epochs": 100,
    "learning_rate": 1e-4,
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "adam_weight_decay": 1e-2,
    "adam_epsilon": 1e-08,
    "mixed_precision": "fp16", # Use "fp16" for faster training on compatible GPUs
    "gradient_accumulation_steps": 1,
    "save_steps": 1000,
    "num_views": 6,
    "vae_model_id": "madebyollin/sdxl-vae-fp16-fix",
    "text_encoder_id": "stabilityai/stable-diffusion-xl-base-1.0",
}

class TextureDataset(Dataset):
    """
    Dataset to load the pre-processed texture generation data.
    """
    def __init__(self, manifest_path, resolution):
        print(f"Loading dataset manifest from: {manifest_path}")
        with open(manifest_path, 'r') as f:
            self.manifest = json.load(f)
        print(f"Found {len(self.manifest)} 3D assets.")

        # Define transformations for the images
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
        ])
        self.control_transform = transforms.Compose([
            transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        item = self.manifest[idx]
        view_names = ['front', 'back', 'right', 'left', 'top', 'bottom']

        # Load multi-view albedos (target images)
        albedos = [Image.open(item['albedo_map_paths'][v]).convert("RGB") for v in view_names]
        albedos_tensor = torch.stack([self.transform(img) for img in albedos])

        # Load multi-view normals and positions to create the control image
        normals = [Image.open(item['normal_map_paths'][v]).convert("RGB") for v in view_names]
        positions = [Image.open(item['position_map_paths'][v]).convert("RGB") for v in view_names]

        normals_tensor = torch.stack([self.control_transform(img) for img in normals])
        positions_tensor = torch.stack([self.control_transform(img) for img in positions])

        # Concatenate normal and position maps to form the geometric guidance
        control_images = torch.cat([normals_tensor, positions_tensor], dim=1) # Shape: (NumViews, 6, H, W)

        # Load the single reference image for conditioning
        ref_image = Image.open(item['reference_image_path']).convert("RGB")
        ref_image_tensor = self.transform(ref_image)

        # For this texture generation, we don't need text prompts, so we provide a placeholder
        text_prompt = "a high-quality texture"

        return {
            "albedos": albedos_tensor,
            "control_images": control_images,
            "reference_image": ref_image_tensor,
            "prompt": text_prompt,
        }


# --- 3. Training Function ---
def main():
    # Initialize Accelerator for distributed training and mixed precision
    accelerator = Accelerator(
        gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
        mixed_precision=CONFIG["mixed_precision"],
        log_with="tensorboard",
        project_dir=os.path.join(CONFIG["output_dir"], "logs")
    )

    # Load pre-trained components
    vae = AutoencoderKL.from_pretrained(CONFIG["vae_model_id"], torch_dtype=torch.float16)
    # The paper uses a CLIP-based text encoder. SDXL's is suitable.
    tokenizer = CLIPTokenizer.from_pretrained(CONFIG["text_encoder_id"], subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(CONFIG["text_encoder_id"], subfolder="text_encoder", torch_dtype=torch.float16)
    
    # Noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(CONFIG["text_encoder_id"], subfolder="scheduler")

    # Instantiate your TernaryMVAdapter model
    # Note: The `cond_channels` is 6 (3 for normal + 3 for position)
    model = TernaryMVAdapter(
        input_size=CONFIG["latent_resolution"],
        patch_size=2,
        in_channels=4, # VAE latent space has 4 channels
        hidden_size=1152,
        depth=28,
        num_heads=16,
        cond_channels=6, # 3 for normal map + 3 for position map
        learn_sigma=True,
    )

    # Freeze VAE and text_encoder as they are used only for encoding
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CONFIG["learning_rate"],
        betas=(CONFIG["adam_beta1"], CONFIG["adam_beta2"]),
        weight_decay=CONFIG["adam_weight_decay"],
        eps=CONFIG["adam_epsilon"],
    )

    # Dataset and DataLoader
    train_dataset = TextureDataset(CONFIG["dataset_manifest"], CONFIG["image_resolution"])
    train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["train_batch_size"], shuffle=True)

    # Prepare everything with Accelerator
    model, optimizer, train_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader
    )
    
    # Move encoders to the correct device
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)

    global_step = 0
    
    print("🚀 Starting training...")
    for epoch in range(CONFIG["num_train_epochs"]):
        progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{CONFIG['num_train_epochs']}")

        for step, batch in enumerate(train_dataloader):
            # Reshape batch: (Batch, NumViews, C, H, W) -> (Batch * NumViews, C, H, W)
            # This is how the model expects the input to process all views
            batch_size = batch["albedos"].shape[0]
            
            clean_images = batch["albedos"].view(-1, 3, CONFIG["image_resolution"], CONFIG["image_resolution"])
            control_images = batch["control_images"].view(-1, 6, CONFIG["image_resolution"], CONFIG["image_resolution"])
            ref_images = batch["reference_image"] # Shape: (Batch, 3, H, W)
            
            with torch.no_grad():
                # 1. Encode images into latent space
                # The .latent_dist.sample() is the standard way to get latents from the VAE
                clean_latents = vae.encode(clean_images.to(dtype=torch.float16)).latent_dist.sample() * vae.config.scaling_factor
                ref_latents = vae.encode(ref_images.to(dtype=torch.float16)).latent_dist.sample() * vae.config.scaling_factor

                # 2. Encode text prompts (we use a placeholder)
                text_inputs = tokenizer(batch["prompt"], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
                prompt_embeds = text_encoder(text_inputs.input_ids.to(accelerator.device))[0]

            # 3. Sample noise and timesteps for diffusion
            noise = torch.randn_like(clean_latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (clean_latents.shape[0],), device=accelerator.device)
            noisy_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps)
            
            # Training step
            with accelerator.accumulate(model):
                # 4. Predict the noise
                noise_pred = model(
                    x=noisy_latents,
                    t=timesteps,
                    num_views=CONFIG["num_views"],
                    encoder_hidden_states=prompt_embeds,
                    control_image_feature=control_images,
                    ref_hidden_states=ref_latents
                )

                # 5. Calculate loss
                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
                
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            # Update progress bar and logs
            progress_bar.update(1)
            progress_bar.set_postfix(loss=loss.detach().item())
            global_step += 1

            if global_step % CONFIG["save_steps"] == 0:
                # Save a checkpoint
                save_path = os.path.join(CONFIG["output_dir"], f"checkpoint-{global_step}")
                accelerator.save_state(save_path)
                print(f"✅ Saved checkpoint to {save_path}")

    print("✅ Training complete!")

if __name__ == "__main__":
    os.makedirs(CONFIG["output_dir"], exist_ok=True)
    main()