In [None]:
!git clone https://github.com/amiralichangizi/3DHighlighter.git

In [None]:
import os

os.chdir('/content/3DHighlighter')

In [None]:
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html

In [1]:

import sys
import torch

need_pytorch3d = False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d = True
if need_pytorch3d:
    pyt_version_str = torch.__version__.split("+")[0].replace(".", "")
    version_str = "".join([
        f"py3{sys.version_info.minor}_cu",
        torch.version.cuda.replace(".", ""),
        f"_pyt{pyt_version_str}"
    ])
    !pip install iopath
    if sys.platform.startswith("linux"):
        print("Trying to install wheel for PyTorch3D")
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
        pip_list = !pip freeze
        need_pytorch3d = not any(i.startswith("pytorch3d==") for i in pip_list)
    if need_pytorch3d:
        print(f"failed to find/install wheel for {version_str}")
if need_pytorch3d:
    print("Installing PyTorch3D from source")
    !pip install ninja
    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

failed to find/install wheel for py39_cu113_pyt1121
Installing PyTorch3D from source


ERROR: Invalid requirement: "'git+https://github.com/facebookresearch/pytorch3d.git@stable'"


In [None]:
!pip install open3d

In [None]:
!mkdir -p data/PittsburghBridge
!wget -P data/PittsburghBridge https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz

In [None]:

from src.mesh import Mesh
from pytorch3d.structures import Pointclouds


def load_3d_data(file_path, num_points=10000, device="cuda"):
    """
    Loads 3D data as PyTorch3D Pointclouds from either NPZ point cloud or OBJ mesh.

    Args:
        file_path: Path to either .npz point cloud or .obj mesh file
        num_points: Number of points to sample if loading from mesh
        device: Device to load data on

    Returns:
        Pointclouds object containing points and features
    """
    file_ext = file_path.split('.')[-1].lower()

    if file_ext == 'npz':
        # Load NPZ point cloud directly like in the example
        pointcloud = np.load(file_path)
        verts = torch.Tensor(pointcloud['verts']).to(device)
        rgb = torch.Tensor(pointcloud['rgb']).to(device)

        # Subsample if needed
        if len(verts) > num_points:
            idx = torch.randperm(len(verts))[:num_points]
            verts = verts[idx]
            rgb = rgb[idx]

        # Return both the points tensor and the Pointclouds object
        point_cloud = Pointclouds(points=[verts], features=[rgb])
        return verts, point_cloud  # Return both

    elif file_ext == 'obj':
        # Load mesh and sample points
        mesh = Mesh(file_path)
        vertices = mesh.vertices

        # Sample random points
        idx = torch.randperm(vertices.shape[0])[:num_points]
        points = vertices[idx].to(device)

        # Initialize with gray color
        colors = torch.ones_like(points) * 0.7

        return Pointclouds(points=[points], features=[colors])

    else:
        raise ValueError(f"Unsupported file format: {file_ext}. Only .npz and .obj are supported.")



In [None]:
from src.save_results import save_renders, save_results
from src.render.cloud_point_renderer import PointCloudRenderer
from src.neural_highlighter import NeuralHighlighter
from src.Clip.loss_function import clip_loss
from src.Clip.clip_model import get_clip_model, encode_text, setup_clip_transforms

import torch
import numpy as np
import random
from tqdm import tqdm

# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
# Set a consistent seed for reproducibility
seed = 0  # You can use any integer value
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def optimize_point_cloud(points, point_cloud, clip_model, renderer, encoded_text, log_dir: str, num_iterations=500,
                         learning_rate=1e-4, device="cuda"):
    # Initialize network and optimizer
    net = NeuralHighlighter(
        depth=5,  # Number of hidden layers
        width=256,  # Width of each layer
        out_dim=2,  # Binary classification (highlight/no-highlight)
        input_dim=3,  # 3D coordinates (x,y,z)
        positional_encoding=False  # As recommended in the paper
    ).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    render_engine = renderer.setup_renderer()

    # Set up the transforms
    clip_transform, augment_transform = setup_clip_transforms()

    # Training loop
    for i in tqdm(range(num_iterations)):
        optimizer.zero_grad()

        # Predict highlight probabilities
        pred_class = net(points)

        # Create colors based on predictions
        highlight_color = torch.tensor([204 / 255, 1.0, 0.0]).to(device)
        base_color = torch.tensor([180 / 255, 180 / 255, 180 / 255]).to(device)

        colors = pred_class[:, 0:1] * highlight_color + pred_class[:, 1:2] * base_color

        # Create and render point cloud
        point_cloud = renderer.create_point_cloud(points, colors)
        rendered_images = render_engine(point_cloud)

        #Convert rendered images to CLIP format
        rendered_images = rendered_images.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]

        # Calculate CLIP loss
        loss = clip_loss(rendered_images, encoded_text, clip_transform,
                         augment_transform, clip_model)

        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Iteration {i}, Loss: {loss.item():.4f}")
            save_renders(log_dir, i, rendered_images)

    return net


def main(input_path, object_name, highlight_region,
         num_points=10000, device="cuda", output_dir="./output"):
    """
    Main function to process 3D data and generate highlighted regions.

    Args:
        input_path: Path to input 3D file (mesh or point cloud)
        object_name: Name of the object for the prompt
        highlight_region: Region to highlight
        num_points: Number of points to use
        device: Device to run on
        output_dir: Directory to save outputs
    """
    try:
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)

        # Load 3D data (either mesh or point cloud)
        print(f"Loading 3D data from {input_path}...")
        points, point_cloud = load_3d_data(input_path, num_points=num_points, device=device)
        print(f"Loaded {len(points)} points")

        # Setup CLIP model
        print("Setting up CLIP model...")
        clip_model, preprocess, resolution = get_clip_model()

        # Create and encode prompt
        prompt = f"A 3D render of a gray {object_name} with highlighted {highlight_region}"
        print(f"Using prompt: {prompt}")
        text_features = encode_text(clip_model, prompt, device)

        # Initialize renderer
        print("Setting up renderer...")
        renderer = PointCloudRenderer(device=device)

        # Optimize point cloud highlighting
        print("Starting optimization...")
        net = optimize_point_cloud(
            points=points,
            point_cloud=point_cloud,
            renderer=renderer,
            clip_model=clip_model,
            encoded_text=text_features,
            device=device,
            log_dir=output_dir
        )

        # Save results
        print("Saving results...")
        save_results(
            net=net,
            points=points,
            point_cloud=point_cloud,  # Add this
            prompt=prompt,
            output_dir=output_dir,
            renderer=renderer
        )

        print("Processing complete!")
        return net, points

    except Exception as e:
        print(f"Error in processing: {str(e)}")
        raise



In [None]:
main(
    input_path="/content/3DHighlighter/data/PittsburghBridge/pointcloud.npz",
    object_name="Bridge",
    highlight_region="floor",
)