# WaveMesh-Diff - Google Colab Quick Start

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HoangNguyennnnnnn/WaveMeshDf/blob/main/colab_quickstart.ipynb)

**3D Mesh Generation using Diffusion Models in Wavelet Domain**

---

## ‚ö° Quick Overview

This notebook demonstrates:
1. ‚úÖ Setup WaveMesh-Diff in Google Colab
2. üß™ Test all 4 modules (Wavelet, U-Net, Diffusion, Multi-view)
3. üìä Visualize sparse wavelet representation
4. üé® Run quick demos

**Estimated time: 10-15 minutes**

---

## üöÄ Setup

### 1. Clone Repository

In [None]:
!git clone https://github.com/HoangNguyennnnnnn/WaveMeshDf.git
%cd WaveMeshDf

### 2. Install Dependencies

In [None]:
# C√†i dependencies c∆° b·∫£n
!pip install -q PyWavelets trimesh matplotlib rtree scipy

# PyTorch th∆∞·ªùng ƒë√£ c√≥ s·∫µn trong Colab
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# C√†i th√™m scikit-image cho marching cubes
!pip install -q scikit-image

print("‚úÖ All dependencies installed!")

In [None]:
# Quick verification - check imports work
try:
    import pywt
    import trimesh
    import matplotlib
    from skimage import measure
    print("‚úÖ PyWavelets:", pywt.__version__)
    print("‚úÖ Trimesh:", trimesh.__version__)
    print("‚úÖ Matplotlib:", matplotlib.__version__)
    print("‚úÖ scikit-image: OK")
    print("\nüéâ All core dependencies ready!")
except ImportError as e:
    print(f"‚ùå Missing dependency: {e}")
    print("Run the install cells above to fix this.")

### 3. Optional: Install Advanced Features

In [None]:
# C√†i transformers cho DINOv2 (t√πy ch·ªçn - c·∫£i thi·ªán quality)
!pip install -q transformers huggingface_hub

# Login HuggingFace (c·∫ßn token t·ª´ https://huggingface.co/settings/tokens)
# from huggingface_hub import login
# login(token="your_token_here")

---

## üß™ Test Installation

In [None]:
# Test t·∫•t c·∫£ modules
!python test_all_modules.py

**Note:** N·∫øu g·∫∑p l·ªói import, restart runtime v√† ch·∫°y l·∫°i t·ª´ ƒë·∫ßu.

**K·ª≥ v·ªçng:**
```
Results: 4/4 modules passed
  Module A ‚úÖ PASS
  Module B ‚úÖ PASS
  Module C ‚úÖ PASS
  Module D ‚úÖ PASS
```

---

## üìä Quick Demo

### Module A: Wavelet Transform

In [None]:
from data.wavelet_utils import mesh_to_sdf_simple, sdf_to_sparse_wavelet, sparse_wavelet_to_sdf
import trimesh
import numpy as np
import matplotlib.pyplot as plt

# T·∫°o mesh m·∫´u
mesh = trimesh.creation.box(extents=[1, 1, 1])
print(f"Mesh: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")

# Chuy·ªÉn sang SDF
sdf = mesh_to_sdf_simple(mesh, resolution=32)
print(f"SDF shape: {sdf.shape}")

# Wavelet transform - tr·∫£ v·ªÅ dictionary
sparse_data = sdf_to_sparse_wavelet(sdf, threshold=0.01)
print(f"Sparse indices: {sparse_data['indices'].shape}")
print(f"Sparse features: {sparse_data['features'].shape}")

# Calculate sparsity
total_elements = 32 ** 3
non_zero = len(sparse_data['features'])
sparsity = 100 * (1 - non_zero / total_elements)
print(f"Sparsity: {sparsity:.1f}%")

# Reconstruct
sdf_recon = sparse_wavelet_to_sdf(sparse_data)
mse = np.mean((sdf - sdf_recon) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")

### Visualize SDF

In [None]:
# Visualize SDF slice
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].imshow(sdf[16, :, :], cmap='RdBu')
axes[0].set_title('Original SDF (slice)')
axes[0].axis('off')

axes[1].imshow(sdf_recon[16, :, :], cmap='RdBu')
axes[1].set_title('Reconstructed SDF')
axes[1].axis('off')

diff = np.abs(sdf - sdf_recon)
axes[2].imshow(diff[16, :, :], cmap='hot')
axes[2].set_title(f'Error (MSE={mse:.6f})')
axes[2].axis('off')

plt.tight_layout()
plt.show()

---

### Module D: Multi-view Encoder

In [None]:
from models import create_multiview_encoder
import torch

# T·∫°o encoder
encoder = create_multiview_encoder(preset='small')
print(f"Encoder created: {sum(p.numel() for p in encoder.parameters()):,} params")

# Test v·ªõi random data
batch_size = 2
num_views = 4
images = torch.randn(batch_size, num_views, 3, 224, 224)
poses = torch.randn(batch_size, num_views, 3, 4)

# Forward pass
with torch.no_grad():
    conditioning = encoder(images, poses)

print(f"Input images: {images.shape}")
print(f"Input poses: {poses.shape}")
print(f"Output conditioning: {conditioning.shape}")
print("‚úÖ Multi-view encoder working!")

---

### Module B + C: U-Net + Diffusion

In [None]:
from models import WaveMeshUNet, GaussianDiffusion

# T·∫°o U-Net
unet = WaveMeshUNet(
    in_channels=1,
    encoder_channels=[16, 32, 64],
    decoder_channels=[64, 32, 16],
    time_emb_dim=128,
    use_attention=True,
    context_dim=384  # Match Module D output
)
print(f"U-Net: {sum(p.numel() for p in unet.parameters()):,} params")

# T·∫°o Diffusion
diffusion = GaussianDiffusion(
    timesteps=1000,
    beta_schedule='linear'
)
print(f"Diffusion: {diffusion.timesteps} timesteps")
print(f"Beta range: [{diffusion.betas[0]:.6f}, {diffusion.betas[-1]:.6f}]")
print("‚úÖ U-Net + Diffusion ready!")

---

## üìä Download Data

### Option 1: ModelNet40 (Quick - 500MB)

In [None]:
# Download ModelNet40
!python scripts/download_data.py --dataset modelnet40

# Check downloaded data
!ls -lh data/ModelNet40/ 2>/dev/null || echo "Data downloading... Check scripts/download_data.py for manual instructions"

### Option 2: ShapeNet (Manual)

ƒê·ªÉ download ShapeNet:
1. ƒêƒÉng k√Ω t·∫°i https://shapenet.org/
2. Download ShapeNetCore.v2
3. Upload l√™n Google Drive
4. Mount Drive v√† copy data

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Copy ShapeNet data (n·∫øu ƒë√£ c√≥ trong Drive)
# !cp -r /content/drive/MyDrive/ShapeNetCore.v2 ./data/

---

## üé® Visualize Pipeline

In [None]:
# Visualize complete pipeline
# Note: visualize_results.py c·∫ßn ƒë∆∞·ª£c t·∫°o tr∆∞·ªõc
# Ho·∫∑c d√πng code ƒë∆°n gi·∫£n d∆∞·ªõi ƒë√¢y:

import matplotlib.pyplot as plt
from data.wavelet_utils import WaveletTransform3D
import numpy as np

# Simple visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('WaveMesh-Diff Pipeline Overview', fontsize=16, fontweight='bold')

# 1. Input SDF
axes[0, 0].imshow(sdf[16, :, :], cmap='RdBu')
axes[0, 0].set_title('1. Input SDF (slice)')
axes[0, 0].axis('off')

# 2. Wavelet Coefficients (visualize sparsity)
sparse_indices = sparse_data['indices']
sparse_viz = np.zeros((32, 32))
for idx in sparse_indices:
    if idx[2] == 16:  # Same slice
        sparse_viz[idx[0], idx[1]] += 1
axes[0, 1].imshow(sparse_viz, cmap='hot')
axes[0, 1].set_title(f'2. Sparse Wavelet ({sparsity:.1f}% sparse)')
axes[0, 1].axis('off')

# 3. Reconstructed SDF
axes[1, 0].imshow(sdf_recon[16, :, :], cmap='RdBu')
axes[1, 0].set_title('3. Reconstructed SDF')
axes[1, 0].axis('off')

# 4. Reconstruction Error
error = np.abs(sdf - sdf_recon)
axes[1, 1].imshow(error[16, :, :], cmap='Reds')
axes[1, 1].set_title(f'4. Error (MSE={mse:.6f})')
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()
print("‚úÖ Pipeline visualization complete!")

---

## üèãÔ∏è Training Example (Conceptual)

‚ö†Ô∏è **L∆∞u √Ω:** Training ƒë·∫ßy ƒë·ªß c·∫ßn nhi·ªÅu th·ªùi gian v√† GPU. Xem `ROADMAP.md` ƒë·ªÉ c√≥ code ƒë·∫ßy ƒë·ªß.

In [None]:
# Training loop (conceptual)
# Xem ROADMAP.md ƒë·ªÉ c√≥ implementation ƒë·∫ßy ƒë·ªß

from models import create_multiview_encoder, WaveMeshUNet, GaussianDiffusion
import torch

# 1. Prepare models
encoder = create_multiview_encoder(preset='small')
unet = WaveMeshUNet(
    in_channels=1,
    encoder_channels=[16, 32, 64],
    decoder_channels=[64, 32, 16],
    time_emb_dim=128,
    context_dim=384
)
diffusion = GaussianDiffusion(timesteps=1000)

# 2. Optimizer
optimizer = torch.optim.AdamW([
    {'params': encoder.parameters(), 'lr': 1e-5},
    {'params': unet.parameters(), 'lr': 1e-4}
])

print("‚úÖ Training setup ready!")
print(f"Encoder params: {sum(p.numel() for p in encoder.parameters()):,}")
print(f"U-Net params: {sum(p.numel() for p in unet.parameters()):,}")
print("\nüìñ See ROADMAP.md for full training code with:")
print("  ‚Ä¢ Dataset implementation")
print("  ‚Ä¢ Training loop")
print("  ‚Ä¢ Evaluation metrics")
print("  ‚Ä¢ Checkpointing")

---

## üìö Next Steps

1. **ƒê·ªçc Documentation:**
   - [README.md](README.md) - Project overview
   - [QUICKSTART.md](QUICKSTART.md) - Quick start guide
   - [ROADMAP.md](ROADMAP.md) - Training roadmap
   - [ARCHITECTURE.md](ARCHITECTURE.md) - Technical details

2. **Download Data:**
   - ModelNet40 (500MB) - Quick start
   - ShapeNet (50GB) - Better quality

3. **Train Model:**
   - Xem `ROADMAP.md` ƒë·ªÉ c√≥ training code ƒë·∫ßy ƒë·ªß
   - Implement dataset loader
   - Run training loop

4. **Improve:**
   - Mixed precision training
   - Classifier-free guidance
   - EMA for better quality

---

## üêõ Troubleshooting

### "ModuleNotFoundError: No module named 'pywt'"
```python
!pip install PyWavelets
```

### "ModuleNotFoundError: No module named 'rtree'"
```python
!pip install rtree
```

### "ValueError: too many values to unpack (expected 2)"
L·ªói n√†y x·∫£y ra khi d√πng API c≈©. `sdf_to_sparse_wavelet()` tr·∫£ v·ªÅ **dictionary**, kh√¥ng ph·∫£i tuple:
```python
# ‚ùå Sai:
coeffs, coords = sdf_to_sparse_wavelet(sdf)

# ‚úÖ ƒê√∫ng:
sparse_data = sdf_to_sparse_wavelet(sdf, threshold=0.01)
print(sparse_data['indices'].shape)
print(sparse_data['features'].shape)
```

### "CUDA out of memory"
```python
# Gi·∫£m batch size ho·∫∑c resolution
batch_size = 2
resolution = 16
```

### "transformers not available"
```python
!pip install transformers huggingface_hub
# Code s·∫Ω t·ª± ƒë·ªông fallback sang CNN n·∫øu kh√¥ng c√≥
```

### "No module named 'skimage'"
```python
!pip install scikit-image
```

Xem ƒë·∫ßy ƒë·ªß t·∫°i [TROUBLESHOOTING.md](TROUBLESHOOTING.md).

---

**Happy 3D Generation! üé®**