In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.models.multimodal import MultimodalPointDiffusionTransformer
from point_e.models.configs import MODEL_CONFIGS
from point_e.models.download import load_checkpoint


# Define a dataset for (image, text, point cloud) triplets
class MultimodalPointCloudDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.samples = []
        
        # Load all triplets of (image, text, point cloud)
        # This is just a placeholder - you'll need to implement
        # the actual data loading logic based on your dataset format
        for item in os.listdir(data_dir):
            item_dir = os.path.join(data_dir, item)
            if os.path.isdir(item_dir):
                image_path = os.path.join(item_dir, "render.png")
                text_path = os.path.join(item_dir, "caption.txt")
                pc_path = os.path.join(item_dir, "points.npz")
                
                if os.path.exists(image_path) and os.path.exists(text_path) and os.path.exists(pc_path):
                    self.samples.append({
                        "image_path": image_path,
                        "text_path": text_path,
                        "pc_path": pc_path
                    })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image (you'll need to implement based on your data format)
        # image = load_image(sample["image_path"])
        # if self.transform:
        #     image = self.transform(image)
        
        # Load text
        with open(sample["text_path"], "r") as f:
            text = f.read().strip()
        
        # Load point cloud
        # pc = load_point_cloud(sample["pc_path"])
        
        return {
            "image": image,
            "text": text,
            "point_cloud": pc
        }


def main():
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create model
    model_config = MODEL_CONFIGS['base40M'].copy()
    model_config.pop('name', None)  # Remove the 'name' key if it exists
    model = MultimodalPointDiffusionTransformer(
        device=device,
        dtype=torch.float32,
        frozen_transformer=True,  # Only train fusion
        use_cross_attention=True,
        **model_config
    )
    diffusion = diffusion_from_config(DIFFUSION_CONFIGS['base40M'])
    
    # Load pre-trained weights
    model.load_state_dict(load_checkpoint('base40M', device), strict=False)
    
    # Setup dataset and dataloader
    dataset = MultimodalPointCloudDataset(
        data_dir="path/to/your/dataset",
        transform=None  # Add necessary transforms
    )
    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4
    )
    
    # Setup optimizer (only for trainable parameters)
    optimizer = optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-4,
        weight_decay=1e-4
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=100000
    )
    
    # Training loop
    num_epochs = 100
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for batch in pbar:
                # Extract data
                images = batch["image"].to(device)
                texts = batch["text"]
                point_clouds = batch["point_cloud"].to(device)
                
                # Add noise and get loss
                t = torch.randint(0, diffusion.num_timesteps, (images.shape[0],), device=device)
                model_kwargs = {"images": images, "texts": texts}
                
                # Compute training loss
                losses = diffusion.training_losses(
                    model=model,
                    x_start=point_clouds,
                    t=t,
                    model_kwargs=model_kwargs
                )
                
                loss = losses["loss"].mean()
                
                # Backprop and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
                
                # Update progress bar
                total_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save(
                model.state_dict(), 
                f"multimodal_point_e_epoch_{epoch+1}.pt"
            )
    
    # Save final model
    torch.save(model.state_dict(), "multimodal_point_e_final.pt")


if __name__ == "__main__":
    main()

TypeError: PointDiffusionTransformer.__init__() got an unexpected keyword argument 'name'