In [1]:
!pip install einops timm
import sys, os
!git clone https://github.com/Sakib323/AI-Game-Engine.git
sys.path.append('/workspace/AI-Game-Engine') 
from mmfreelm.models.hgrn_bit.mesh_dit import MeshDiT_models

[0mfatal: destination path 'AI-Game-Engine' already exists and is not an empty directory.


In [None]:
import os
import gc
import glob
import bisect
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer, 
    Trainer, 
    TrainingArguments, 
    default_data_collator
)
from huggingface_hub import snapshot_download
from tqdm import tqdm
import wandb
import torch.nn.functional as F  # <--- THIS WAS MISSING
from mmfreelm.models.hgrn_bit.video_gen import VideoDiT_models, flow_matching_loss

# --- Setup ---
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
gc.collect()

MODEL_SAVE_DIR = "/kaggle/working/checkpoints"
TOKENIZER_NAME = "Sakib323/MMfreeLM-370M"
HF_DATASET_ID = "Sakib323/panda-70m-latents"

# --- CONFIG ---
BATCH_SIZE = 3 
GRADIENT_ACCUMULATION_STEPS = 4 
LEARNING_RATE = 5e-5             
NUM_EPOCHS = 2
NUM_WORKERS = 2
INPUT_SIZE = (16, 72, 128) 
PATCH_SIZE = (1, 2, 2)

# WandB Login
WANDB_TOKEN = "89b06c10468af620747b4bd340f72fa5d56f6849"
wandb.login(key=WANDB_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# --- 1. Download Data from Hub ---
print(f"ðŸ“¥ Downloading {HF_DATASET_ID} to local cache...")
# This downloads the .pt files to a local folder managed by HF
local_data_dir = snapshot_download(
    repo_id=HF_DATASET_ID, 
    repo_type="dataset",
    token=WANDB_TOKEN, # Using the same token variable if it's your HF token too, otherwise use HF token
    allow_patterns=["*.pt"] # Only download the data files
)
print(f"âœ… Data downloaded to: {local_data_dir}")

# --- 2. Robust Dataset Class (Fixed for variable frame counts) ---
class VideoLatentDataset(Dataset):
    def __init__(self, data_dir):
        self.files = sorted(glob.glob(os.path.join(data_dir, "*.pt")))
        self.scale_factor = 0.18215 
        self.target_frames = 16  # <--- Target frame count
        
        # Build index map for chunks
        self.file_map = []     
        self.file_starts = []  
        self.total_samples = 0
        
        print(f"Scanning {len(self.files)} files to build index map...")
        for f_path in tqdm(self.files):
            try:
                data = torch.load(f_path, map_location="cpu")
                if isinstance(data, list):
                    count = len(data)
                    is_list = True
                else:
                    count = 1
                    is_list = False
                
                self.file_starts.append(self.total_samples)
                self.file_map.append({"path": f_path, "is_list": is_list})
                self.total_samples += count
                del data
            except Exception as e:
                print(f"Skipping broken file {f_path}: {e}")
                
        print(f"Total samples found: {self.total_samples}")

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        # Find file index
        file_idx = bisect.bisect_right(self.file_starts, idx) - 1
        path_info = self.file_map[file_idx]
        start_idx = self.file_starts[file_idx]
        
        # Load data
        item_data = torch.load(path_info["path"], map_location="cpu")
        
        # Extract item
        if path_info["is_list"]:
            local_idx = idx - start_idx
            item = item_data[local_idx]
        else:
            item = item_data
            
        # Prepare tensors
        latents = item["video_latents"].float()
        
        # --- FIX: HANDLE VARIABLE FRAME COUNTS ---
        # Current shape: (T, C, H, W) e.g., (15, 4, 72, 128)
        current_frames = latents.shape[0]
        
        if current_frames > self.target_frames:
            # Too long: Cut it
            latents = latents[:self.target_frames]
        elif current_frames < self.target_frames:
            # Too short: Repeat the last frame to fill
            diff = self.target_frames - current_frames
            last_frame = latents[-1].unsqueeze(0) # (1, C, H, W)
            padding = last_frame.repeat(diff, 1, 1, 1)
            latents = torch.cat([latents, padding], dim=0)
            
        # Now shape is guaranteed to be (16, 4, 72, 128)
        
        # Permute to (C, T, H, W) for model
        latents = latents.permute(1, 0, 2, 3) 
        latents = latents * self.scale_factor
        
        return {
            "latents": latents,
            "input_ids": item["input_ids"].squeeze(0),
            "attention_mask": item["attention_mask"].squeeze(0)
        }


class FlowMatchingTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        x_1 = inputs["latents"]
        cond_y = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"]
        }
        
        b = x_1.shape[0]
        device = x_1.device
        x_0 = torch.randn_like(x_1)
        t_step = torch.rand(b, device=device)
        t_expand = t_step.view(b, 1, 1, 1, 1)
        x_t = t_expand * x_1 + (1 - t_expand) * x_0
        v_target = x_1 - x_0
        v_pred = model(x_t, t_step, cond_y)
        loss = F.mse_loss(v_pred, v_target)

        if torch.isnan(loss) or torch.isinf(loss):
            loss = torch.tensor(0.0, device=device, requires_grad=True)

        return (loss, v_pred) if return_outputs else loss

# --- 4. Execution ---
print("Initializing Dataset...")
# Point the dataset to the downloaded snapshot folder
full_dataset = VideoLatentDataset(local_data_dir)

if len(full_dataset) == 0:
    raise ValueError("Dataset is empty! Check your Hugging Face repo contains .pt files.")

train_size = int(0.9 * len(full_dataset))
eval_size = len(full_dataset) - train_size
generator = torch.Generator().manual_seed(42)
train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_size, eval_size], generator=generator)

print(f"Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")

print("Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

os.environ["WANDB_PROJECT"] = "video-dit-3d-generation"

print("Initializing VideoDiT Model...")
model = VideoDiT_models['VideoDiT-S'](
    input_size=INPUT_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=4, 
    vocab_size=tokenizer.vocab_size,
    use_rope=True,
    use_ternary_rope=True,
    first_frame_condition=False,
    full_precision=True,
    optimized_bitlinear=False,
    use_temporal=False,
    use_grid=False,
    use_resampling=False,
)

training_args = TrainingArguments(
    output_dir=MODEL_SAVE_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_ratio=0.1,
    max_grad_norm=1.0,           
    lr_scheduler_type="cosine",  
    weight_decay=0.01,
    fp16=True,                   
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",    
    dataloader_num_workers=NUM_WORKERS,
    dataloader_pin_memory=True,
    save_total_limit=2,
    logging_steps=10,
    report_to="wandb",
    run_name="VideoDiT-S-HF-Load",
    remove_unused_columns=False,
    label_names=["latents"],
)

trainer = FlowMatchingTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

print("Starting Stable Training...")
trainer.train()
trainer.save_model(os.path.join(MODEL_SAVE_DIR, "final_model"))
print("Training complete.")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msakibahmed2018go[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


ðŸ“¥ Downloading Sakib323/panda-70m-latents to local cache...


Fetching 28 files:   0%|          | 0/28 [00:00<?, ?it/s]

âœ… Data downloaded to: /root/.cache/huggingface/hub/datasets--Sakib323--panda-70m-latents/snapshots/50f4d6d292a176fee09d631edded256e01f6c0e4
Initializing Dataset...
Scanning 28 files to build index map...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 28/28 [00:08<00:00,  3.39it/s]


Total samples found: 11597
Train: 10437 | Eval: 1160
Loading Tokenizer...
Initializing VideoDiT Model...
Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.0, ternary=True

Initializing RotaryEmbedding with theta=10000.0 and ternary=True

[RotaryEmbedding] Initialized with: dim=64, max_pos=2048, base=10000.

Step,Training Loss,Validation Loss
