In [None]:
!pip install -U git+https://github.com/Sakib323/AI-Game-Engine.git
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install transformers
!pip install datasets
!pip install wandb
!pip install -U datasets
!pip install objaverse
!pip install diffusers
!pip install trimesh
!pip install jaxtyping
!pip install pytorch-lightning
!pip install ijson
!pip install triton==3.2.0
!pip install wandb

In [None]:
!git clone --depth 1 --branch main https://github.com/stepfun-ai/Step1X-3D.git
import sys
sys.path.append("./Step1X-3D")  
import os
print(os.listdir("./Step1X-3D"))
!pip install -r ./Step1X-3D/requirements.txt --verbose

In [None]:
pip install torch-cluster -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html

# =========================================================================================
# TRAIN MESH GENERATION MODEL: DOWNLOAD DATASET & TRAIN THE MODEL
# This script download the dataset then process them and lastly train the model.
# =========================================================================================

In [None]:
import requests
import gzip
import json
import objaverse
import gc
import os
import ijson


KEYS_TO_EXCLUDE = {
    "tags", "name", "staffpickedAt", "viewCount", "likeCount", "animationCount",
    "commentCount", "publishedAt", "user", "description", "faceCount", "createdAt",
    "vertexCount", "license", "uri", "viewerUrl", "embedUrl", "isDownloadable",
    "categories", "isAgeRestricted", "archives"
}

print("Starting dataset preparation...")

# Check if captions file exists; download only if necessary
url = "https://huggingface.co/datasets/tiange/Cap3D/resolve/main/Objaverse_files/cap3d_captions.json.gz"
captions_file = 'cap3d_captions.json.gz'
if not os.path.exists(captions_file):
    print("Downloading captions file...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        with open(captions_file, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Captions file downloaded successfully.")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading captions: {e}")
        exit()
else:
    print("Captions file already exists, skipping download.")

def process_batch(batch, f, total_counter):
    uids = [uid for uid, _ in batch]
    try:
        annotations = objaverse.load_annotations(uids)
    except Exception as e:
        print(f"Error loading annotations for batch: {e}")
        return 0  # Skip this batch on error

    for i, (uid, caption) in enumerate(batch):
        metadata = annotations.get(uid, {})
        filtered_metadata = {k: v for k, v in metadata.items() if k not in KEYS_TO_EXCLUDE}

        # Filter thumbnails to keep only 1024x576 images
        if "thumbnails" in filtered_metadata and "images" in filtered_metadata["thumbnails"]:
            filtered_metadata["thumbnails"]["images"] = [
                img for img in filtered_metadata["thumbnails"]["images"]
                if img.get("width") == 1024 and img.get("height") == 576
            ]

        datapoint = {
            "uid": uid,
            "caption": caption,
            **filtered_metadata
        }
        f.write(json.dumps(datapoint) + '\n')
        if total_counter + i < 10:
            print(f"Datapoint {uid}:")
            print(json.dumps(datapoint, indent=2))
            print("-" * 50)
    del annotations
    gc.collect()
    return len(batch)


batch_size = 100000
total_counter = 0
with open('extended_dataset.jsonl', 'w', encoding='utf-8') as f:
    batch = []
    for uid, caption in ijson.kvitems(gzip.open(captions_file, 'rt', encoding='utf-8'), ''):
        batch.append((uid, caption))
        if len(batch) >= batch_size:
            processed = process_batch(batch, f, total_counter)
            total_counter += processed
            print(f"Processed {total_counter} captions so far.")
            batch = []
    if batch:
        processed = process_batch(batch, f, total_counter)
        total_counter += processed

print(f"Total entries in extended dataset: {total_counter}")
print("Dataset saved to 'extended_dataset.jsonl'")
print("Script finished.")

In [None]:
import json
import objaverse
import pathlib
import shutil
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  

jsonl_path = "extended_dataset.jsonl"
output_dir = pathlib.Path("downloads")
output_dir.mkdir(exist_ok=True)
metadata_output_path = output_dir / "metadata.json"
download_limit = 100

uid_metadata = {}
with open(jsonl_path, "r", encoding="utf-8") as f:
    for line in f:
        if len(uid_metadata) >= download_limit:
            break
        entry = json.loads(line)
        uid = entry["uid"]
        caption = entry.get("caption", "")
        thumbnails = entry.get("thumbnails", {}).get("images", [])
        thumbnail_urls = [img["url"] for img in thumbnails if "url" in img]

        uid_metadata[uid] = {
            "caption": caption,
            "thumbnails": thumbnail_urls
        }

uids = list(uid_metadata.keys())
print(f"Downloading {len(uids)} models...")
paths = objaverse.load_objects(uids=uids)

final_metadata = {}
for uid, src_path in paths.items():
    dst_path = output_dir / f"{uid}.glb"
    try:
        shutil.copy(src_path, dst_path)
        final_metadata[uid] = uid_metadata[uid]
        print(f"Saved {uid} to {dst_path}")
    except Exception as e:
        print(f"Failed to save {uid}: {e}")

with open(metadata_output_path, "w", encoding="utf-8") as f:
    json.dump(final_metadata, f, indent=2)

print(f"\nDownloaded {len(final_metadata)} models.")
print(f"Metadata saved to {metadata_output_path}")


In [2]:
import requests
import os
import tempfile
import trimesh
import numpy as np
import torch
from transformers import AutoTokenizer
from tqdm import tqdm
import random
import traceback
import logging
import sys
from diffusers import AutoencoderKL


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")
sys.path.append("./Step1X-3D")

from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline

from huggingface_hub import hf_hub_download
print("Initializing Step1X-3D VAE...")
try:
    geometry_pipeline = Step1X3DGeometryPipeline.from_pretrained(
        "stepfun-ai/Step1X-3D",
        subfolder='Step1X-3D-Geometry-1300m',
        torch_dtype=dtype
    )
    vae = geometry_pipeline.vae.to(device)
    vae.eval()
    print("Step1X-3D VAE initialized successfully.")
except Exception as e:
    print(f"Error initializing pipeline. Make sure you have the Step1X-3D repo cloned and in your sys.path. Error: {e}")
    vae = None

tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
tokenizer.pad_token = tokenizer.eos_token


image_vae = AutoencoderKL.from_pretrained(
        "stabilityai/stable-diffusion-2-1",
        subfolder="vae",
        torch_dtype=dtype
    ).to(device).eval()
print("Image VAE model loaded.")


Using device: cuda, dtype: torch.float32
Initializing Step1X-3D VAE...


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

/root/.cache/huggingface/hub/models--stepfun-ai--Step1X-3D/snapshots/bf7084495b3a72222f36549b7942948aa4d9daa7


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading Dinov2 model from facebook/dinov2-with-registers-large
Step1X-3D VAE initialized successfully.
Image VAE model loaded.


In [3]:
from diffusers import AutoencoderKL
import requests
from PIL import Image
from io import BytesIO
import torch
import torchvision.transforms as transforms
from torchvision.transforms.functional import crop
from diffusers import AutoencoderKL
import torchvision.transforms as T


def get_image_latent(url: str, vae: AutoencoderKL, device: str = "cuda") -> torch.Tensor:

    try:
        response = requests.get(url)
        response.raise_for_status()
        img_data = BytesIO(response.content)
        raw_image = Image.open(img_data).convert("RGB")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading image: {e}")
        return None
    except IOError as e:
        print(f"Error opening image: {e}")
        return None

    preprocess = T.Compose([
        T.Resize(512, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(512),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    image_tensor = preprocess(raw_image).unsqueeze(0).to(device=device, dtype=vae.dtype)


    with torch.no_grad():
        latent_dist = vae.encode(image_tensor).latent_dist
        latents = latent_dist.sample()
        latents = latents * vae.config.scaling_factor

    return latents
    

In [4]:
import os
import json
import torch
import traceback
import numpy as np
import trimesh
from tqdm import tqdm
import logging
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def compute_dihedral_angle(normal1, normal2):
    cos_angle = np.dot(normal1, normal2)
    cos_angle = np.clip(cos_angle, -1.0, 1.0)
    return np.degrees(np.arccos(cos_angle))

def robust_sharp_edge_sampling(mesh, num_points=16384, sharp_threshold=30):
    uniform_count = int(num_points * 0.9)
    uniform_points, face_indices = trimesh.sample.sample_surface(mesh, uniform_count)
    uniform_normals = mesh.face_normals[face_indices]

    sharp_edges = []
    if hasattr(mesh, 'edges_face'):
        for edge, face_pair in zip(mesh.edges, mesh.edges_face):
            if isinstance(face_pair, (int, np.integer)):
                continue
            if len(face_pair) == 2:
                normal1 = mesh.face_normals[face_pair[0]]
                normal2 = mesh.face_normals[face_pair[1]]
                angle = compute_dihedral_angle(normal1, normal2)
                if angle > sharp_threshold:
                    sharp_edges.append(edge)

    sharp_count = num_points - uniform_count
    sharp_points, sharp_normals = [], []
    if sharp_edges:
        edges = mesh.vertices[np.array(sharp_edges)]
        for edge in edges:
            t = np.random.rand(sharp_count // max(1, len(sharp_edges)))
            points = (1 - t[:, None]) * edge[0] + t[:, None] * edge[1]
            sharp_points.extend(points)
            sharp_normals.extend([np.mean(mesh.face_normals[mesh.edges_face[mesh.edges.tolist().index(list(edge))]], axis=0)] * len(points))
    else:
        sharp_points, sharp_face_idx = trimesh.sample.sample_surface(mesh, sharp_count)
        sharp_normals = mesh.face_normals[sharp_face_idx]

    all_points = np.vstack([uniform_points, sharp_points])[:num_points]
    all_normals = np.vstack([uniform_normals, sharp_normals])[:num_points]
    return all_points, all_normals, np.array(sharp_points), np.array(sharp_normals)

def process_mesh_to_vae_input(mesh_path, num_points=32768):
    try:
        ext = os.path.splitext(mesh_path)[1].lower()
        mesh = trimesh.load(mesh_path, force='mesh', file_type=ext[1:] if ext else None)
        if isinstance(mesh, trimesh.Scene):
            mesh = mesh.dump().sum()
        if not mesh.is_watertight:
            try:
                mesh.fill_holes()
            except Exception:
                logger.warning(f"Hole filling failed for {mesh_path}")

        points, normals, sharp_points, sharp_normals = robust_sharp_edge_sampling(mesh, num_points=num_points)
        if points.shape[0] < num_points:
            pad = num_points - points.shape[0]
            points = np.pad(points, ((0, pad), (0, 0)), mode='wrap')
            normals = np.pad(normals, ((0, pad), (0, 0)), mode='wrap')
        if sharp_points.shape[0] < num_points:
            pad = num_points - sharp_points.shape[0]
            sharp_points = np.pad(sharp_points, ((0, pad), (0, 0)), mode='wrap')
            sharp_normals = np.pad(sharp_normals, ((0, pad), (0, 0)), mode='wrap')

        centroid = np.mean(points, axis=0)
        points -= centroid
        max_distance = np.max(np.linalg.norm(points, axis=1))
        if max_distance > 1e-6:
            points /= max_distance
            sharp_points -= centroid
            sharp_points /= max_distance

        point_cloud = np.hstack([points, normals])
        sharp_cloud = np.hstack([sharp_points, sharp_normals]) if len(sharp_points) > 0 else None

        return {
            "surface": torch.tensor(point_cloud, dtype=dtype).unsqueeze(0),
            "sharp_surface": torch.tensor(sharp_cloud, dtype=dtype).unsqueeze(0) if sharp_cloud is not None else None
        }
    except Exception as e:
        logger.error(f"Error processing {mesh_path}: {str(e)}")
        logger.debug(traceback.format_exc())
        return None

def process_all_meshes(download_dir="downloads", metadata_file="downloads/metadata.json"):
    """
    MODIFIED: Processes all meshes from the metadata file into a clean dataset.
    This function now stores the original, clean latents (x_start) and returns the data.
    The noising process is handled dynamically by the Trainer.
    """
    logger.info("Starting dataset creation with the updated process_all_meshes function...")
    processed_data = []
    
    if not os.path.exists(metadata_file):
        logger.error(f"Metadata file not found at: {metadata_file}")
        return []
        
    with open(metadata_file, 'r', encoding='utf-8') as f:
        metadata = json.load(f)

    for uid, info in tqdm(metadata.items(), desc="Processing Raw Data"):
        mesh_path = os.path.join(download_dir, f"{uid}.glb")
        caption = info.get('caption', '')
        
        thumbnails = info.get('thumbnails')
        if not thumbnails or not isinstance(thumbnails, list) or not thumbnails[0]:
            continue  
        image_url = thumbnails[0]


        if not os.path.isfile(mesh_path) or not image_url:
            continue

        # 1. Process 3D Mesh -> Get clean 3D latent (x_start)
        mesh_inputs = process_mesh_to_vae_input(mesh_path, num_points=32768)
        if mesh_inputs is None:
            continue
        
        mesh_inputs_on_device = {
            "surface": mesh_inputs["surface"].to(device),
            "sharp_surface": mesh_inputs["sharp_surface"].to(device) if mesh_inputs.get("sharp_surface") is not None else None
        }
        



        with torch.no_grad():
            # *** YOUR FIX for AttributeError + RuntimeError ***
            # This logic now handles all known VAE output formats.
            encode_output = vae.encode(**mesh_inputs_on_device)
            if isinstance(encode_output, tuple):
                encode_output = encode_output[0]
            
            if hasattr(encode_output, "latent_dist"):
                latent_3d = encode_output.latent_dist.sample().squeeze(0).cpu()
            elif isinstance(encode_output, torch.Tensor):
                latent_3d = encode_output.squeeze(0).cpu()
            else:
                raise RuntimeError(f"Unexpected VAE encode_output type: {type(encode_output)}")


        # 2. Process Image -> Get Image latent
        latent_image = get_image_latent(image_url, image_vae, device=device)
        if latent_image is None:
            continue
        latent_image = latent_image.cpu()

        # 3. Process Text -> Get token IDs
        tokens = tokenizer(
            caption, padding="max_length", max_length=128, truncation=True, return_tensors="pt"
        )

        processed_data.append({
            "x": latent_3d,
            "y": {
                "image_latent": latent_image,
                "input_ids": tokens["input_ids"].squeeze(0),
                "attention_mask": tokens["attention_mask"].squeeze(0),
            }
        })

    logger.info(f"Successfully created dataset with {len(processed_data)} samples.")
    return processed_data



In [None]:
!apt-get install -y libaio-dev > /dev/null


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- 1. Required Imports ---
import json
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import Trainer, TrainingArguments
import random
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
import wandb


os.environ["WANDB_API_KEY"] = "89b06c10468af620747b4bd340f72fa5d56f6849"


# --- 2. Dataset, Collator, and Trainer Classes (Updated) ---

class MeshDataset(Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

class CustomDataCollator:
    def __call__(self, features):
        batch = {}
        batch['x'] = torch.stack([f['x'] for f in features])
        y_features = [f['y'] for f in features]
        batch['y'] = {key: torch.stack([d[key] for d in y_features]) for key in y_features[0]}
        return batch

class MeshDiTTrainer(Trainer):
    def __init__(self, *args, diffusion, **kwargs):
        super().__init__(*args, **kwargs)
        self.diffusion = diffusion

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        x_start = inputs.get("x")
        model_kwargs = {"y": inputs.get("y")}
        t = torch.randint(0, self.diffusion.num_timesteps, (x_start.shape[0],), device=x_start.device).long()
        noise = torch.randn_like(x_start)
        x_t = self.diffusion.q_sample(x_start, t, noise=noise)
        
        # model's forward pass
        model_output = model(x_t, t, y=inputs['y'])
        
        # diffusion model's loss calculation
        loss_dict = self.diffusion.training_losses(model, x_start, t, model_kwargs, noise=noise)
        loss = loss_dict["loss"].mean()
        
        return (loss, loss_dict) if return_outputs else loss

    # MODIFIED: Added prediction_step to handle evaluation correctly
    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys = None,
    ):
        """
        Perform an evaluation step on the model.
        """
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            loss = self.compute_loss(model, inputs)
            
        return (loss.detach(), None, None)



print("Initializing MeshDiT model for training...")
model = MeshDiT_models['MeshDiT-S'](
    vocab_size=tokenizer.vocab_size, image_latent_channels=4,
    image_latent_height=64, image_latent_width=64,
    use_rope=True, use_ternary_rope=False
).to(device, dtype=dtype)

print("Initializing Gaussian diffusion process...")
diffusion = GaussianDiffusion(
    betas=get_named_beta_schedule("linear", 1000),
    model_mean_type=ModelMeanType.EPSILON,
    model_var_type=ModelVarType.FIXED_SMALL,
    loss_type=LossType.MSE,
)

all_data = process_all_meshes(download_dir="downloads", metadata_file="downloads/metadata.json")


if all_data:
    # --- ADDED: Shuffle and split the data ---
    random.shuffle(all_data)
    eval_size = int(len(all_data) * 0.05) 
    if eval_size == 0 and len(all_data) > 1: 
        eval_size = 1
    
    train_data = all_data[eval_size:]
    eval_data = all_data[:eval_size]
    print(f"Data split: {len(train_data)} training samples, {len(eval_data)} evaluation samples.")
    train_dataset = MeshDataset(train_data)
    eval_dataset = MeshDataset(eval_data)
    data_collator = CustomDataCollator()

    training_args = TrainingArguments(
    output_dir="./mesh_dit_checkpoint",
    num_train_epochs=100,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=500,
    logging_dir='./logs',
    logging_strategy="steps",
    logging_steps=20,
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=False,  # Correctly set to False as we are using float32
    max_grad_norm=1.0,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to=["wandb", "tensorboard"],
    run_name="MeshDiT-S-training-float32",
    )

    
    trainer = MeshDiTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,    
        data_collator=data_collator,
        diffusion=diffusion,
    )

    print("\n" + "="*50)
    print("      Starting MeshDiT Model Training")
    print("="*50 + "\n")

    trainer.train()

    print("\nTraining complete. Saving the final model.")
    # Because load_best_model_at_end=True, the saved model will be the one
    # that had the lowest loss on the evaluation set.
    trainer.save_model("./mesh_dit_final")
    tokenizer.save_pretrained("./mesh_dit_final")
    print("Model saved to ./mesh_dit_final")
else:
    print("FATAL: Dataset creation failed. No data to train on.")

Initializing MeshDiT model for training...
Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with theta=10000.0 and ternary=False

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=False

Initializing RotaryEmbedding with

Processing Raw Data:   8%|▊         | 8/100 [00:59<14:41,  9.58s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/d7340a5b05b6460facbd90aaafb7f1f1/thumbnails/868c75b689ac4da681e10d9d857ea075/108718dab17e4802827a6c1752cd3929.jpeg


Processing Raw Data:  10%|█         | 10/100 [01:08<09:50,  6.56s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/3764a2442eca43c7bbe64d8297d84905/thumbnails/2f02d5a090de4cc1bf790f8023c82dca/a7e0a494e1db40b99709b67b9223a573.jpeg


Processing Raw Data:  11%|█         | 11/100 [01:09<07:07,  4.81s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/9ff331513cea44a2938d09a03d7b0493/thumbnails/59e0d3c53d0c4c76aac836eafccee485/349e5b6fb5f9467ab32fc9623fb5e8ae.jpeg


Processing Raw Data:  21%|██        | 21/100 [02:24<20:22, 15.47s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/d40f6a7e7dc049a6b9e63c2ac31918f0/thumbnails/52b270d8988a4cefaa7e85a7585f23d6/327f77b5459b419e84ae410e25444783.jpeg


Processing Raw Data:  58%|█████▊    | 58/100 [03:41<00:36,  1.15it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/39c23be7af5848408ad74a744a127880/thumbnails/b8f74f22c2b24142b7803ed231dc6704/7e7d47678f9c41a2bf7ee7cbde956049.jpeg


Processing Raw Data:  60%|██████    | 60/100 [03:47<01:18,  1.96s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/0e315229b94543b3b510403f02c797d3/thumbnails/8cadd0cf53c549fda9c3da6628575f88/4cb6c28801e84e0ba7ff1015a0f94d71.jpeg


Processing Raw Data:  61%|██████    | 61/100 [03:49<01:16,  1.96s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/28de7aac69ab41c6be2371e144a38c28/thumbnails/5951f068190d44968f1169524a1a191c/9cfa501fcc314f13a88d3adfaf34056f.jpeg


Processing Raw Data:  71%|███████   | 71/100 [04:09<00:36,  1.26s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/74d467eb13b44d3890862046c3539992/thumbnails/7ffddcf165f64340b16cc627341841f9/ff34e059ae484889b4ef18b6bd2d8d45.jpeg


Processing Raw Data:  86%|████████▌ | 86/100 [04:36<00:26,  1.88s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/c5e8fc0ce4f14e75b29e42b62edf70fe/thumbnails/b11ccde98c5d43b49d7dd8804731c598/06441ad9abd9436fb24c7969d78aa1c4.jpeg


Processing Raw Data:  88%|████████▊ | 88/100 [04:38<00:15,  1.33s/it]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/9fe9043ea98d46179ffa896fbab9fdce/thumbnails/84edfd4596e9458286193592c16868e4/f1a45a452fca4284b943947ac8e1e593.jpeg


Processing Raw Data:  94%|█████████▍| 94/100 [04:44<00:05,  1.07it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/c7f9e145f7ba4744925c2a26fae01176/thumbnails/2311249598c9480397061df7bf6714f7/487086369e9041ad84cb856267d91d73.jpeg


Processing Raw Data:  95%|█████████▌| 95/100 [04:45<00:04,  1.06it/s]

Error downloading image: 403 Client Error: Forbidden for url: https://media.sketchfab.com/models/c640aca2e346421aa9ac1ec494b2567e/thumbnails/e4d4759c9bbb4f999fd83574de7820e5/8b0e7abc78cb4ba8906a753199913de2.jpeg


Processing Raw Data: 100%|██████████| 100/100 [04:52<00:00,  2.93s/it]

Data split: 83 training samples, 4 evaluation samples.
[2025-06-16 16:24:25,533] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)






      Starting MeshDiT Model Training



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msakibahmed2018go[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss


# =======================================================================================================
# THIS IS DEMO TRAINER FOR TRAINING THE MODEL WITH DEMO DATAPOINT WHILE AVOIDING DATASET PROCESSING
# =======================================================================================================

In [None]:
import os
import json
import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import Trainer, TrainingArguments, AutoTokenizer
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
import wandb

# Ensure the tokenizer parallelism warning is suppressed
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Set your wandb API key if you want to log the test run
os.environ["WANDB_API_KEY"] = "89b06c10468af620747b4bd340f72fa5d56f6849"

# --- 1. Demo Data Generation Function ---

def process_all_meshes_demo(num_samples, tokenizer):
    """
    Generates a list of random dummy data points with the same shape and
    structure as the output of the original `process_all_meshes` function.
    """
    print(f"Generating {num_samples} dummy data samples for testing...")
    all_data = []
    x_shape = (1024, 768)
    image_latent_shape = (4, 64, 64)
    text_shape = (128,)

    for _ in range(num_samples):
        sample = {
            "x": torch.randn(x_shape, dtype=torch.float32),
            "y": {
                "image_latent": torch.randn(image_latent_shape, dtype=torch.float32),
                "input_ids": torch.randint(0, tokenizer.vocab_size, text_shape, dtype=torch.long),
                "attention_mask": torch.ones(text_shape, dtype=torch.long)
            }
        }
        all_data.append(sample)
    print("Dummy data generation complete.")
    return all_data


# --- 2. Dataset, Collator, and Corrected Trainer Class ---

class MeshDataset(Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

class CustomDataCollator:
    def __call__(self, features):
        batch = {}
        batch['x'] = torch.stack([f['x'] for f in features])
        y_features = [f['y'] for f in features]
        batch['y'] = {key: torch.stack([d[key] for d in y_features]) for key in y_features[0]}
        return batch

class MeshDiTTrainer(Trainer):
    def __init__(self, *args, diffusion, **kwargs):
        super().__init__(*args, **kwargs)
        self.diffusion = diffusion

    # MODIFIED: Added **kwargs to the function signature
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        x_start = inputs.get("x")
        model_kwargs = {"y": inputs.get("y")}
        t = torch.randint(0, self.diffusion.num_timesteps, (x_start.shape[0],), device=x_start.device).long()
        noise = torch.randn_like(x_start)
        x_t = self.diffusion.q_sample(x_start, t, noise=noise)
        
        # model's forward pass
        model_output = model(x_t, t, y=inputs['y'])
        
        # diffusion model's loss calculation
        loss_dict = self.diffusion.training_losses(model, x_start, t, model_kwargs, noise=noise)
        loss = loss_dict["loss"].mean()
        
        return (loss, loss_dict) if return_outputs else loss

    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys = None,
    ):
        """
        Perform an evaluation step on the model.
        """
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            loss = self.compute_loss(model, inputs)
            
        return (loss.detach(), None, None)


# --- 3. Main Setup and Execution ---

# Setup device and data type
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
tokenizer.pad_token = tokenizer.eos_token

# Generate dummy data
all_data = process_all_meshes_demo(num_samples=100, tokenizer=tokenizer)

# Initialize Model
print("Initializing MeshDiT model for training...")
model = MeshDiT_models['MeshDiT-S'](
    vocab_size=tokenizer.vocab_size, image_latent_channels=4,
    image_latent_height=64, image_latent_width=64,
    use_rope=True, use_ternary_rope=False
).to(device, dtype=dtype)

# Initialize Diffusion Process
print("Initializing Gaussian diffusion process...")
diffusion = GaussianDiffusion(
    betas=get_named_beta_schedule("linear", 1000),
    model_mean_type=ModelMeanType.EPSILON,
    model_var_type=ModelVarType.FIXED_SMALL,
    loss_type=LossType.MSE,
)

if all_data:
    random.shuffle(all_data)
    train_data = all_data[5:]
    eval_data = all_data[:5]
    print(f"Data split: {len(train_data)} training samples, {len(eval_data)} evaluation samples.")
    train_dataset = MeshDataset(train_data)
    eval_dataset = MeshDataset(eval_data)
    data_collator = CustomDataCollator()

    training_args = TrainingArguments(
        output_dir="./mesh_dit_test_checkpoint",
        num_train_epochs=5,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        learning_rate=1e-4,
        logging_strategy="steps",
        logging_steps=1,
        eval_strategy="epoch",
        save_strategy="epoch",
        fp16=False,
        report_to="wandb",
        run_name="MeshDiT-S-test-run-fixed",
    )

    trainer = MeshDiTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        diffusion=diffusion,
    )

    print("\n" + "="*50)
    print("      Starting MeshDiT Model Training Test")
    print("="*50 + "\n")
    trainer.train()
    print("\n--- Test training finished successfully! ---")
else:
    print("FATAL: Dummy data creation failed.")

# =========================================================================================
# MESH GENERATION SCRIPT: LOAD MODEL AND GENERATE MESH
# This script loads your trained model and generates a mesh.
# =========================================================================================

In [None]:
import torch
import numpy as np
import trimesh
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoTokenizer
from diffusers import AutoencoderKL
from tqdm import tqdm
import os 
from safetensors.torch import load_file
import torchvision.transforms as transforms
import tempfile 

# Add Step1X-3D repo to Python's path
import sys

from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule, _extract_into_tensor


print("--- Starting Inference Setup ---")

# --- 1. Configuration ---
MODEL_PATH = "./mesh_dit_final"
OUTPUT_FILENAME = "generated_mesh_error_test.glb"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32

# --- Your Generation Inputs ---
TEXT_PROMPT = "a green rifle with a long barrel"
IMAGE_URL = "https://media.sketchfab.com/models/ed51a51909ee46c780db3a85e821feb2/thumbnails/f46b00b385d4449d923aff78348b04c6/eb41be4a1dcd450a974db9d24dc13d16.jpeg"

# --- Generation Parameters ---
CFG_SCALE_TEXT = 7.5
CFG_SCALE_IMAGE = 1.5
NUM_SAMPLING_STEPS = 250

# --- 2. Load All Models & Components ---
print(f"Loading models onto {DEVICE}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
print("-> Tokenizer loaded.")

# --- Load your original trained MeshDiT model ---
# We use the original input_dim of 768 to match the saved weights.
generation_model = MeshDiT_models['MeshDiT-S'](
    input_dim=64, 
    vocab_size=tokenizer.vocab_size,
    image_latent_channels=4,
    image_latent_height=64,
    image_latent_width=64,
    use_rope=True,
    use_ternary_rope=False
).to(DEVICE, dtype=DTYPE).eval()
state_dict = load_file(os.path.join(MODEL_PATH, "model.safetensors"), device=DEVICE)
generation_model.load_state_dict(state_dict)
print("-> Custom DiT model (dim=768) loaded successfully.")


white_image = Image.new('RGB', (512, 512), 'white')
print(f"-> Created a blank white PIL Image with size {white_image.size}.")


# Load the other necessary pipelines
geometry_pipeline = Step1X3DGeometryPipeline.from_pretrained("stepfun-ai/Step1X-3D", subfolder='Step1X-3D-Geometry-1300m', torch_dtype=DTYPE)
vae_3d = geometry_pipeline.vae.to(DEVICE).eval()
print("-> 3D VAE loaded.")

vae_image = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="vae", torch_dtype=DTYPE).to(DEVICE).eval()
print("-> Image VAE loaded.")

diffusion = GaussianDiffusion(betas=get_named_beta_schedule("linear", 1000), model_mean_type=ModelMeanType.EPSILON, model_var_type=ModelVarType.FIXED_SMALL, loss_type=LossType.MSE)
print("--- All components loaded successfully! ---")


# --- 3. Prepare Inputs for the Model ---
print("\n--- Preparing Generation Inputs ---")

tokens = tokenizer(TEXT_PROMPT, padding="max_length", max_length=128, truncation=True, return_tensors="pt")
input_ids = tokens["input_ids"].to(DEVICE)
attention_mask = tokens["attention_mask"].to(DEVICE)
print(f"-> Text prompt: '{TEXT_PROMPT}'")

# Process the real image from the URL to get its latent representation for conditioning
def get_image_latent(url: str, vae: AutoencoderKL, device: str) -> torch.Tensor:
    response = requests.get(url)
    response.raise_for_status()
    img_data = BytesIO(response.content)
    raw_image = Image.open(img_data).convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    image_tensor = transform(raw_image).unsqueeze(0).to(device=device, dtype=vae.dtype)
    with torch.no_grad():
        image_latent = vae.encode(image_tensor).latent_dist.sample() * vae.config.scaling_factor
    return image_latent

image_latent = get_image_latent(IMAGE_URL, vae_image, DEVICE)
print(f"-> Conditioning image latent created with shape: {image_latent.shape}")


model_kwargs = {"y": {"image_latent": image_latent, "input_ids": input_ids, "attention_mask": attention_mask}}

# --- 4. Generate the 3D Latent via Denoising ---
print("\n--- Starting Denoising Process ---")

# Start with random noise. The input dimension for our model is 768.
z = torch.randn(1, 1024, 768, device=DEVICE, dtype=DTYPE) 

ddim_timesteps = np.asarray(list(range(0, 1000, 1000 // NUM_SAMPLING_STEPS)))
ddim_steps = torch.from_numpy(ddim_timesteps).long().to(DEVICE)


with torch.no_grad():
    for i in tqdm(range(NUM_SAMPLING_STEPS - 1, -1, -1), desc="DDIM Sampling"):
        t = ddim_steps[i].expand(z.shape[0])
        
        z_in = torch.cat([z, z], dim=0)
        t_in = torch.cat([t, t], dim=0)
        y_in = {
            "image_latent": torch.cat([model_kwargs['y']['image_latent']] * 2, dim=0),
            "input_ids": torch.cat([model_kwargs['y']['input_ids']] * 2, dim=0),
            "attention_mask": torch.cat([model_kwargs['y']['attention_mask']] * 2, dim=0)
        }
        
        noise_pred = generation_model.forward_with_cfg(
            z_in, t_in, y_in, CFG_SCALE_TEXT, CFG_SCALE_IMAGE
        )
        
        alpha_t = _extract_into_tensor(diffusion.alphas_cumprod, t, z.shape)
        t_prev_idx = ddim_steps[i - 1] if i > 0 else torch.tensor([-1], device=DEVICE, dtype=torch.long)
        alpha_t_prev = _extract_into_tensor(diffusion.alphas_cumprod_prev, t_prev_idx, z.shape)

        pred_x0 = (z - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
        dir_xt = (1 - alpha_t_prev).sqrt() * noise_pred
        z = alpha_t_prev.sqrt() * pred_x0 + dir_xt

generated_latent_seq = z
print("--- Denoising complete. ---")

# --- 5. Decode Latent and Save Mesh ---
print("\n--- Decoding Latent to Mesh (Error expected here) ---")
with torch.no_grad():
    # This is where the error will happen because the VAE decoder expects a latent with 64 channels,
    # but our generated latent will have 768 channels.
    B, N, C = generated_latent_seq.shape
    H = W = int(np.sqrt(N)) 
    generated_latent_reshaped = generated_latent_seq.permute(0, 2, 1).view(B, C, H, W)
    
    # This line will cause the dimension mismatch error.
    decoded_output = vae_3d.decode(generated_latent_reshaped).sample
    points = geometry_pipeline.point_cloud_from_logits(decoded_output.squeeze(0), to_world=True)

final_mesh = trimesh.PointCloud(vertices=points).to_mesh()
final_mesh.export(OUTPUT_FILENAME)
print(f"\n--- ✨ Success! Mesh saved to {OUTPUT_FILENAME} ---")



In [None]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from diffusers import AutoencoderKL
from diffusion_model import GaussianDiffusion, ModelMeanType, ModelVarType, LossType, get_named_beta_schedule
from timm.models.vision_transformer import PatchEmbed
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

# --- Config ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


image_size = 512
latent_downscale = 8
latent_size = image_size // latent_downscale  # 64
patch_size = 4
patch_dim = 4 * patch_size * patch_size      # 64
embed_dim = 1024
batch_size = 8

vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="vae").to(device).eval()
betas = get_named_beta_schedule("linear", 1000)
diffusion = GaussianDiffusion(betas=betas, model_mean_type=ModelMeanType.EPSILON, model_var_type=ModelVarType.FIXED_SMALL, loss_type=LossType.MSE)
patch_embed = PatchEmbed(img_size=latent_size, patch_size=patch_size, in_chans=4, embed_dim=embed_dim).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
tokenizer.pad_token = tokenizer.eos_token


transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])



def preprocess(example):
    # Image -> latent
    image = transform(example['image'].convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        latents = vae.encode(image).latent_dist.sample() * vae.config.scaling_factor
        latents = F.interpolate(latents, size=(latent_size, latent_size), mode='bilinear')
    # Add noise
    t = torch.randint(0, 1000, (1,), device=device).long()
    noise = torch.randn_like(latents)
    x_noisy = diffusion.q_sample(x_start=latents, t=t, noise=noise)
    # Flatten patches via unfold: [1,4,64,64] -> [1,256,64]
    patches = F.unfold(x_noisy, kernel_size=patch_size, stride=patch_size)  # [1,64,256]
    patches = patches.permute(0, 2, 1).squeeze(0)  # [256,64]
    # Tokenize text
    tokens = tokenizer(example['text'], padding='max_length', truncation=True,
                       max_length=2048, return_tensors='pt')
    return {
        'patch_embeddings': patches.cpu(),
        'noise': noise.squeeze(0).cpu(),
        'timestep': t.item(),
        'input_ids': tokens['input_ids'].squeeze(0),
        'attention_mask': tokens['attention_mask'].squeeze(0)
    }




dataset = load_dataset("iamkaikai/GAME-MAP-ART", split="train").select(range(10))
processed = [preprocess(ex) for ex in dataset]



# Dataset and DataLoader
class DiTDataset(torch.utils.data.Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

def collate_fn(batch):
    return {
        'patch_embeddings': torch.stack([x['patch_embeddings'] for x in batch]),
        'noise': torch.stack([x['noise'] for x in batch]),
        'timestep': torch.tensor([x['timestep'] for x in batch], dtype=torch.long),
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
    }

dataloader = DataLoader(DiTDataset(processed), batch_size=batch_size,
                        shuffle=True, num_workers=4, collate_fn=collate_fn)

In [None]:
"""
A minimal training script for DiT using PyTorch DDP.
"""
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
from collections import OrderedDict
from PIL import Image
from copy import deepcopy
from glob import glob
from time import time
import argparse
import logging
import os

from mmfreelm.models.hgrn_bit.ternary_dit import DiT_models
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL


#################################################################################
#                             Training Helper Functions                         #
#################################################################################

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def cleanup():
    """
    End DDP training.
    """
    dist.destroy_process_group()


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    if dist.get_rank() == 0:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger


def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


#################################################################################
#                                  Training Loop                                #
#################################################################################

def main(args):
    """
    Trains a new DiT model.
    """
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."

    # Setup DDP:
    dist.init_process_group("nccl")
    assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

    # Setup an experiment folder:
    if rank == 0:
        os.makedirs(args.results_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
        experiment_index = len(glob(f"{args.results_dir}/*"))
        model_string_name = args.model.replace("/", "-")  # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
        experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}"  # Create an experiment folder
        checkpoint_dir = f"{experiment_dir}/checkpoints"  # Stores saved model checkpoints
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(experiment_dir)
        logger.info(f"Experiment directory created at {experiment_dir}")
    else:
        logger = create_logger(None)

    # Create model:
    assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
    latent_size = args.image_size // 8
    model = DiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes
    )
    # Note that parameter initialization is done within the DiT constructor
    ema = deepcopy(model).to(device)  # Create an EMA of the model for use after training
    requires_grad(ema, False)
    model = DDP(model.to(device), device_ids=[rank])
    diffusion = create_diffusion(timestep_respacing="")  # default: 1000 steps, linear noise schedule
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
    logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)

    # Setup data:
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    dataset = ImageFolder(args.data_path, transform=transform)
    sampler = DistributedSampler(
        dataset,
        num_replicas=dist.get_world_size(),
        rank=rank,
        shuffle=True,
        seed=args.global_seed
    )
    loader = DataLoader(
        dataset,
        batch_size=int(args.global_batch_size // dist.get_world_size()),
        shuffle=False,
        sampler=sampler,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")

    # Prepare models for training:
    update_ema(ema, model.module, decay=0)  # Ensure EMA is initialized with synced weights
    model.train()  # important! This enables embedding dropout for classifier-free guidance
    ema.eval()  # EMA model should always be in eval mode

    # Variables for monitoring/logging purposes:
    train_steps = 0
    log_steps = 0
    running_loss = 0
    start_time = time()

    logger.info(f"Training for {args.epochs} epochs...")
    for epoch in range(args.epochs):
        sampler.set_epoch(epoch)
        logger.info(f"Beginning epoch {epoch}...")
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            with torch.no_grad():
                # Map input images to latent space + normalize latents:
                x = vae.encode(x).latent_dist.sample().mul_(0.18215)
            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
            model_kwargs = dict(y=y)
            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            update_ema(ema, model.module)

            # Log loss values:
            running_loss += loss.item()
            log_steps += 1
            train_steps += 1
            if train_steps % args.log_every == 0:
                # Measure training speed:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)
                # Reduce loss history over all processes:
                avg_loss = torch.tensor(running_loss / log_steps, device=device)
                dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
                avg_loss = avg_loss.item() / dist.get_world_size()
                logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
                # Reset monitoring variables:
                running_loss = 0
                log_steps = 0
                start_time = time()

            # Save DiT checkpoint:
            if train_steps % args.ckpt_every == 0 and train_steps > 0:
                if rank == 0:
                    checkpoint = {
                        "model": model.module.state_dict(),
                        "ema": ema.state_dict(),
                        "opt": opt.state_dict(),
                        "args": args
                    }
                    checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint to {checkpoint_path}")
                dist.barrier()

    model.eval()  # important! This disables randomized embedding dropout
    # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...

    logger.info("Done!")
    cleanup()


if __name__ == "__main__":
    # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--results-dir", type=str, default="results")
    parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--epochs", type=int, default=1400)
    parser.add_argument("--global-batch-size", type=int, default=256)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")  # Choice doesn't affect training
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--log-every", type=int, default=100)
    parser.add_argument("--ckpt-every", type=int, default=50_000)
    args = parser.parse_args()
    main(args)