In [1]:
pip install -U "accelerate>=0.26.0"

[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
!pip install einops timm

import sys, os
!git clone https://github.com/Sakib323/AI-Game-Engine.git
sys.path.append('/workspace/AI-Game-Engine') 
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models

[0mfatal: destination path 'AI-Game-Engine' already exists and is not an empty directory.


In [3]:
import os, json, shutil, logging, random, signal, time
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, Subset, RandomSampler, SequentialSampler, DataLoader
from transformers import AutoTokenizer, Trainer, TrainingArguments
from diffusers import DDIMScheduler
from safetensors.torch import load_file as safetensors_load
import wandb

import torch._dynamo
torch._dynamo.config.suppress_errors = True   
torch._dynamo.config.verbose = False

os.environ["TOKENIZERS_PARALLELISM"] = "false"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")              

print("SUCCESS: MeshDiT models imported!")

Using device: cuda, dtype: torch.float32
SUCCESS: MeshDiT models imported!


In [4]:
!pip install kagglehub
import kagglehub

path = kagglehub.dataset_download("sakibahmed2022/meshdit-all-data-5-pt")
print("Dataset downloaded to:", path)
cache_dir = "/tmp/meshdit_cache"
os.makedirs(cache_dir, exist_ok=True)

file_names = ["all_data.pt", "all_data_1.pt", "all_data_2.pt", "all_data_3.pt", "all_data_4.pt"]
local_paths = []

for name in file_names:
    src = os.path.join(path, name)
    dst = os.path.join(cache_dir, name)
    
    if not os.path.exists(dst):
        print(f"Copying {name} → {dst}")
        shutil.copy2(src, dst)
    else:
        print(f"Already cached: {dst}")
        
    local_paths.append(dst)

print("All .pt files ready in", cache_dir)

[0mDownloading from https://www.kaggle.com/api/v1/datasets/download/sakibahmed2022/meshdit-all-data-5-pt?dataset_version_number=1...


100%|██████████| 44.1G/44.1G [05:51<00:00, 135MB/s]   

Extracting files...





Dataset downloaded to: /root/.cache/kagglehub/datasets/sakibahmed2022/meshdit-all-data-5-pt/versions/1
Copying all_data.pt → /tmp/meshdit_cache/all_data.pt
Copying all_data_1.pt → /tmp/meshdit_cache/all_data_1.pt
Copying all_data_2.pt → /tmp/meshdit_cache/all_data_2.pt
Copying all_data_3.pt → /tmp/meshdit_cache/all_data_3.pt
Copying all_data_4.pt → /tmp/meshdit_cache/all_data_4.pt
All .pt files ready in /tmp/meshdit_cache


In [5]:
print("\n--- Freeing Disk Space ---")

if os.path.exists(path):
    print(f"Deleting original Kaggle cache at: {path}")
    shutil.rmtree(path)
    print("✅ Redundant dataset deleted.")

total, used, free = shutil.disk_usage("/")
print(f"\nDISK SPACE REMAINING: {free // (2**30)} GB")
print("All .pt files ready in", cache_dir)


--- Freeing Disk Space ---
Deleting original Kaggle cache at: /root/.cache/kagglehub/datasets/sakibahmed2022/meshdit-all-data-5-pt/versions/1
✅ Redundant dataset deleted.

DISK SPACE REMAINING: 70 GB
All .pt files ready in /tmp/meshdit_cache


In [None]:
def train_model():
    # ------------------- W&B -------------------
    WANDB_TOKEN = "89b06c10468af620747b4bd340f72fa5d56f6849"
    try:
        wandb.login(key=WANDB_TOKEN)
        use_wandb = True
    except Exception as e:
        print("W&B disabled:", e)
        use_wandb = False

    # ------------------- Dataset -------------------
    class CustomDataCollator:
        def __call__(self, features):
            batch = {}
            batch['x'] = torch.stack([f['x'] for f in features])
            y_features = [f['y'] for f in features]
            batch['y'] = {k: torch.stack([d[k] for d in y_features]) for k in y_features[0]}
            return batch

    class MeshDiTTrainer(Trainer):
        def __init__(self, *args, noise_scheduler, **kwargs):
            super().__init__(*args, **kwargs)
            self.noise_scheduler = noise_scheduler

        def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
            x_start = inputs["x"]
            model_kwargs = inputs["y"]
            
            noise = torch.randn_like(x_start)
            timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps,
                                      (x_start.shape[0],), device=x_start.device).long()
            
            noisy_latents = self.noise_scheduler.add_noise(x_start, noise, timesteps)
            
            # Forward pass
            noise_pred = model(noisy_latents, timesteps, model_kwargs)
            
            loss = F.huber_loss(noise_pred, noise, delta=1.0)
            return (loss, {"loss": loss}) if return_outputs else loss

        def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
            """
            Custom prediction step for Diffusion validation.
            Generates noise/timesteps on the fly and computes loss on the validation set.
            """
            # CRITICAL FIX: Move inputs to the correct device (GPU)
            inputs = self._prepare_inputs(inputs)

            with torch.no_grad():
                x_start = inputs["x"]
                model_kwargs = inputs["y"]

                # Generate noise and timesteps for validation batch
                noise = torch.randn_like(x_start)
                timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps,
                                          (x_start.shape[0],), device=x_start.device).long()

                noisy_latents = self.noise_scheduler.add_noise(x_start, noise, timesteps)

                # Forward pass
                noise_pred = model(noisy_latents, timesteps, model_kwargs)
                
                loss = F.huber_loss(noise_pred, noise, delta=1.0)

            # Return (loss, logits, labels). Logits/Labels are None because we only track loss.
            return (loss, None, None)

    class ConcatenatedMeshDataset(Dataset):
        def __init__(self, pt_paths):
            self.pt_paths = pt_paths
            self.cumsum = [0]
            self.lengths = []
            self.file_handles = []
            for p in pt_paths:
                handle = torch.load(p, map_location='cpu', mmap=True)
                data = handle['data']
                L = len(data)
                self.lengths.append(L)
                self.cumsum.append(self.cumsum[-1] + L)
                self.file_handles.append(data)
            print(f"Total samples: {sum(self.lengths):,}")

        def __len__(self): return self.cumsum[-1]

        def __getitem__(self, idx):
            for i, cum in enumerate(self.cumsum):
                if idx < cum: break
            file_idx = i - 1
            local_idx = idx - self.cumsum[file_idx]
            sample = self.file_handles[file_idx][local_idx]
            return {'x': sample['x'], 'y': sample['y']}

    # Assuming 'local_paths', 'MeshDiT_models', 'device', and 'dtype' are defined globally
    full_dataset = ConcatenatedMeshDataset(local_paths)

    total = len(full_dataset)
    eval_size = max(1, int(total * 0.10))
    train_size = total - eval_size
    train_subset = Subset(full_dataset, range(eval_size, total))
    eval_subset = Subset(full_dataset, range(eval_size))

    train_sampler = RandomSampler(train_subset, generator=torch.Generator().manual_seed(42))
    eval_sampler = SequentialSampler(eval_subset)
    collator = CustomDataCollator()
    scheduler = DDIMScheduler(num_train_timesteps=2000,
                              beta_schedule="linear",
                              prediction_type="epsilon",
                              clip_sample=False)

    # ------------------- Tokenizer -------------------
    tokenizer = AutoTokenizer.from_pretrained("Sakib323/MMfreeLM-370M")
    tokenizer.pad_token = tokenizer.eos_token

    # ------------------- PHASE 1 (scratch) -------------------
    print("\n=== PHASE 1 : MeshDiT-S from scratch ===")

    if use_wandb:
        os.environ["WANDB_PROJECT"] = "mesh-dit-3d-generation-vast"
        wandb.init(project="mesh-dit-3d-generation", name="MeshDiT-S", reinit=True)

    model_p1 = MeshDiT_models['MeshDiT-S'](
        input_tokens=2048,
        vocab_size=tokenizer.vocab_size,
        use_rope=True,
        use_ternary_rope=True,
        image_condition=False,
        full_precision=True,
        optimized_bitlinear=False,
    ).to(device, dtype=dtype)

    args_p1 = TrainingArguments(
        output_dir="./phase1_ckpt",
        num_train_epochs=10,
        per_device_train_batch_size=64,
        gradient_accumulation_steps=8,
        learning_rate=1e-4,
        lr_scheduler_type="cosine",  
        weight_decay=0.01,
        warmup_ratio=0.1,
        logging_steps=50,
        eval_strategy="steps",
        eval_steps=500,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        fp16=True,
        max_grad_norm=5.0,          
        dataloader_num_workers=4,
        remove_unused_columns=False,
        report_to="wandb" if use_wandb else "tensorboard",
        run_name="MeshDiT-S-Phase1-Final",
        torch_compile=True,        
        torch_compile_backend="inductor",
        dataloader_pin_memory=True,
        dataloader_prefetch_factor=4,
    )

    trainer_p1 = MeshDiTTrainer(
        model=model_p1,
        args=args_p1,
        train_dataset=train_subset,
        eval_dataset=eval_subset,
        data_collator=collator,
        noise_scheduler=scheduler,
    )

    trainer_p1.train_dataloader = DataLoader(
        train_subset,
        batch_size=args_p1.per_device_train_batch_size,
        sampler=train_sampler,
        collate_fn=collator,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=4,
        persistent_workers=True
    )

    trainer_p1.eval_dataloader = DataLoader(
        eval_subset,
        batch_size=args_p1.per_device_train_batch_size,
        sampler=eval_sampler,
        collate_fn=collator,
        num_workers=4,
        pin_memory=True
    )

    trainer_p1.train()
    trainer_p1.save_model("./mesh_dit_final")
    tokenizer.save_pretrained("./mesh_dit_final")
    if use_wandb:
        wandb.finish()
    print("FINAL MODEL → ./mesh_dit_final")

    print("\nTRAINING COMPLETE – MeshDiT-B ready!")

# Run it
train_model()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msakibahmed2018go[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Total samples: 97,308


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]


=== PHASE 1 : MeshDiT-S from scratch ===




INFO:mmfreelm.models.hgrn_bit.mesh_dit:Absolute positional embeddings are disabled for this model.


Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Init

The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233] WON'T CONVERT torch_dynamo_resume_in_forward_at_136 /workspace/AI-Game-Engine/mmfreelm/layers/hgrn_bit.py line 136 
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233] due to: 
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233] Traceback (most recent call last):
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233]     result = self._inner_convert(
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233]              ^^^^^^^^^^^^^^^^^^^^
W1120 17:54:01.958000 5356 site-packages/torch/_dynamo/convert_frame.py:1233]   File "/opt/conda/lib/

Step,Training Loss,Validation Loss
500,0.0774,0.079304
1000,0.0751,0.075536
1500,0.0744,0.073858


In [None]:
import os
import json
import shlex
import subprocess
from datetime import datetime

KAGGLE_USERNAME = "sakibahmed2022"
KAGGLE_KEY = "d0ec7fe8091a3906c6995568c11e9e58"
os.environ['KAGGLE_USERNAME'] = KAGGLE_USERNAME
os.environ['KAGGLE_KEY'] = KAGGLE_KEY

# --------------------------------------------------------------
# 2. CONFIG
# --------------------------------------------------------------
DESIRED_SLUG    = f"{KAGGLE_USERNAME}/meshdit-trained-model"
LOCAL_DIR       = "./mesh_dit_final"
TITLE           = "MeshDiT S Trained Model 10 epoch"
VERSION_MSG     = f"Upload final model – {datetime.utcnow().isoformat()} UTC"
PUBLIC          = False          # private dataset

# --------------------------------------------------------------
# 3. Helper – write dataset-metadata.json (required)
# --------------------------------------------------------------
def write_metadata(folder: str, dataset_id: str):
    meta = {
        "title": TITLE,
        "id": dataset_id,
        "licenses": [{"name": "CC-BY-4.0"}],
        "description": f"{TITLE} – uploaded {datetime.utcnow().isoformat()} UTC"
    }
    path = os.path.join(folder, "dataset-metadata.json")
    with open(path, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)
    print(f"[meta] {path}")

# --------------------------------------------------------------
# 4. Try Python API (ignore harmless ApiStartBlobUploadRequest warning)
# --------------------------------------------------------------
def try_api(folder: str, dataset_id: str, msg: str):
    from kaggle.api.kaggle_api_extended import KaggleApi
    api = KaggleApi()
    api.authenticate()

    # Try version first
    try:
        api.dataset_create_version(
            folder=folder,
            version_notes=msg,
            convert_to_csv=False,
            delete_old_versions=False,
        )
        print("[api] version created")
        return True, dataset_id
    except Exception as e:
        # 403 after files are uploaded → dataset already exists, fall back to CLI
        if "403" in str(e):
            print("[api] 403 after file upload → will use CLI")
            return False, dataset_id
        # Dataset missing → create new
        if "does not exist" in str(e).lower() or "404" in str(e):
            print("[api] dataset missing → creating new")
            api.dataset_create_new(folder=folder, public=PUBLIC, convert_to_csv=False)
            print("[api] new dataset created")
            return True, dataset_id
        print(f"[api] unexpected error: {e}")
        return False, None

# --------------------------------------------------------------
# 5. CLI fallback (no --private flag – CLI defaults to private)
# --------------------------------------------------------------
def try_cli(folder: str, dataset_id: str, msg: str):
    # 1. version
    cmd = f"kaggle datasets version -p {shlex.quote(folder)} -m {shlex.quote(msg)}"
    print(f"[cli] version: {cmd}")
    p = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if p.returncode == 0:
        print("[cli] version OK")
        return True, dataset_id

    # 2. create (only if version failed because dataset missing)
    cmd = f"kaggle datasets create -p {shlex.quote(folder)}"
    print(f"[cli] create: {cmd}")
    p = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if p.returncode == 0:
        print("[cli] create OK")
        return True, dataset_id

    print("[cli] both commands failed")
    print("STDOUT:", p.stdout)
    print("STDERR:", p.stderr)
    return False, None

# --------------------------------------------------------------
# 6. Main – slug-collision safe
# --------------------------------------------------------------
def upload():
    if not os.path.isdir(LOCAL_DIR):
        raise FileNotFoundError(LOCAL_DIR)

    files = [f for f in os.listdir(LOCAL_DIR) if os.path.isfile(os.path.join(LOCAL_DIR, f))]
    print("[main] files:", files)

    # Resolve slug
    dataset_id = DESIRED_SLUG
    ts = datetime.utcnow().strftime("%Y%m%d%H%M%S")
    alt = f"{KAGGLE_USERNAME}/meshdit-trained-model-{ts}"

    # Try a cheap API call to see if the slug exists
    try:
        from kaggle.api.kaggle_api_extended import KaggleApi
        KaggleApi().authenticate()
        KaggleApi().dataset_list_files(dataset_id)   # raises if missing
        print(f"[slug] '{dataset_id}' exists → using timestamped")
        dataset_id = alt
    except Exception:
        pass   # keep desired slug

    write_metadata(LOCAL_DIR, dataset_id)

    # 1. API
    print("[main] trying Python API …")
    ok, final_id = try_api(LOCAL_DIR, dataset_id, VERSION_MSG)
    if ok:
        url = f"https://www.kaggle.com/datasets/{final_id}"
        print(f"[SUCCESS] API → {url}")
        return url

    # 2. CLI
    print("[main] API failed → CLI fallback …")
    ok, final_id = try_cli(LOCAL_DIR, dataset_id, VERSION_MSG)
    if ok:
        url = f"https://www.kaggle.com/datasets/{final_id}"
        print(f"[SUCCESS] CLI → {url}")
        return url

    raise RuntimeError("Both API and CLI failed")

# --------------------------------------------------------------
# 7. RUN
# --------------------------------------------------------------
if __name__ == "__main__":
    final_url = upload()
    print("\n=== FINAL URL ===")
    print(final_url)
    print("\nDownload with:")
    print(f"kaggle datasets download -d {final_url.split('/')[-2]}/{final_url.split('/')[-1]}")