In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import json
import os
import glob


class ABODatasetSimplified(Dataset):
    """
    Further simplified PyTorch Dataset for loading ABO data.
    Focuses on core loading steps, assuming data is mostly well-formed.

    Args:
        data_root (str): Path to the root directory.
        n_ctx (int): Target number of points for the point cloud.
        clip_preprocessor (callable): Preprocessing function for CLIP images.
        image_subdir (str): Optional subdirectory for images.
        pc_keys (tuple): Keys for coordinates and colors in NPZ files.
    """
    def __init__(self, data_root, n_ctx, clip_preprocessor, image_subdir='', pc_keys=('coords', 'colors')):
        super().__init__()
        self.data_root = data_root
        self.n_ctx = n_ctx
        self.image_subdir = image_subdir
        self.pc_keys = pc_keys
        self.clip_preprocess = clip_preprocessor

        self.captions_dir = os.path.join(data_root, 'captions')
        self.images_dir = os.path.join(data_root, 'images')
        self.pointclouds_dir = os.path.join(data_root, 'pointclouds')
        self.jsonl_path = os.path.join(data_root, 'final_dataset.jsonl')

        if not os.path.exists(self.jsonl_path):
            raise FileNotFoundError(f"Error: JSONL file not found at {self.jsonl_path}")

        self.uids = self._load_uids()
        if not self.uids:
            raise ValueError(f"No UIDs loaded from {self.jsonl_path}.")
        print(f"Found {len(self.uids)} UIDs in jsonl file.")

    def _load_uids(self):
        """ Loads UIDs from the JSON Lines file. """
        uids = []
        with open(self.jsonl_path, 'r') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                    uid = None
                    if isinstance(data, dict) and 'uid' in data:
                        uid = data['uid']
                    elif isinstance(data, str):
                         uid = data
                    if uid:
                        uids.append(uid)
                except json.JSONDecodeError:
                     pass # Silently ignore invalid JSON lines
        return uids

    def __len__(self):
        """ Returns the total number of samples. """
        return len(self.uids)

    def __getitem__(self, idx):
        """ Loads and returns a single data sample. Minimal error checking. """
        uid = self.uids[idx]
        caption = "" # Default empty caption
        image_tensor = torch.zeros((3, 224, 224)) # Default blank image tensor
        pointcloud_tensor = torch.zeros((6, self.n_ctx)) # Default blank point cloud tensor

        try:
            # --- Load Caption (optional) ---
            caption_path = os.path.join(self.captions_dir, f"{uid}.txt")
            if os.path.exists(caption_path):
                 with open(caption_path, 'r', encoding='utf-8') as f:
                    caption = f.read().strip()

            # --- Load Image ---
            image_base_path = os.path.join(self.images_dir, uid)
            if self.image_subdir:
                 image_base_path = os.path.join(image_base_path, self.image_subdir)

            image_path = None
            possible_image_paths = glob.glob(f"{image_base_path}.*")
            if not possible_image_paths:
                 possible_image_paths = glob.glob(os.path.join(image_base_path, '*.*'))

            img_extensions = ['.png', '.jpg', '.jpeg', '.webp']
            for path in possible_image_paths:
                if any(path.lower().endswith(ext) for ext in img_extensions):
                    image_path = path
                    break

            if image_path: # Only process if an image path was found
                image = Image.open(image_path).convert('RGB')
                image_tensor = self.clip_preprocess(image)
            # else: image_tensor remains zeros if no path found

            # --- Load Point Cloud ---
            pc_path = os.path.join(self.pointclouds_dir, f"{uid}.npz")
            if os.path.exists(pc_path): # Only process if npz file exists
                data = np.load(pc_path)
                coord_key = self.pc_keys[0]
                color_key = self.pc_keys[1]

                # Check if keys exist, proceed only if both are present
                if coord_key in data and color_key in data:
                    coords = data[coord_key]
                    colors = data[color_key]

                    # Basic check for non-empty arrays and correct second dimension
                    if coords.ndim == 2 and coords.shape[1] == 3 and \
                       colors.ndim == 2 and colors.shape[1] == 3 and \
                       coords.shape[0] == colors.shape[0] and coords.shape[0] > 0:

                        # Normalize colors if needed
                        if colors.max() > 1.0:
                            colors = colors / 255.0

                        pointcloud = np.concatenate([coords.astype(np.float32),
                                                     colors.astype(np.float32)], axis=1)

                        # Subsample or pad points
                        num_points = pointcloud.shape[0]
                        if num_points == self.n_ctx:
                            indices = np.arange(num_points)
                        elif num_points > self.n_ctx:
                            indices = np.random.choice(num_points, self.n_ctx, replace=False)
                        else: # num_points < n_ctx
                            indices = np.random.choice(num_points, self.n_ctx, replace=True)

                        sampled_points = pointcloud[indices, :]
                        pointcloud_tensor = torch.from_numpy(sampled_points).float().transpose(0, 1) # [6, n_ctx]
                    # else: pointcloud_tensor remains zeros if shapes/keys are invalid

            # else: pointcloud_tensor remains zeros if npz file not found

        except Exception as e:
            # Print error for the specific item but allow dataloader to continue
            # by returning the default zero tensors.
            # WARNING: This might lead to training on invalid data if errors are frequent.
            # Consider adding filtering logic or stricter error handling if needed.
            print(f"Warning: Error processing item for UID {uid}. Returning default tensors. Error: {e}")
            # traceback.print_exc() # Uncomment for full traceback during debugging


        # --- Return Sample ---
        # Always returns a dictionary, potentially with default zero tensors if errors occurred.
        return {
            'caption': caption,
            'image': image_tensor,
            'pointcloud': pointcloud_tensor
        }

In [2]:
# # --- Example Usage / Test Block ---
# if __name__ == '__main__':
#     print("\n--- Testing ABODatasetSimplified ---")

#     # --- Configuration ---
#     # !!! IMPORTANT: Set this to the correct path to your dataset !!!
#     DATASET_ROOT = r'C:\Users\lvbab\OneDrive\Documents\GitHub\point-e\point_e\abo_integrated' # Use raw string for Windows paths
#     TARGET_POINTS = 1024 # Or 4096, match your model's n_ctx
#     BATCH_SIZE = 4
#     NUM_WORKERS = 0 # Set > 0 for parallel loading if needed (use 0 for basic testing)

#     # --- Get CLIP Preprocessor ---
#     # You MUST provide a valid CLIP preprocessor function here.
#     # Example: Load it using the 'clip' library (install if needed)
#     clip_preprocess_func = None
#     try:
#         import clip
#         # Load on CPU for this test setup phase
#         clip_model, clip_preprocess_func = clip.load("ViT-L/14", device="cpu")
#         print("Loaded CLIP preprocessor.")
#     except ImportError:
#         print("Error: 'clip' library not found. Cannot get preprocessor.")
#         print("Please install it: pip install git+https://github.com/openai/CLIP.git")
#         # Exit or raise error if preprocessor is mandatory
#         exit()
#     except Exception as e:
#         print(f"Error loading CLIP: {e}")
#         exit()

#     if clip_preprocess_func is None:
#         print("CLIP preprocessor could not be loaded. Exiting.")
#         exit()

#     # --- Instantiate Dataset ---
#     print(f"Initializing dataset from: {DATASET_ROOT}")
#     try:
#         # Remember to adjust image_subdir and pc_keys if needed
#         abo_dataset = ABODatasetSimplified(
#             data_root=DATASET_ROOT,
#             n_ctx=TARGET_POINTS,
#             clip_preprocessor=clip_preprocess_func,
#             image_subdir='',
#             pc_keys=('coords', 'rgb')
#         )

#         # --- Create DataLoader ---
#         if len(abo_dataset) > 0:
#             data_loader = DataLoader(
#                 dataset=abo_dataset,
#                 batch_size=BATCH_SIZE,
#                 shuffle=True,
#                 num_workers=NUM_WORKERS
#             )
#             print(f"Created DataLoader with {len(abo_dataset)} samples.")

#             # --- Load and Inspect One Batch ---
#             print("\nLoading one batch...")
#             try:
#                 first_batch = next(iter(data_loader)) # Get the first batch

#                 print("\n--- Batch Data Inspection ---")
#                 # Check if keys exist before accessing
#                 if 'caption' in first_batch:
#                     print(f"Captions in batch (first one): '{first_batch['caption'][0]}'")
#                 if 'image' in first_batch:
#                     print(f"Image tensor shape: {first_batch['image'].shape}")
#                     print(f"Image tensor dtype: {first_batch['image'].dtype}")
#                 if 'pointcloud' in first_batch:
#                     print(f"Pointcloud tensor shape: {first_batch['pointcloud'].shape}")
#                     print(f"Pointcloud tensor dtype: {first_batch['pointcloud'].dtype}")
#                 print("-----------------------------")
#                 print("DataLoader test successful!")

#             except StopIteration:
#                  print("DataLoader is empty. Check dataset initialization.")
#             except Exception as e:
#                  print(f"Error loading batch from DataLoader: {e}")
#                  import traceback
#                  traceback.print_exc()
#         else:
#             print("Dataset is empty. Cannot create or test DataLoader.")

#     except FileNotFoundError as e:
#         print(f"Error: {e}") # Specifically catch file not found for jsonl
#     except Exception as e:
#         print(f"Error initializing ABODatasetSimplified: {e}")
#         import traceback
#         traceback.print_exc()



In [3]:
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from point_e.models.fusion import TextImageFusionModule
from point_e.models.multimodal import SimpleMultimodalTransformer
from point_e.models.configs import MODEL_CONFIGS
from point_e.models.download import load_checkpoint
import traceback
from point_e.diffusion.sampler import PointCloudSampler
from point_e.diffusion.configs import DIFFUSION_CONFIGS
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.models.download import load_checkpoint
from point_e.util.plotting import plot_point_cloud
from torch.optim import AdamW
import torch.nn as nn
from point_e.diffusion.gaussian_diffusion import GaussianDiffusion
from torchvision import transforms


try:
    import clip
except ImportError:
    print("Warning: 'clip' library not found. CLIP preprocessing might fail.")
    print("Install using: pip install git+https://github.com/openai/CLIP.git")
    clip = None

# == Device ==
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# == Model Initialization ==
print("Initializing model...")
use_cross_attn = False  # or False, based on your preference
model = SimpleMultimodalTransformer(
    device=device,
    dtype=torch.float32 if device == torch.device('cpu') else torch.float16,
    cache_dir="./point_e_cache_train",
    use_cross_attention=use_cross_attn
)
model.to(device)
print("Model initialized.")

# == Data ==
print("Initializing dataset and dataloader...")
DATASET_ROOT = r'C:\Users\lvbab\OneDrive\Documents\GitHub\point-e\point_e\abo_integrated'
TARGET_POINTS = 1024
BATCH_SIZE = 16  # Adjust based on GPU memory
NUM_WORKERS = 0  # Windows recommendation

# Get CLIP Preprocessor
clip_preprocess_func = None
if clip:
    try:
        clip_name_to_load = model.clip_model_name if hasattr(model, 'clip_model_name') else "ViT-L/14"
        _, clip_preprocess_func = clip.load(clip_name_to_load, device="cpu", jit=False)
        print(f"Loaded CLIP preprocessor for {clip_name_to_load}.")
    except Exception as e:
        print(f"Error loading CLIP preprocessor: {e}. Exiting.")
        exit()
else:
    print("Error: 'clip' library not available. Exiting.")
    exit()

# Create Dataset and DataLoader
try:
    dataset = ABODatasetSimplified(
        data_root=DATASET_ROOT,
        n_ctx=TARGET_POINTS,
        clip_preprocessor=clip_preprocess_func,
        pc_keys=('coords', 'colors')
    )
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=(device == torch.device('cuda'))
    )
    print(f"DataLoader created with {len(dataset)} samples.")
except Exception as e:
    print(f"Error creating dataset/dataloader: {e}")
    traceback.print_exc()
    exit()

# == Optimizer ==
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"Number of trainable parameters: {sum(p.numel() for p in trainable_params)}")
if not trainable_params:
    print("Warning: No trainable parameters found.")
learning_rate = 1e-4
optimizer = AdamW(trainable_params, lr=learning_rate)

# == Loss Function ==
loss_fn = nn.MSELoss()

# == Noise Schedule ==
print("Setting up noise schedule...")
noise_scheduler = None
q_sample_fn = None
try:
    diffusion_steps = 1024
    diffusion_config_name = 'base40M'
    diffusion_kwargs = DIFFUSION_CONFIGS.get(diffusion_config_name, {})
    model_var_type = diffusion_kwargs.get('model_var_type', 'fixed_large')
    model_mean_type = diffusion_kwargs.get('model_mean_type', 'epsilon')
    loss_type = diffusion_kwargs.get('loss_type', 'mse')
    beta_schedule_name = diffusion_kwargs.get('beta_schedule', 'linear')
    print(f"Using diffusion parameters: var_type={model_var_type}, mean_type={model_mean_type}, loss_type={loss_type}, schedule={beta_schedule_name}")

    # Define linear beta schedule manually
    beta_start = 0.0001
    beta_end = 0.02
    betas = np.linspace(beta_start, beta_end, diffusion_steps)

    init_args = {
        'betas': betas,
        'model_var_type': model_var_type,
        'model_mean_type': model_mean_type,
        'loss_type': loss_type,
    }
    noise_scheduler = GaussianDiffusion(**init_args)
    print(f"Using GaussianDiffusion with T={noise_scheduler.num_timesteps}")
    q_sample_fn = noise_scheduler.q_sample
except Exception as e:
    print(f"Error setting up noise schedule: {e}")
    traceback.print_exc()
    raise RuntimeError("Noise scheduler setup failed")

# Ensure scheduler is initialized
if noise_scheduler is None or q_sample_fn is None:
    print("Error: Noise scheduler setup failed. Exiting.")
    raise RuntimeError("Noise scheduler is not initialized")

# == Training Parameters ==
epochs = 100
save_interval = 10
output_dir = "./checkpoints_multimodal"
os.makedirs(output_dir, exist_ok=True)

# == Training Loop ==
print("\n--- Starting Training ---")
model.train()
for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    total_loss = 0.0
    steps_in_epoch = 0
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}")

    for step, batch in progress_bar:
        try:
            captions = batch['caption']
            images = batch['image'].to(device, dtype=model.dtype)
            x_0 = batch['pointcloud'].to(device, dtype=model.dtype)

            if torch.all(x_0 == 0) or torch.all(images == 0):
                continue

            current_batch_size = x_0.shape[0]
            t = torch.randint(0, noise_scheduler.num_timesteps, (current_batch_size,), device=device).long()
            epsilon = torch.randn_like(x_0)
            x_t = q_sample_fn(x_start=x_0, t=t, noise=epsilon)

            predicted_epsilon = model(x=x_t, t=t, images=images, texts=captions)
            loss = loss_fn(predicted_epsilon, epsilon)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            steps_in_epoch += 1
            progress_bar.set_postfix(loss=loss.item())
        except Exception as e:
            print(f"\nError during training step {step} for epoch {epoch+1}: {e}")
            traceback.print_exc()
            continue

    if steps_in_epoch > 0:
        avg_loss = total_loss / steps_in_epoch
        print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f} ({steps_in_epoch}/{len(data_loader)} steps)")
    else:
        print(f"Epoch {epoch+1} finished. No batches processed successfully.")

    if (epoch + 1) % save_interval == 0 or epoch == epochs - 1:
        save_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.pt")
        save_dict = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss if steps_in_epoch > 0 else float('inf'),
            'use_cross_attention': use_cross_attn
        }
        torch.save(save_dict, save_path)
        print(f"Model checkpoint saved to {save_path}")

print("\n--- Training Finished ---")

Using device: cpu
Initializing model...
Model config expects input_channels: 6
Initializing with simple Linear fusion
Model initialized.
Initializing dataset and dataloader...
Loaded CLIP preprocessor for ViT-L/14.
Found 7890 UIDs in jsonl file.
DataLoader created with 7890 samples.
Number of trainable parameters: 1444864
Setting up noise schedule...
Using diffusion parameters: var_type=fixed_large, mean_type=epsilon, loss_type=mse, schedule=linear
Using GaussianDiffusion with T=1024

--- Starting Training ---

Epoch 1/100


Epoch 1:   0%|          | 1/494 [00:38<5:12:30, 38.03s/it, loss=1.01]


KeyboardInterrupt: 

In [None]:
# import argparse
# import os
# import torch
# import torch.nn.functional as F
# from torch.utils.data import DataLoader

# from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
# from point_e.models.multimodal import SimpleMultimodalTransformer  # assumes multimodal.py is in PYTHONPATH

# # TODO: replace this with your real dataset class or import
# class MultimodalPointCloudDataset(torch.utils.data.Dataset):
#     """Simple placeholder dataset.
#     Expects each sample as a dict with keys:
#         'points'  : torch.Tensor [C, N]  (original clean point cloud)
#         'images'  : List[PIL.Image] or a single PIL.Image for conditioning
#         'captions': str (text prompt)
#     """
#     def __init__(self, root: str, split: str = "train"):
#         super().__init__()
#         # load your metadata here
#         self.meta = []  # list of file paths / annotations
#         # ...

#     def __len__(self):
#         return len(self.meta)

#     def __getitem__(self, idx):
#         # load point cloud, image and caption
#         raise NotImplementedError("Implement your dataset loading logic here")


# def collate_fn(batch):
#     """Keeps images & captions as Python lists so CLIP can handle them."""
#     points = torch.stack([b["points"] for b in batch], dim=0)  # [B, C, N]
#     images = [b["images"] for b in batch]
#     captions = [b["captions"] for b in batch]
#     return {"points": points, "images": images, "captions": captions}


# def parse_args():
#     p = argparse.ArgumentParser("Train multimodal fusion (Point‑E)")
#     p.add_argument("--data", required=True, help="Dataset root directory")
#     p.add_argument("--epochs", type=int, default=50)
#     p.add_argument("--batch_size", type=int, default=8)
#     p.add_argument("--lr", type=float, default=1e-4)
#     p.add_argument("--outdir", default="checkpoints")
#     p.add_argument("--cache_dir", default=None, help="CLIP/Point‑E cache")
#     p.add_argument("--use_cross_attention", action="store_true")
#     p.add_argument("--mixed_precision", action="store_true")
#     return p.parse_args()


# def main():
#     args = parse_args()

#     # Device & dtype --------------------------------------------------------
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     dtype = torch.float16 if args.mixed_precision else torch.float32

#     # Data ------------------------------------------------------------------
#     train_ds = MultimodalPointCloudDataset(args.data, split="train")
#     train_loader = DataLoader(
#         train_ds,
#         batch_size=args.batch_size,
#         shuffle=True,
#         num_workers=4,
#         pin_memory=True,
#         drop_last=True,
#         collate_fn=collate_fn,
#     )

#     # Diffusion process -----------------------------------------------------
#     diffusion = diffusion_from_config(DIFFUSION_CONFIGS["base"],)  # base schedule
#     diffusion = diffusion.to(device)

#     # Model -----------------------------------------------------------------
#     model = SimpleMultimodalTransformer(
#         device=device,
#         dtype=dtype,
#         cache_dir=args.cache_dir,
#         use_cross_attention=args.use_cross_attention,
#     ).to(device)

#     # Optimizer -------------------------------------------------------------
#     optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
#     scaler = torch.cuda.amp.GradScaler(enabled=args.mixed_precision)

#     os.makedirs(args.outdir, exist_ok=True)
#     global_step = 0

#     # Training loop ---------------------------------------------------------
#     for epoch in range(1, args.epochs + 1):
#         model.train()
#         for batch in train_loader:
#             pc0 = batch["points"].to(device)         # x_0
#             images = batch["images"]                 # list of PIL.Image
#             captions = batch["captions"]             # list[str]

#             bs = pc0.size(0)
#             t = torch.randint(0, diffusion.num_timesteps, (bs,), device=device, dtype=torch.long)
#             noise = torch.randn_like(pc0)
#             x_t = diffusion.q_sample(pc0, t, noise=noise)

#             with torch.cuda.amp.autocast(enabled=args.mixed_precision):
#                 pred_noise = model(x_t, t, images=images, texts=captions)
#                 loss = F.mse_loss(pred_noise, noise)

#             scaler.scale(loss).backward()
#             scaler.unscale_(optim)
#             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#             scaler.step(optim)
#             scaler.update()
#             optim.zero_grad(set_to_none=True)

#             if global_step % 100 == 0:
#                 print(f"Epoch {epoch} | Step {global_step} | Loss {loss.item():.4f}")
#             global_step += 1

#         # Save checkpoint ----------------------------------------------------
#         ckpt_path = os.path.join(args.outdir, f"epoch_{epoch:03d}.pt")
#         torch.save({
#             "model": model.state_dict(),
#             "optimizer": optim.state_dict(),
#             "epoch": epoch,
#             "global_step": global_step,
#         }, ckpt_path)
#         print(f"Saved checkpoint to {ckpt_path}")

#     print("Training complete.")


# if __name__ == "__main__":
#     main()
