In [None]:
!pip install -U git+https://github.com/Sakib323/AI-Game-Engine.git
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install transformers
!pip install datasets
!pip install wandb
!pip install -U datasets
!pip install objaverse
!pip install diffusers
!pip install trimesh
!pip install jaxtyping
!pip install pytorch-lightning
!pip install ijson
!pip install triton==3.2.0
!pip install wandb

In [None]:
!git clone --depth 1 --branch main https://github.com/stepfun-ai/Step1X-3D.git
import sys
sys.path.append("./Step1X-3D")  
import os
print(os.listdir("./Step1X-3D"))
!pip install -r ./Step1X-3D/requirements.txt --verbose

In [1]:
pip install torch-cluster -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html

Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html
Note: you may need to restart the kernel to use updated packages.


# =========================================================================================
# TRAIN MESH GENERATION MODEL: DOWNLOAD DATASET & TRAIN THE MODEL
# This script download the dataset then process them and lastly train the model.
# =========================================================================================

In [2]:
import requests
import gzip
import json
import objaverse
import gc
import os
import ijson


KEYS_TO_EXCLUDE = {
    "tags", "name", "staffpickedAt", "viewCount", "likeCount", "animationCount",
    "commentCount", "publishedAt", "user", "description", "faceCount", "createdAt",
    "vertexCount", "license", "uri", "viewerUrl", "embedUrl", "isDownloadable",
    "categories", "isAgeRestricted", "archives"
}

print("Starting dataset preparation...")

# Check if captions file exists; download only if necessary
url = "https://huggingface.co/datasets/tiange/Cap3D/resolve/main/Objaverse_files/cap3d_captions.json.gz"
captions_file = 'cap3d_captions.json.gz'
if not os.path.exists(captions_file):
    print("Downloading captions file...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        with open(captions_file, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Captions file downloaded successfully.")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading captions: {e}")
        exit()
else:
    print("Captions file already exists, skipping download.")

def process_batch(batch, f, total_counter):
    uids = [uid for uid, _ in batch]
    try:
        annotations = objaverse.load_annotations(uids)
    except Exception as e:
        print(f"Error loading annotations for batch: {e}")
        return 0  # Skip this batch on error

    for i, (uid, caption) in enumerate(batch):
        metadata = annotations.get(uid, {})
        filtered_metadata = {k: v for k, v in metadata.items() if k not in KEYS_TO_EXCLUDE}

        # Filter thumbnails to keep only 1024x576 images
        if "thumbnails" in filtered_metadata and "images" in filtered_metadata["thumbnails"]:
            filtered_metadata["thumbnails"]["images"] = [
                img for img in filtered_metadata["thumbnails"]["images"]
                if img.get("width") == 1024 and img.get("height") == 576
            ]

        datapoint = {
            "uid": uid,
            "caption": caption,
            **filtered_metadata
        }
        f.write(json.dumps(datapoint) + '\n')
        if total_counter + i < 10:
            print(f"Datapoint {uid}:")
            print(json.dumps(datapoint, indent=2))
            print("-" * 50)
    del annotations
    gc.collect()
    return len(batch)


batch_size = 100000
total_counter = 0
with open('extended_dataset.jsonl', 'w', encoding='utf-8') as f:
    batch = []
    for uid, caption in ijson.kvitems(gzip.open(captions_file, 'rt', encoding='utf-8'), ''):
        batch.append((uid, caption))
        if len(batch) >= batch_size:
            processed = process_batch(batch, f, total_counter)
            total_counter += processed
            print(f"Processed {total_counter} captions so far.")
            batch = []
    if batch:
        processed = process_batch(batch, f, total_counter)
        total_counter += processed

print(f"Total entries in extended dataset: {total_counter}")
print("Dataset saved to 'extended_dataset.jsonl'")
print("Script finished.")

Starting dataset preparation...
Captions file already exists, skipping download.


 99%|█████████▉| 159/160 [01:00<00:00,  2.63it/s]


Datapoint ed51a51909ee46c780db3a85e821feb2:
{
  "uid": "ed51a51909ee46c780db3a85e821feb2",
  "caption": "Matte green rifle with a long barrel, stock, and detailed magazine.",
  "thumbnails": {
    "images": [
      {
        "uid": "0d6d45c0ee174cca9323a183ebbe2ef3",
        "size": 10719,
        "width": 1024,
        "url": "https://media.sketchfab.com/models/ed51a51909ee46c780db3a85e821feb2/thumbnails/f46b00b385d4449d923aff78348b04c6/eb41be4a1dcd450a974db9d24dc13d16.jpeg",
        "height": 576
      }
    ]
  }
}
--------------------------------------------------
Datapoint 9110b606f6c547b2980fcb3c8c4b6a1c:
{
  "uid": "9110b606f6c547b2980fcb3c8c4b6a1c",
  "caption": "Rustic single-story building with a weathered green gable roof, exposed wooden beams, brick walls, and a partly visible glass window front.",
  "thumbnails": {
    "images": [
      {
        "uid": "32b5631f607543a38e10188b03e3d7c1",
        "size": 80170,
        "width": 1024,
        "url": "https://media.sketchfab

 66%|██████▋   | 106/160 [00:40<00:20,  2.64it/s]


KeyboardInterrupt: 

In [3]:
import json
import objaverse
import pathlib
import shutil
import os

jsonl_path = "extended_dataset.jsonl"
output_dir = pathlib.Path("downloads")
output_dir.mkdir(exist_ok=True)
metadata_output_path = output_dir / "metadata.json"
download_limit = 150

uid_metadata = {}
with open(jsonl_path, "r", encoding="utf-8") as f:
    for line in f:
        if len(uid_metadata) >= download_limit:
            break
        entry = json.loads(line)
        uid = entry["uid"]
        caption = entry.get("caption", "")
        thumbnails = entry.get("thumbnails", {}).get("images", [])
        thumbnail_urls = [img["url"] for img in thumbnails if "url" in img]

        uid_metadata[uid] = {
            "caption": caption,
            "thumbnails": thumbnail_urls
        }

uids = list(uid_metadata.keys())
print(f"Downloading {len(uids)} models...")
paths = objaverse.load_objects(uids=uids)

final_metadata = {}
for uid, src_path in paths.items():
    dst_path = output_dir / f"{uid}.glb"
    try:
        shutil.copy(src_path, dst_path)
        final_metadata[uid] = uid_metadata[uid]
        print(f"Saved {uid} to {dst_path}")
    except Exception as e:
        print(f"Failed to save {uid}: {e}")

with open(metadata_output_path, "w", encoding="utf-8") as f:
    json.dump(final_metadata, f, indent=2)

print(f"\nDownloaded {len(final_metadata)} models.")
print(f"Metadata saved to {metadata_output_path}")


Downloading 150 models...
Downloaded 1 / 150 objects
Downloaded 2 / 150 objects
Downloaded 3 / 150 objects
Downloaded 4 / 150 objects
Downloaded 5 / 150 objects
Downloaded 6 / 150 objects
Downloaded 7 / 150 objects
Downloaded 8 / 150 objects
Downloaded 9 / 150 objects
Downloaded 10 / 150 objects
Downloaded 11 / 150 objects
Downloaded 12 / 150 objects
Downloaded 13 / 150 objects
Downloaded 14 / 150 objects
Downloaded 15 / 150 objects
Downloaded 16 / 150 objects
Downloaded 17 / 150 objects
Downloaded 18 / 150 objects
Downloaded 19 / 150 objects
Downloaded 20 / 150 objects
Downloaded 21 / 150 objects
Downloaded 22 / 150 objects
Downloaded 23 / 150 objects
Downloaded 24 / 150 objects
Downloaded 25 / 150 objects
Downloaded 26 / 150 objects
Downloaded 27 / 150 objects
Downloaded 28 / 150 objects
Downloaded 29 / 150 objects
Downloaded 30 / 150 objects
Downloaded 31 / 150 objects
Downloaded 32 / 150 objects
Downloaded 33 / 150 objects
Downloaded 34 / 150 objects
Downloaded 35 / 150 objects
Dow

In [4]:
import requests
import os
import tempfile
import trimesh
import numpy as np
import torch
from transformers import AutoTokenizer
from tqdm import tqdm
import random
import traceback
import logging
import sys
from diffusers import AutoencoderKL


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")

sys.path.append("./Step1X-3D")

from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline

from huggingface_hub import hf_hub_download
print("Initializing Step1X-3D VAE...")
try:
    geometry_pipeline = Step1X3DGeometryPipeline.from_pretrained(
        "stepfun-ai/Step1X-3D",
        subfolder='Step1X-3D-Geometry-1300m',
        torch_dtype=dtype
    )
    vae = geometry_pipeline.vae.to(device)
    vae.eval()
    print("Step1X-3D VAE initialized successfully.")
except Exception as e:
    print(f"Error initializing pipeline. Make sure you have the Step1X-3D repo cloned and in your sys.path. Error: {e}")
    vae = None

tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
tokenizer.pad_token = tokenizer.eos_token


image_vae = AutoencoderKL.from_pretrained(
        "stabilityai/stable-diffusion-2-1",
        subfolder="vae",
        torch_dtype=dtype
    ).to(device).eval()
print("Image VAE model loaded.")


2025-07-13 11:57:19.919995: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752407840.180790    1359 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752407840.240181    1359 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda, dtype: torch.float32
Initializing Step1X-3D VAE...


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/457 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/770 [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/5.27G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/766M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

/root/.cache/huggingface/hub/models--stepfun-ai--Step1X-3D/snapshots/bf7084495b3a72222f36549b7942948aa4d9daa7


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading Dinov2 model from facebook/dinov2-with-registers-large


config.json: 0.00B [00:00, ?B/s]

preprocessor_config.json:   0%|          | 0.00/436 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Step1X-3D VAE initialized successfully.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Image VAE model loaded.


In [5]:
from diffusers import AutoencoderKL
import requests
from PIL import Image
from io import BytesIO
import torch
import torchvision.transforms as transforms
from torchvision.transforms.functional import crop
from diffusers import AutoencoderKL
import torchvision.transforms as T


def get_image_latent(url: str, vae: AutoencoderKL, device: str = "cuda") -> torch.Tensor:

    try:
        response = requests.get(url)
        response.raise_for_status()
        img_data = BytesIO(response.content)
        raw_image = Image.open(img_data).convert("RGB")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading image: {e}")
        return None
    except IOError as e:
        print(f"Error opening image: {e}")
        return None

    preprocess = T.Compose([
        T.Resize(512, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(512),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    image_tensor = preprocess(raw_image).unsqueeze(0).to(device=device, dtype=vae.dtype)


    with torch.no_grad():
        latent_dist = vae.encode(image_tensor).latent_dist
        latents = latent_dist.sample()
        latents = latents * vae.config.scaling_factor

    return latents
    

In [None]:
import os
import json
import torch
import traceback
import numpy as np
import trimesh
from tqdm import tqdm
import logging
from torch_cluster import fps 

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def sample_uniform_and_sharp_points(mesh, num_uniform_points=16384, num_sharp_points=16384, sharp_threshold_deg=60):
    """
    Performs separate uniform and sharp edge sampling on a mesh.
    Returns two distinct point clouds.
    """
    uniform_points, face_indices_uniform = trimesh.sample.sample_surface(mesh, num_uniform_points)
    uniform_normals = mesh.face_normals[face_indices_uniform]
    sharp_points, sharp_normals = None, None

    try:
        # In trimesh, a larger angle between adjacent face normals means a sharper edge.
        edge_angles = mesh.face_adjacency_angles
        sharp_threshold_rad = np.deg2rad(sharp_threshold_deg)
        sharp_edge_indices = np.where(edge_angles > sharp_threshold_rad)[0]

        if len(sharp_edge_indices) > 0:
            face_indices_of_sharp_edges = mesh.face_adjacency[sharp_edge_indices].flatten()
            sharp_face_indices = np.unique(face_indices_of_sharp_edges)
            if len(sharp_face_indices) > 0:
                sharp_mesh = mesh.submesh([sharp_face_indices], append=True)
                if sharp_mesh.vertices.shape[0] > 3 and sharp_mesh.faces.shape[0] > 1:
                    sharp_points, face_indices_sharp = trimesh.sample.sample_surface(sharp_mesh, num_sharp_points)
                    sharp_normals = sharp_mesh.face_normals[face_indices_sharp]
    except Exception as e:
        logger.warning(f"Could not perform sharp sampling due to a mesh error: {e}. Falling back to uniform.")
        sharp_points = None

    if sharp_points is None or sharp_normals is None:
        logger.info("No sharp regions found or submesh failed. Using uniform samples for the sharp set as a fallback.")
        sharp_points, face_indices_sharp = trimesh.sample.sample_surface(mesh, num_sharp_points)
        sharp_normals = mesh.face_normals[face_indices_sharp]

    return (uniform_points, uniform_normals), (sharp_points, sharp_normals)



def process_mesh_to_vae_input(mesh_path, num_points=32768):
    """
    Processes a mesh file to generate the 'surface' and 'sharp_surface' point
    clouds that the MichelangeloAutoencoder.encode function expects.
    """
    try:
        mesh = trimesh.load(mesh_path, force='mesh', process=True)

        if isinstance(mesh, trimesh.Scene):
            if not mesh.geometry:
                 logger.warning(f"Skipping {mesh_path}: Trimesh scene is empty.")
                 return None
            mesh = mesh.dump().sum()

        if not isinstance(mesh, trimesh.Trimesh) or len(mesh.vertices) == 0 or len(mesh.faces) == 0:
            logger.warning(f"Skipping {mesh_path}: No valid mesh data found after loading.")
            return None

        if not mesh.is_watertight:
            trimesh.repair.fill_holes(mesh)

        center = mesh.bounds.mean(axis=0)
        mesh.apply_translation(-center)
        max_extent = np.max(np.linalg.norm(mesh.vertices, axis=1))
        if max_extent > 1e-6:
            mesh.apply_scale(1.0 / max_extent)

        # Get the two separate point clouds
        (uniform_points, uniform_normals), (sharp_points, sharp_normals) = sample_uniform_and_sharp_points(
            mesh, num_uniform_points=num_points, num_sharp_points=num_points
        )

        surface_cloud = np.hstack([uniform_points, uniform_normals])
        sharp_cloud = np.hstack([sharp_points, sharp_normals])

        return {
            "surface": torch.tensor(surface_cloud, dtype=dtype).unsqueeze(0),
            "sharp_surface": torch.tensor(sharp_cloud, dtype=dtype).unsqueeze(0)
        }
    except Exception as e:
        logger.error(f"CRITICAL ERROR processing {mesh_path}: {str(e)}")
        logger.error(traceback.format_exc())
        return None


def process_all_meshes(download_dir="downloads", metadata_file="downloads/metadata.json"):
    # Ensure global variables are accessible or passed as arguments
    global device, vae, image_vae, tokenizer 
    
    logger.info("Starting dataset creation with the updated process_all_meshes function...")
    processed_data = []
    
    if not os.path.exists(metadata_file):
        logger.error(f"Metadata file not found at: {metadata_file}")
        return []
        
    with open(metadata_file, 'r', encoding='utf-8') as f:
        metadata = json.load(f)

    for uid, info in tqdm(metadata.items(), desc="Processing Raw Data"):
        mesh_path = os.path.join(download_dir, f"{uid}.glb")
        caption = info.get('caption', '')
        
        thumbnails = info.get('thumbnails')
        if not thumbnails or not isinstance(thumbnails, list) or not thumbnails[0]:
            continue  
        image_url = thumbnails[0]

        if not os.path.isfile(mesh_path) or not image_url:
            continue

        mesh_inputs = process_mesh_to_vae_input(mesh_path, num_points=16384)
        
        if mesh_inputs is None:
            continue 
        mesh_inputs_on_device = {k: v.to(device) for k, v in mesh_inputs.items()}

        logger.info("Inputs prepared. Calling vae.encode with keys: %s", list(mesh_inputs_on_device.keys()))
        
        with torch.no_grad():
            _shape_embeds, kl_embed, _posterior = vae.encode(
                sample_posterior=True, **mesh_inputs_on_device)
            latent_3d = kl_embed.squeeze(0).cpu()
            
        latent_image = get_image_latent(image_url, image_vae, device=device)
        if latent_image is None:
            continue
        latent_image = latent_image.cpu()

        tokens = tokenizer(caption, padding="max_length", max_length=128, truncation=True, return_tensors="pt")


        processed_data.append({
            "x": latent_3d,
            "y": {
                "image_latent": latent_image,
                "input_ids": tokens["input_ids"].squeeze(0),
                "attention_mask": tokens["attention_mask"].squeeze(0),
            }
        })

    logger.info(f"Successfully created dataset with {len(processed_data)} samples.")
    
    if not processed_data:
        logger.critical("The `processed_data` list is empty after the loop. This means every file failed. Check the error logs above for CRITICAL ERROR messages to see the root cause.")

    return processed_data


In [None]:
!apt-get install -y libaio-dev > /dev/null

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- 1. Required Imports ---
import json
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import Trainer, TrainingArguments
import random
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
import wandb


os.environ["WANDB_API_KEY"] = "89b06c10468af620747b4bd340f72fa5d56f6849"

class MeshDataset(Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

class CustomDataCollator:
    def __call__(self, features):
        batch = {}
        batch['x'] = torch.stack([f['x'] for f in features])
        y_features = [f['y'] for f in features]
        batch['y'] = {key: torch.stack([d[key] for d in y_features]) for key in y_features[0]}
        return batch

class MeshDiTTrainer(Trainer):
    def __init__(self, *args, diffusion, **kwargs):
        super().__init__(*args, **kwargs)
        self.diffusion = diffusion

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        x_start = inputs.get("x")
        model_kwargs = {"y": inputs.get("y")}
        t = torch.randint(0, self.diffusion.num_timesteps, (x_start.shape[0],), device=x_start.device).long()
        noise = torch.randn_like(x_start)
        x_t = self.diffusion.q_sample(x_start, t, noise=noise)
        
        # model's forward pass
        model_output = model(x_t, t, y=inputs['y'])
        
        # diffusion model's loss calculation
        loss_dict = self.diffusion.training_losses(model, x_start, t, model_kwargs, noise=noise)
        loss = loss_dict["loss"].mean()
        
        return (loss, loss_dict) if return_outputs else loss
    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys = None,
    ):
        """
        Perform an evaluation step on the model.
        """
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            loss = self.compute_loss(model, inputs)
            
        return (loss.detach(), None, None)



Initializing MeshDiT model for training...
Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with

Processing Raw Data:   5%|▌         | 8/150 [01:06<24:44, 10.45s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/d7340a5b05b6460facbd90aaafb7f1f1/thumbnails/868c75b689ac4da681e10d9d857ea075/108718dab17e4802827a6c1752cd3929.jpeg


Processing Raw Data:   7%|▋         | 10/150 [01:17<17:43,  7.59s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/3764a2442eca43c7bbe64d8297d84905/thumbnails/2f02d5a090de4cc1bf790f8023c82dca/a7e0a494e1db40b99709b67b9223a573.jpeg


Processing Raw Data:   7%|▋         | 11/150 [01:18<12:40,  5.47s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/9ff331513cea44a2938d09a03d7b0493/thumbnails/59e0d3c53d0c4c76aac836eafccee485/349e5b6fb5f9467ab32fc9623fb5e8ae.jpeg


Processing Raw Data:  14%|█▍        | 21/150 [02:30<32:38, 15.18s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/d40f6a7e7dc049a6b9e63c2ac31918f0/thumbnails/52b270d8988a4cefaa7e85a7585f23d6/327f77b5459b419e84ae410e25444783.jpeg


Processing Raw Data:  39%|███▊      | 58/150 [03:49<01:13,  1.24it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/39c23be7af5848408ad74a744a127880/thumbnails/b8f74f22c2b24142b7803ed231dc6704/7e7d47678f9c41a2bf7ee7cbde956049.jpeg


Processing Raw Data:  40%|████      | 60/150 [03:53<02:37,  1.75s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/0e315229b94543b3b510403f02c797d3/thumbnails/8cadd0cf53c549fda9c3da6628575f88/4cb6c28801e84e0ba7ff1015a0f94d71.jpeg


Processing Raw Data:  41%|████      | 61/150 [03:56<02:43,  1.84s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/28de7aac69ab41c6be2371e144a38c28/thumbnails/5951f068190d44968f1169524a1a191c/9cfa501fcc314f13a88d3adfaf34056f.jpeg


Processing Raw Data:  47%|████▋     | 71/150 [04:16<01:31,  1.16s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/74d467eb13b44d3890862046c3539992/thumbnails/7ffddcf165f64340b16cc627341841f9/ff34e059ae484889b4ef18b6bd2d8d45.jpeg


Processing Raw Data:  57%|█████▋    | 86/150 [04:44<01:59,  1.87s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/c5e8fc0ce4f14e75b29e42b62edf70fe/thumbnails/b11ccde98c5d43b49d7dd8804731c598/06441ad9abd9436fb24c7969d78aa1c4.jpeg


Processing Raw Data:  59%|█████▊    | 88/150 [04:45<01:17,  1.25s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/9fe9043ea98d46179ffa896fbab9fdce/thumbnails/84edfd4596e9458286193592c16868e4/f1a45a452fca4284b943947ac8e1e593.jpeg


Processing Raw Data:  63%|██████▎   | 94/150 [04:50<00:43,  1.30it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/c7f9e145f7ba4744925c2a26fae01176/thumbnails/2311249598c9480397061df7bf6714f7/487086369e9041ad84cb856267d91d73.jpeg


Processing Raw Data:  63%|██████▎   | 95/150 [04:50<00:42,  1.28it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/c640aca2e346421aa9ac1ec494b2567e/thumbnails/e4d4759c9bbb4f999fd83574de7820e5/8b0e7abc78cb4ba8906a753199913de2.jpeg


Processing Raw Data:  69%|██████▊   | 103/150 [05:07<01:45,  2.25s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/5b3096461e3646d18b231c614018c358/thumbnails/179adc51c0364a9f91188e236ff04319/6a495d1c087e43578b3cd075e8daca0f.jpeg


Processing Raw Data:  74%|███████▍  | 111/150 [05:23<00:58,  1.50s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/16a37dc15d204ab9b3c3052030fe6fce/thumbnails/87de4f9974a04423bfc3af29839c6798/5e71d4805f0e4352b6cdd0d8954ef291.jpeg


Processing Raw Data:  76%|███████▌  | 114/150 [05:32<01:32,  2.58s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/8129182bb18a4f99b2b657e1b49513a3/thumbnails/622c1dd806e74ad9a7ae850d670a6456/fdc038b714eb457985d6aabded53aee0.jpeg


Processing Raw Data:  82%|████████▏ | 123/150 [05:53<00:46,  1.73s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/08dd63723554492094aca4c7816d4b1d/thumbnails/f27b1b5d59b04aefbdf2625daa0ac166/9e15e77ce9d4456788a8c95fd4c9e798.jpeg


Processing Raw Data:  85%|████████▌ | 128/150 [05:57<00:20,  1.09it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/d03fa1c3ddf4478f997b03f486887f19/thumbnails/d8ca7bb0609b4e3eae2cc857d16f6fcf/b34c13f44bd64b739e61b2f064a24dc2.jpeg


Processing Raw Data: 100%|██████████| 150/150 [06:41<00:00,  2.68s/it]

Data split: 126 training samples, 6 evaluation samples.
[2025-07-13 12:05:59,395] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)






      Starting MeshDiT Model Training



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msakibahmed2018go[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss
0,1.1072,0.999844
1,1.107,1.002036
2,1.1056,1.000109
3,1.1064,0.999032
4,1.1055,0.998316
5,1.1048,0.998247
6,1.1042,0.992885
7,1.1028,0.996335
8,1.1017,0.992254
9,1.101,0.989853


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
tokenizer.pad_token = tokenizer.eos_token

print("Initializing MeshDiT model for training...")
model = MeshDiT_models['MeshDiT-S'](
    vocab_size=tokenizer.vocab_size, image_latent_channels=4,
    image_latent_height=64, image_latent_width=64,
    use_rope=True, use_ternary_rope=False
).to(device, dtype=dtype) 

print("Initializing Gaussian diffusion process...")
diffusion = GaussianDiffusion(
    betas=get_named_beta_schedule("linear", 1000),
    model_mean_type=ModelMeanType.EPSILON,
    model_var_type=ModelVarType.FIXED_SMALL,
    loss_type=LossType.MSE,
)


all_data = process_all_meshes(download_dir="downloads", metadata_file="downloads/metadata.json")


if all_data:
    random.shuffle(all_data)
    eval_size = int(len(all_data) * 0.05)
    if eval_size == 0 and len(all_data) > 1:
        eval_size = 1
    
    train_data = all_data[eval_size:]
    eval_data = all_data[:eval_size]
    print(f"Data split: {len(train_data)} training samples, {len(eval_data)} evaluation samples.")
    train_dataset = MeshDataset(train_data)
    eval_dataset = MeshDataset(eval_data)
    data_collator = CustomDataCollator()
    
    training_args = TrainingArguments(
    output_dir="./mesh_dit_checkpoint",
    num_train_epochs=200,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    weight_decay=0.01,
    warmup_steps=200,
    logging_dir='./logs',
    logging_strategy="epoch",        # <-- changed
    eval_strategy="epoch",           # <-- changed
    save_strategy="epoch",           # <-- changed
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=False,
    max_grad_norm=1.0,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to=["wandb", "tensorboard"],
    run_name="MeshDiT-S-training-float32",
    )
    trainer = MeshDiTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        diffusion=diffusion,
    )
    print("\n" + "="*50)
    print("      Starting MeshDiT Model Training")
    print("="*50 + "\n")
    trainer.train()
    print("\nTraining complete. Saving the final model.")
    trainer.save_model("./mesh_dit_final")
    tokenizer.save_pretrained("./mesh_dit_final")
    print("Model saved to ./mesh_dit_final")
else:
    print("FATAL: Dataset creation failed. No data to train on.")

# =======================================================================================================
# THIS IS DEMO TRAINER FOR TRAINING THE MODEL WITH DEMO DATAPOINT WHILE AVOIDING DATASET PROCESSING
# =======================================================================================================

In [None]:
import os
import json
import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import Trainer, TrainingArguments, AutoTokenizer
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
import wandb

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_API_KEY"] = "89b06c10468af620747b4bd340f72fa5d56f6849"
def process_all_meshes_demo(num_samples, tokenizer):
    """
    Generates a list of random dummy data points with the same shape and
    structure as the output of the original `process_all_meshes` function.
    """
    print(f"Generating {num_samples} dummy data samples for testing...")
    all_data = []
    x_shape = (1024, 768)
    image_latent_shape = (4, 64, 64)
    text_shape = (128,)

    for _ in range(num_samples):
        sample = {
            "x": torch.randn(x_shape, dtype=torch.float32),
            "y": {
                "image_latent": torch.randn(image_latent_shape, dtype=torch.float32),
                "input_ids": torch.randint(0, tokenizer.vocab_size, text_shape, dtype=torch.long),
                "attention_mask": torch.ones(text_shape, dtype=torch.long)
            }
        }
        all_data.append(sample)
    print("Dummy data generation complete.")
    return all_data


# --- 2. Dataset, Collator, and Corrected Trainer Class ---

class MeshDataset(Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

class CustomDataCollator:
    def __call__(self, features):
        batch = {}
        batch['x'] = torch.stack([f['x'] for f in features])
        y_features = [f['y'] for f in features]
        batch['y'] = {key: torch.stack([d[key] for d in y_features]) for key in y_features[0]}
        return batch

class MeshDiTTrainer(Trainer):
    def __init__(self, *args, diffusion, **kwargs):
        super().__init__(*args, **kwargs)
        self.diffusion = diffusion

    # MODIFIED: Added **kwargs to the function signature
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        x_start = inputs.get("x")
        model_kwargs = {"y": inputs.get("y")}
        t = torch.randint(0, self.diffusion.num_timesteps, (x_start.shape[0],), device=x_start.device).long()
        noise = torch.randn_like(x_start)
        x_t = self.diffusion.q_sample(x_start, t, noise=noise)
        
        # model's forward pass
        model_output = model(x_t, t, y=inputs['y'])
        
        # diffusion model's loss calculation
        loss_dict = self.diffusion.training_losses(model, x_start, t, model_kwargs, noise=noise)
        loss = loss_dict["loss"].mean()
        
        return (loss, loss_dict) if return_outputs else loss

    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys = None,
    ):
        """
        Perform an evaluation step on the model.
        """
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            loss = self.compute_loss(model, inputs)
            
        return (loss.detach(), None, None)


# --- 3. Main Setup and Execution ---

# Setup device and data type
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
tokenizer.pad_token = tokenizer.eos_token

# Generate dummy data
all_data = process_all_meshes_demo(num_samples=100, tokenizer=tokenizer)

# Initialize Model
print("Initializing MeshDiT model for training...")
model = MeshDiT_models['MeshDiT-S'](
    input_tokens=512,
    vocab_size=tokenizer.vocab_size, image_latent_channels=4,
    image_latent_height=64, image_latent_width=64,
    use_rope=True, use_ternary_rope=False
).to(device, dtype=dtype)

# Initialize Diffusion Process
print("Initializing Gaussian diffusion process...")
diffusion = GaussianDiffusion(
    betas=get_named_beta_schedule("linear", 1000),
    model_mean_type=ModelMeanType.EPSILON,
    model_var_type=ModelVarType.FIXED_SMALL,
    loss_type=LossType.MSE,
)

if all_data:
    # Split data and create datasets
    random.shuffle(all_data)
    train_data = all_data[5:]
    eval_data = all_data[:5]
    print(f"Data split: {len(train_data)} training samples, {len(eval_data)} evaluation samples.")
    train_dataset = MeshDataset(train_data)
    eval_dataset = MeshDataset(eval_data)
    data_collator = CustomDataCollator()

    # Configure training arguments
    training_args = TrainingArguments(
        output_dir="./mesh_dit_test_checkpoint",
        num_train_epochs=5,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        learning_rate=1e-4,
        logging_strategy="steps",
        logging_steps=1,
        eval_strategy="epoch",
        save_strategy="epoch",
        fp16=False,
        report_to="wandb",
        run_name="MeshDiT-S-test-run-fixed",
    )

    # Initialize Trainer
    trainer = MeshDiTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        diffusion=diffusion,
    )

    # Start Training
    print("\n" + "="*50)
    print("      Starting MeshDiT Model Training Test")
    print("="*50 + "\n")
    trainer.train()
    print("\n--- Test training finished successfully! ---")
else:
    print("FATAL: Dummy data creation failed.")

# =========================================================================================
# MESH GENERATION SCRIPT: LOAD MODEL AND GENERATE MESH
# This script loads your trained model and generates a mesh.
# =========================================================================================

In [None]:
import torch
import numpy as np
import trimesh
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoTokenizer
from diffusers import AutoencoderKL
from tqdm import tqdm
import os 
from safetensors.torch import load_file
import torchvision.transforms as transforms
import sys

# --- Local Imports ---
from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule, _extract_into_tensor


print("--- Starting Inference Setup ---")

# --- 1. Configuration ---
MODEL_PATH = "./mesh_dit_final"
OUTPUT_FILENAME = "generated_mesh_final.glb"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32

TEXT_PROMPT = "a green rifle with a long barrel"
IMAGE_URL = "https://media.sketchfab.com/models/ed51a51909ee46c780db3a85e821feb2/thumbnails/f46b00b385d4449d923aff78348b04c6/eb41be4a1dcd450a974db9d24dc13d16.jpeg"

CFG_SCALE_TEXT = 7.5
CFG_SCALE_IMAGE = 1.5
NUM_SAMPLING_STEPS = 250

print(f"Loading models onto {DEVICE}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
print("-> Tokenizer loaded.")

generation_model = MeshDiT_models['MeshDiT-S'](
    input_tokens=2048,
    input_dim=64,
    vocab_size=tokenizer.vocab_size,
    image_latent_channels=4,
    image_latent_height=64,
    image_latent_width=64,
    use_rope=True,
    use_ternary_rope=False
).to(DEVICE, dtype=DTYPE).eval()

state_dict = load_file(os.path.join(MODEL_PATH, "model.safetensors"), device=DEVICE)
generation_model.load_state_dict(state_dict)
print("-> Custom DiT model loaded successfully.")

geometry_pipeline = Step1X3DGeometryPipeline.from_pretrained("stepfun-ai/Step1X-3D", subfolder='Step1X-3D-Geometry-1300m', torch_dtype=DTYPE)
vae_3d = geometry_pipeline.vae.to(DEVICE).eval()
print("-> 3D VAE loaded.")

vae_image = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval()
print("-> Image VAE loaded.")

diffusion = GaussianDiffusion(betas=get_named_beta_schedule("linear", 1000), model_mean_type=ModelMeanType.EPSILON, model_var_type=ModelVarType.FIXED_SMALL, loss_type=LossType.MSE)
print("--- All components loaded successfully! ---")

print("\n--- Preparing Generation Inputs ---")

tokens = tokenizer(TEXT_PROMPT, padding="max_length", max_length=128, truncation=True, return_tensors="pt")
input_ids = tokens["input_ids"].to(DEVICE)
attention_mask = tokens["attention_mask"].to(DEVICE)

def get_image_latent(url: str, vae: AutoencoderKL, device: str) -> torch.Tensor:
    response = requests.get(url)
    response.raise_for_status()
    img_data = BytesIO(response.content)
    raw_image = Image.open(img_data).convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    image_tensor = transform(raw_image).unsqueeze(0).to(device=device, dtype=vae.dtype)
    with torch.no_grad():
        image_latent = vae.encode(image_tensor).latent_dist.sample() * vae.config.scaling_factor
    return image_latent

image_latent = get_image_latent(IMAGE_URL, vae_image, DEVICE)
print(f"-> Conditioning image latent created with shape: {image_latent.shape}")

model_kwargs = {"y": {"image_latent": image_latent, "input_ids": input_ids, "attention_mask": attention_mask}}

print("\n--- Starting Denoising Process ---")

z = torch.randn(1, 2048, 64, device=DEVICE, dtype=DTYPE) 

ddim_timesteps = np.asarray(list(range(0, 1000, 1000 // NUM_SAMPLING_STEPS)))
ddim_steps = torch.from_numpy(ddim_timesteps).long().to(DEVICE)

with torch.no_grad():
    for i in tqdm(range(NUM_SAMPLING_STEPS - 1, -1, -1), desc="DDIM Sampling"):
        t = ddim_steps[i].expand(z.shape[0])
        
        z_in = torch.cat([z, z], dim=0)
        t_in = torch.cat([t, t], dim=0)
        y_in = {
            "image_latent": torch.cat([model_kwargs['y']['image_latent']] * 2, dim=0),
            "input_ids": torch.cat([model_kwargs['y']['input_ids']] * 2, dim=0),
            "attention_mask": torch.cat([model_kwargs['y']['attention_mask']] * 2, dim=0)
        }
        
        noise_pred = generation_model.forward_with_cfg(
            z_in, t_in, y_in, CFG_SCALE_TEXT, CFG_SCALE_IMAGE
        )
        
        alpha_t = _extract_into_tensor(diffusion.alphas_cumprod, t, z.shape)
        t_prev_idx = ddim_steps[i - 1] if i > 0 else torch.tensor([-1], device=DEVICE, dtype=torch.long)
        alpha_t_prev = _extract_into_tensor(diffusion.alphas_cumprod_prev, t, z.shape)

        pred_x0 = (z - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        dir_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        z = alpha_t_prev.sqrt() * pred_x0 + dir_xt

generated_latent_seq = z
print("--- Denoising complete. ---")

print("\n--- Decoding Latent and Preparing Mesh ---")
with torch.no_grad():
    decoded_latents = vae_3d.decode(generated_latent_seq)
    
    mesh_result = vae_3d.extract_geometry(
        decoded_latents,
        mc_level=0.5,
        bounds=[-1, -1, -1, 1, 1, 1],
        octree_resolution=256
    )[0]

final_mesh = trimesh.Trimesh(
    vertices=mesh_result.verts.cpu().numpy(),
    faces=mesh_result.faces.cpu().numpy()
)

print("-> Applying white texture to the mesh...")
white_texture = Image.new('RGB', (128, 128), (255, 255, 255))
material = trimesh.visual.texture.SimpleMaterial(image=white_texture)
final_mesh.visual.material = material
print("-> Texture applied successfully.")

final_mesh.export(OUTPUT_FILENAME)
print(f"\n--- ✨ Success! Mesh saved to {OUTPUT_FILENAME} ---")


# =========================================================================================
# TRAIN TEXTURE GENERATION MODEL: DOWNLOAD DATASET & TRAIN THE MODEL
# This script download the dataset then process them and lastly train the Texture gen model.
# =========================================================================================

In [None]:
pip install trimesh pyrender xatlas opencv-python torch scipy

In [None]:
from huggingface_hub import login
login(token="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")

In [None]:
import urllib.request
url = "https://huggingface.co/datasets/stepfun-ai/Step1X-3D-obj-data/resolve/main/objaverse_texture_30k.json"
file_name = "objaverse_texture_30k.json"

try:
    print(f"Downloading {file_name} from {url}...")
    urllib.request.urlretrieve(url, file_name)
    print(f"Successfully downloaded and saved as '{file_name}'")
except Exception as e:
    print(f"An error occurred while trying to download the file: {e}")

In [None]:
import json
import os
import urllib.request

JSON_FILE_PATH = 'objaverse_texture_30k.json'
DOWNLOAD_DIR = 'objaverse_dataset'
BASE_URL = 'https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/'  # fixed from 'blob' to 'resolve'

MAX_FILES = 10 # Set to an integer like 1000 to limit, or None for all

if not os.path.exists(DOWNLOAD_DIR):
    os.makedirs(DOWNLOAD_DIR)

try:
    with open(JSON_FILE_PATH, 'r') as f:
        file_paths = json.load(f)
except FileNotFoundError:
    print(f"Error: The file '{JSON_FILE_PATH}' was not found.")
    exit()
except json.JSONDecodeError:
    print(f"Error: The file '{JSON_FILE_PATH}' is not a valid JSON file.")
    exit()

total_files = len(file_paths)
if MAX_FILES is not None:
    file_paths = file_paths[:MAX_FILES]
print(f"Found {total_files} files to download. Downloading {len(file_paths)} files...")

for i, relative_path in enumerate(file_paths):
    full_url = f"{BASE_URL}{relative_path}"
    local_folder = os.path.join(DOWNLOAD_DIR, os.path.dirname(relative_path))
    if not os.path.exists(local_folder):
        os.makedirs(local_folder)

    file_name = os.path.basename(relative_path)
    local_file_path = os.path.join(local_folder, file_name)

    if not os.path.exists(local_file_path):
        print(f"Downloading ({i+1}/{len(file_paths)}): {full_url}")
        try:
            urllib.request.urlretrieve(full_url, local_file_path)
            print(f" -> Saved to {local_file_path}")
        except Exception as e:
            print(f" -> An error occurred: {e}")
    else:
        print(f"Skipping ({i+1}/{len(file_paths)}): {local_file_path} (already exists)")

print("\nDownload process finished.")
print(f"All 3D models are now in the '{DOWNLOAD_DIR}' folder.")


In [None]:
import os
os.environ['PYOPENGL_PLATFORM'] = 'egl'
import glob
import json
import trimesh
import numpy as np
import pyrender
import xatlas
import cv2
import torch
from scipy.spatial.transform import Rotation

DOWNLOAD_DIR = 'objaverse_dataset'
OUTPUT_DIR = 'processed_texture_data'
RESOLUTION = 768
SUBDIVISIONS = 1
LAPLACIAN_ITERATIONS = 1

def process_mesh(glb_path, output_mesh_path):
    """Enhanced mesh processing with remeshing and smoothing"""
    print(f"  - Post-processing mesh: {os.path.basename(glb_path)}")
    try:
        mesh = trimesh.load(glb_path, force='mesh')
        if isinstance(mesh, trimesh.Scene):
            mesh = mesh.dump(concatenate=True)
        
        
        if not mesh.is_watertight:
            print("    - Filling holes to make watertight")
            mesh.fill_holes()
        print("    - Generating UVs with xAtlas...")
        vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
        processed_mesh = trimesh.Trimesh(
            vertices=mesh.vertices[vmapping],
            faces=indices,
            vertex_normals=mesh.vertex_normals[vmapping] if mesh.vertex_normals is not None else None,
            visual=trimesh.visual.TextureVisuals(uv=uvs)
        )
        processed_mesh.export(output_mesh_path)
        return processed_mesh
    except Exception as e:
        print(f"    - Error processing mesh: {str(e)}")
        return None

def render_albedo_map(scene, camera_node, renderer, view_pose):
    """Render albedo map using flat shading with full lighting"""
    scene.set_pose(camera_node, pose=view_pose)
    
    # Remove all existing lights
    for node in list(scene.light_nodes):
        scene.remove_node(node)
    
    # Add strong directional light from camera position
    light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)
    light_pose = np.copy(view_pose)
    light_pose[:3, 3] += light_pose[:3, 2] * 0.1
    scene.add(light, pose=light_pose)
    
    # Render with flat shading
    color = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)[0]
    return color[..., :3]

def vectorized_unproject(depth, camera, view_pose):
    """GPU-accelerated position map generation"""
    height, width = depth.shape
    
    depth_t = torch.tensor(depth, device="cuda", dtype=torch.float32)
    inv_view = torch.tensor(np.linalg.inv(view_pose), device="cuda", dtype=torch.float32)
    inv_proj = torch.tensor(np.linalg.inv(camera.get_projection_matrix()), device="cuda", dtype=torch.float32)
    
    y, x = torch.meshgrid(
        torch.linspace(0, height-1, height, device="cuda"),
        torch.linspace(0, width-1, width, device="cuda"),
        indexing='ij'
    )
    
    ndc_x = (2.0 * x / (width - 1)) - 1.0
    ndc_y = 1.0 - (2.0 * y / (height - 1))
    
    clip_coords = torch.stack([ndc_x, ndc_y, depth_t, torch.ones_like(depth_t)], dim=-1)
    view_coords = torch.matmul(inv_proj, clip_coords.unsqueeze(-1)).squeeze(-1)
    view_coords = view_coords / view_coords[..., 3:4]
    world_coords = torch.matmul(inv_view, view_coords.unsqueeze(-1)).squeeze(-1)
    return world_coords[..., :3].cpu().numpy()

def save_position_map(position_map, min_bound, max_bound, path):
    """Saves position map as 16-bit PNG with XYZ in RGB channels"""
    normalized = (position_map - min_bound) / (max_bound - min_bound)
    normalized = np.clip(normalized, 0, 1)
    position_16bit = (normalized * 65535).astype(np.uint16)
    bgr_16bit = position_16bit[..., [2, 1, 0]]  # XYZ to BGR
    cv2.imwrite(path, bgr_16bit)

def save_normal_map(normal_map, path):
    """Saves normal map as 8-bit PNG with world-space normals"""
    normal_img = ((normal_map + 1) * 127.5).clip(0, 255).astype(np.uint8)
    bgr_img = cv2.cvtColor(normal_img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, bgr_img)

def get_camera_poses():
    """Defines the 6 standard camera views with proper world-space orientation"""
    def look_at(eye, target, up):
        eye = np.asarray(eye, dtype=np.float32)
        target = np.asarray(target, dtype=np.float32)
        up = np.asarray(up, dtype=np.float32)
        z_axis = (target - eye)
        z_axis /= np.linalg.norm(z_axis)
        x_axis = np.cross(z_axis, up)
        x_axis /= np.linalg.norm(x_axis)
        y_axis = np.cross(x_axis, z_axis)
        matrix = np.eye(4, dtype=np.float32)
        matrix[:3, 0] = x_axis
        matrix[:3, 1] = y_axis
        matrix[:3, 2] = -z_axis
        matrix[:3, 3] = eye
        return matrix

    cam_dist = 2.0
    poses = {
        'front': look_at(eye=[0, 0, cam_dist], target=[0, 0, 0], up=[0, 1, 0]),
        'back': look_at(eye=[0, 0, -cam_dist], target=[0, 0, 0], up=[0, 1, 0]),
        'right': look_at(eye=[cam_dist, 0, 0], target=[0, 0, 0], up=[0, 1, 0]),
        'left': look_at(eye=[-cam_dist, 0, 0], target=[0, 0, 0], up=[0, 1, 0]),
        'top': look_at(eye=[0, cam_dist, 0], target=[0, 0, 0], up=[0, 0, -1]),
        'bottom': look_at(eye=[0, -cam_dist, 0], target=[0, 0, 0], up=[0, 0, 1]),
    }
    return poses

def render_world_maps(mesh, camera_poses, output_dir):
    """Render all required maps with GPU acceleration"""
    print("  - Rendering world-space maps...")
    scene = pyrender.Scene(ambient_light=[0.1, 0.1, 0.1])
    scene.add(pyrender.Mesh.from_trimesh(mesh, smooth=True))
    
    camera = pyrender.PerspectiveCamera(yfov=np.pi/3.0, aspectRatio=1.0)
    camera_node = scene.add(camera, pose=np.eye(4))
    renderer = pyrender.OffscreenRenderer(RESOLUTION, RESOLUTION)
    
    min_bound, max_bound = mesh.bounds
    
    output_paths = {
        'albedo_maps': {},
        'normal_maps': {},
        'position_maps': {}
    }
    
    for view_name, view_pose in camera_poses.items():
        print(f"    - Rendering {view_name} view...")
        
        # Render Albedo Map
        albedo_map = render_albedo_map(scene, camera_node, renderer, view_pose)
        albedo_path = os.path.join(output_dir, f"{view_name}_albedo.png")
        cv2.imwrite(albedo_path, cv2.cvtColor(albedo_map, cv2.COLOR_RGB2BGR))
        output_paths['albedo_maps'][view_name] = albedo_path
        
        # Render Depth for Position (single return value)
        depth = renderer.render(scene, flags=pyrender.RenderFlags.DEPTH_ONLY)
        
        # GPU Position Map Calculation
        position_map = vectorized_unproject(depth, camera, view_pose)
        position_path = os.path.join(output_dir, f"{view_name}_position.png")
        save_position_map(position_map, min_bound, max_bound, position_path)
        output_paths['position_maps'][view_name] = position_path
        
        # Direct World-Space Normal Map
        scene.set_pose(camera_node, pose=view_pose)
        normal_map = renderer.render(scene, flags=pyrender.RenderFlags.VERTEX_NORMALS)[0]
        normal_map = normal_map[..., :3].astype(np.float32) * 2 - 1
        rot_matrix = np.linalg.inv(view_pose)[:3, :3]
        world_normal_map = normal_map @ rot_matrix.T
        normal_path = os.path.join(output_dir, f"{view_name}_normal.png")
        save_normal_map(world_normal_map, normal_path)
        output_paths['normal_maps'][view_name] = normal_path
    
    renderer.delete()
    return output_paths

# --- Main Execution ---
if __name__ == "__main__":
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    glb_files = glob.glob(os.path.join(DOWNLOAD_DIR, '**', '*.glb'), recursive=True)

    if not glb_files:
        print(f"Error: No .glb files found in '{DOWNLOAD_DIR}'.")
        exit()

    print(f"Found {len(glb_files)} .glb files to process.")
    batch_data = []

    for i, glb_path in enumerate(glb_files):
        print(f"\nProcessing ({i+1}/{len(glb_files)}): {os.path.basename(glb_path)}")
        model_id = os.path.splitext(os.path.basename(glb_path))[0]
        model_output_dir = os.path.join(OUTPUT_DIR, model_id)
        if not os.path.exists(model_output_dir):
            os.makedirs(model_output_dir)

        output_mesh_path = os.path.join(model_output_dir, f"{model_id}_processed.glb")
        mesh = process_mesh(glb_path, output_mesh_path)
        if mesh is None:
            print(f"    - Skipping due to processing error")
            continue

        camera_poses = get_camera_poses()
        rendered_paths = render_world_maps(mesh, camera_poses, model_output_dir)

        record = {
            'model_id': model_id,
            'original_path': glb_path,
            'processed_mesh_path': output_mesh_path,
            'albedo_maps': rendered_paths['albedo_maps'],
            'normal_maps': rendered_paths['normal_maps'],
            'position_maps': rendered_paths['position_maps'],
            'camera_poses': {name: pose.tolist() for name, pose in camera_poses.items()},
            'bounds': [mesh.bounds[0].tolist(), mesh.bounds[1].tolist()]
        }
        batch_data.append(record)

    batch_file_path = os.path.join(OUTPUT_DIR, 'batch_data.json')
    with open(batch_file_path, 'w') as f:
        json.dump(batch_data, f, indent=4)
        
    print(f"\nProcessing complete. {len(batch_data)} assets processed successfully.")
    print(f"Assets saved to: {OUTPUT_DIR}")
    print(f"Batch manifest: {batch_file_path}")