<a href="https://colab.research.google.com/github/Wizorld/sparp_iit/blob/main/Sparp_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision
!pip install diffusers transformers
!pip install numpy scipy



In [None]:
!pip install opencv-python



In [2]:
!pip install datasets
!pip install huggingface_hub
!pip install tqdm
!pip install trimesh
!pip install pyrender
!apt-get update && apt-get install -y xvfb
!pip install pyglet

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
!pip install objaverse

Collecting objaverse
  Downloading objaverse-0.1.7-py3-none-any.whl.metadata (4.6 kB)
Collecting loguru (from objaverse)
  Downloading loguru-0.7.2-py3-none-any.whl.metadata (23 kB)
Collecting gputil==1.4.0 (from objaverse)
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading objaverse-0.1.7-py3-none-any.whl (32 kB)
Downloading loguru-0.7.2-py3-none-any.whl (62 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-py3-none-any.whl size=7392 sha256=2311ed06919a3b84960316461bd8046c31492ee5af01cbd4ad49823b3c337453
  Stored in directory: /root/.cache/pip/wheels/a9/8a/bd/81082387151853ab8b6b3ef33426e98f5cbfebc3c397a9d4d0
Successfully built gputil
Installing collected packages: gputil, loguru, objaverse
Success

In [3]:
import os
import torch
import numpy as np
from PIL import Image, ImageDraw
import json
from pathlib import Path
from typing import List, Dict
import random
import matplotlib.pyplot as plt
from math import sin, cos, pi
import numpy as np
from scipy.spatial.transform import Rotation
import shutil

class SimpleDataLoader:
    def __init__(self,
                 data_dir: str = './test_data',
                 image_size: int = 512):
        print("Initializing SimpleDataLoader...")

        self.data_dir = Path(data_dir)
        self.image_size = image_size

        # Create directories
        print("Creating directories...")
        self.data_dir.mkdir(exist_ok=True, parents=True)
        self.image_dir = self.data_dir / 'images'
        self.image_dir.mkdir(exist_ok=True)
        self.nocs_dir = self.data_dir / 'nocs_maps'
        self.nocs_dir.mkdir(exist_ok=True)

        # Initialize dataset
        print("Initializing dataset...")
        self.dataset = {}

        # Initialize or load dataset info
        self.dataset_file = self.data_dir / 'dataset_info.json'
        if self.dataset_file.exists():
            print("Loading existing dataset...")
            self._load_dataset()
        else:
            print("No existing dataset found.")
            self._save_dataset()

        # Create test data if dataset is empty
        if not self.dataset:
            print("Creating new test data...")
            self._create_test_data()

        print(f"Initialization complete. Dataset contains {len(self.dataset)} objects.")


    def _generate_nocs_map(self, shape_type: str, view_angle: float, size: int) -> Image.Image:
        """Generate NOCS map for a given shape and view angle"""
        nocs_map = np.zeros((size, size, 3), dtype=np.float32)
        center_x = size // 2
        center_y = size // 2
        shape_size = size // 3

        if shape_type == 'cube':
            # Generate cube NOCS coordinates
            for x in range(size):
                for y in range(size):
                    # Convert to local coordinates
                    local_x = (x - center_x) / shape_size
                    local_y = (y - center_y) / shape_size

                    # Apply view rotation
                    rot_matrix = np.array([
                        [cos(view_angle), -sin(view_angle)],
                        [sin(view_angle), cos(view_angle)]
                    ])
                    rotated = np.dot(rot_matrix, np.array([local_x, local_y]))
                    local_x, local_y = rotated

                    # Check if point is inside cube
                    if abs(local_x) <= 1 and abs(local_y) <= 1:
                        # Normalize coordinates to [0, 1]
                        nocs_x = (local_x + 1) / 2
                        nocs_y = (local_y + 1) / 2
                        nocs_z = 0.5  # Middle of the cube

                        nocs_map[y, x] = [nocs_x, nocs_y, nocs_z]

        elif shape_type == 'cylinder':
            # Generate cylinder NOCS coordinates
            for x in range(size):
                for y in range(size):
                    local_x = (x - center_x) / shape_size
                    local_y = (y - center_y) / shape_size

                    # Calculate radius and angle for cylinder
                    radius = np.sqrt(local_x**2 + local_y**2)
                    if radius <= 1:
                        angle = np.arctan2(local_y, local_x)

                        # Convert to NOCS coordinates
                        nocs_x = (radius * np.cos(angle + view_angle) + 1) / 2
                        nocs_y = (radius * np.sin(angle + view_angle) + 1) / 2
                        nocs_z = 0.5  # Middle of cylinder

                        nocs_map[y, x] = [nocs_x, nocs_y, nocs_z]

        elif shape_type == 'pyramid':
            # Generate pyramid NOCS coordinates
            height = shape_size * 2
            base_width = shape_size * 2

            for x in range(size):
                for y in range(size):
                    local_x = (x - center_x) / (base_width/2)
                    local_y = (y - center_y) / height

                    # Apply view rotation
                    rot_matrix = np.array([
                        [cos(view_angle), -sin(view_angle)],
                        [sin(view_angle), cos(view_angle)]
                    ])
                    rotated = np.dot(rot_matrix, np.array([local_x, local_y]))
                    local_x, local_y = rotated

                    # Check if point is inside pyramid
                    if abs(local_x) <= (1 - abs(local_y)) and abs(local_y) <= 1:
                        nocs_x = (local_x + 1) / 2
                        nocs_y = (local_y + 1) / 2
                        nocs_z = (1 - abs(local_y)) / 2  # Height decreases with y

                        nocs_map[y, x] = [nocs_x, nocs_y, nocs_z]

        # Convert to PIL Image
        nocs_map = (nocs_map * 255).astype(np.uint8)
        return Image.fromarray(nocs_map)


    def _create_gradient_texture(self, size):
        """Create a gradient texture"""
        img = Image.new('RGB', (size, size))
        pixels = img.load()
        for i in range(size):
            for j in range(size):
                r = int((i / size) * 255)
                g = int((j / size) * 255)
                b = int(((i + j) / (2 * size)) * 255)
                pixels[i, j] = (r, g, b)
        return img

    def _create_pattern_texture(self, size, pattern_type='grid'):
        """Create a patterned texture"""
        img = Image.new('RGB', (size, size), 'white')
        draw = ImageDraw.Draw(img)

        if pattern_type == 'grid':
            # Draw grid lines
            spacing = size // 8
            for i in range(0, size, spacing):
                draw.line([(i, 0), (i, size)], fill='black', width=2)
                draw.line([(0, i), (size, i)], fill='black', width=2)

        elif pattern_type == 'dots':
            # Draw dots pattern
            spacing = size // 8
            dot_size = spacing // 4
            for i in range(0, size, spacing):
                for j in range(0, size, spacing):
                    draw.ellipse([i-dot_size, j-dot_size, i+dot_size, j+dot_size],
                               fill='black')

        elif pattern_type == 'stripes':
            # Draw diagonal stripes
            spacing = size // 16
            for i in range(-size, size*2, spacing):
                draw.line([(i, 0), (i+size, size)], fill='black', width=2)

        return img

    def _apply_texture_to_shape(self, img, shape_mask, texture):
        """Apply texture to a shape using a mask"""
        # Convert images to numpy arrays for easier manipulation
        img_array = np.array(img)
        mask_array = np.array(shape_mask)
        texture_array = np.array(texture)

        # Apply texture only where mask is non-zero
        for c in range(3):  # For each color channel
            img_array[:,:,c] = np.where(mask_array == 255,
                                      texture_array[:,:,c],
                                      img_array[:,:,c])

        return Image.fromarray(img_array)

    def _create_textured_shape(self, size, shape_type, view_angle=0):
        """Create a shape with texture"""
        # Create base image and mask
        img = Image.new('RGB', (size, size), 'white')
        mask = Image.new('L', (size, size), 0)
        draw = ImageDraw.Draw(mask)

        center_x = size // 2
        center_y = size // 2
        shape_size = size // 3

        # Create shape mask with rotation
        if shape_type == 'cube':
            # Draw a cube-like shape
            points = [
                (center_x - shape_size + shape_size * cos(view_angle),
                 center_y - shape_size + shape_size * sin(view_angle)),
                (center_x + shape_size + shape_size * cos(view_angle),
                 center_y - shape_size + shape_size * sin(view_angle)),
                (center_x + shape_size - shape_size * sin(view_angle),
                 center_y + shape_size + shape_size * cos(view_angle)),
                (center_x - shape_size - shape_size * sin(view_angle),
                 center_y + shape_size + shape_size * cos(view_angle))
            ]
            draw.polygon(points, fill=255)

        elif shape_type == 'cylinder':
            # Draw a cylinder-like shape
            draw.ellipse([center_x - shape_size, center_y - shape_size//2,
                         center_x + shape_size, center_y + shape_size//2], fill=255)

        elif shape_type == 'pyramid':
            # Draw a pyramid-like shape
            height = shape_size * 2
            base_width = shape_size * 2
            points = [
                (center_x, center_y - height//2),  # Top
                (center_x - base_width//2, center_y + height//2),  # Bottom left
                (center_x + base_width//2, center_y + height//2)   # Bottom right
            ]
            draw.polygon(points, fill=255)

        # Create and apply texture
        patterns = ['grid', 'dots', 'stripes']
        texture = self._create_pattern_texture(size, random.choice(patterns))

        # Apply a color tint
        tint = Image.new('RGB', (size, size),
                        (random.randint(50, 200),
                         random.randint(50, 200),
                         random.randint(50, 200)))
        texture = Image.blend(texture, tint, 0.5)

        # Apply texture to shape
        textured_img = self._apply_texture_to_shape(img, mask, texture)

        return textured_img

    def _create_test_data(self):
      """Create test data with textured shapes and NOCS maps"""
      print("Creating test dataset...")

      shapes = ['cube', 'cylinder', 'pyramid']

      for obj_idx, shape_type in enumerate(shapes):
          try:
              obj_id = f"test_object_{obj_idx}"
              print(f"\nProcessing {shape_type} (ID: {obj_id})...")

              object_data = {
                  'uid': obj_id,
                  'category': shape_type,
                  'images': [],
                  'nocs_maps': []  # Ensure this is initialized
              }

              # Create 6 views with different angles
              for view_idx in range(6):
                  print(f"Creating view {view_idx}")
                  image_path = self.image_dir / f"{obj_id}_view_{view_idx}.jpg"
                  nocs_path = self.nocs_dir / f"{obj_id}_nocs_{view_idx}.png"

                  view_angle = (view_idx * 2 * pi) / 6

                  # Create textured view
                  if not image_path.exists():
                      img = self._create_textured_shape(self.image_size,
                                                      shape_type,
                                                      view_angle)
                      img.save(image_path)
                      print(f"Saved image: {image_path}")

                  # Create NOCS map
                  if not nocs_path.exists():
                      nocs_map = self._generate_nocs_map(shape_type,
                                                      view_angle,
                                                      self.image_size)
                      nocs_map.save(nocs_path)
                      print(f"Saved NOCS map: {nocs_path}")

                  # Store paths in object data
                  object_data['images'].append(str(image_path))
                  object_data['nocs_maps'].append(str(nocs_path))

              # Save object data to dataset
              self.dataset[obj_id] = object_data
              # Save after each object in case of errors
              self._save_dataset()
              print(f"Successfully processed {shape_type}")

          except Exception as e:
              print(f"Error processing shape {shape_type}: {str(e)}")
              continue

      print(f"\nCreated {len(self.dataset)} test objects with NOCS maps")


    def _save_dataset(self):
        print(f"Saving dataset with {len(self.dataset)} objects...")
        with open(self.dataset_file, 'w') as f:
            json.dump(self.dataset, f, indent=2)
        print("Dataset saved successfully.")

    def _load_dataset(self):
      """Load dataset with error handling"""
      try:
          print("Loading dataset from file...")
          with open(self.dataset_file, 'r') as f:
              loaded_data = json.load(f)

          # Verify and fix data structure if needed
          for obj_id, obj_data in loaded_data.items():
              if 'nocs_maps' not in obj_data:
                  print(f"Fixing missing nocs_maps for {obj_id}")
                  obj_data['nocs_maps'] = []
                  # Generate NOCS maps for existing images
                  for idx, img_path in enumerate(obj_data['images']):
                      nocs_path = self.nocs_dir / f"{obj_id}_nocs_{idx}.png"
                      if not nocs_path.exists():
                          shape_type = obj_data['category']
                          view_angle = (idx * 2 * pi) / 6
                          nocs_map = self._generate_nocs_map(shape_type,
                                                          view_angle,
                                                          self.image_size)
                          nocs_map.save(nocs_path)
                      obj_data['nocs_maps'].append(str(nocs_path))

          self.dataset = loaded_data
          print(f"Dataset loaded successfully with {len(self.dataset)} objects.")

      except Exception as e:
          print(f"Error loading dataset: {str(e)}")
          print("Creating new dataset...")
          self.dataset = {}
          self._create_test_data()

    def load_object_views(self, uid: str) -> tuple:
        """Load all views and NOCS maps for a specific object"""
        if uid not in self.dataset:
            raise ValueError(f"Object {uid} not found in dataset")

        images = []
        nocs_maps = []

        for img_path, nocs_path in zip(self.dataset[uid]['images'],
                                     self.dataset[uid]['nocs_maps']):
            images.append(Image.open(img_path))
            nocs_maps.append(Image.open(nocs_path))

        return images, nocs_maps

    def get_random_object(self) -> tuple:
        """Get a random object and its views"""
        if not self.dataset:
            raise ValueError("No objects in dataset")

        uid = random.choice(list(self.dataset.keys()))
        images = self.load_object_views(uid)
        return uid, images

    def estimate_poses(self):
      """
      Estimate poses for all views of all objects
      """
      pose_estimator = PoseEstimator(self.image_size)

      for uid in self.dataset.keys():
          try:
              images, nocs_maps = self.load_object_views(uid)

              # Store poses for this object
              poses = []

              for img, nocs in zip(images, nocs_maps):
                  try:
                      # Estimate pose
                      rotation, translation = pose_estimator.estimate_pose(np.array(img), np.array(nocs))

                      # Create visualization
                      vis_img = pose_estimator.visualize_pose_estimation(
                          np.array(img),
                          np.array(nocs),
                          rotation,
                          translation
                      )

                      # Store pose data
                      poses.append({
                          'rotation': rotation.tolist(),
                          'translation': translation.tolist()
                      })

                      # Visualize results
                      plt.figure(figsize=(15, 5))

                      plt.subplot(131)
                      plt.imshow(img)
                      plt.title('Original Image')
                      plt.axis('off')

                      plt.subplot(132)
                      plt.imshow(nocs)
                      plt.title('NOCS Map')
                      plt.axis('off')

                      plt.subplot(133)
                      plt.imshow(vis_img)
                      plt.title('Pose Estimation')
                      plt.axis('off')

                      plt.suptitle(f'Object {uid} - Pose Estimation Results')
                      plt.show()

                  except Exception as e:
                      print(f"Error estimating pose for view: {str(e)}")
                      poses.append(None)

              # Store poses in dataset
              self.dataset[uid]['poses'] = poses
              self._save_dataset()

          except Exception as e:
              print(f"Error processing object {uid}: {str(e)}")
              continue

    def generate_novel_views(self, num_views=8):
      """Generate novel views for each object"""
      trainer = MultiViewTrainer(self)
      print("Training multi-view generation model...")
      trainer.train(num_epochs=100)

      print("\nGenerating novel views...")
      for uid in self.dataset:
          try:
              # Get source image and pose
              images, _ = self.load_object_views(uid)
              source_image = images[0]

              # Generate novel views by rotating around the object
              novel_views = []
              for i in range(num_views):
                  angle = (i * 2 * np.pi) / num_views
                  target_pose = np.array([
                      np.cos(angle), np.sin(angle), 0,  # Rotation
                      0, 0, 2  # Translation
                  ])

                  novel_view = trainer.generate_novel_view(source_image, target_pose)
                  novel_views.append(novel_view)

              # Visualize results
              plt.figure(figsize=(20, 4))
              for i, view in enumerate(novel_views):
                  plt.subplot(1, num_views, i + 1)
                  plt.imshow(view)
                  plt.axis('off')
                  plt.title(f'View {i+1}')
              plt.suptitle(f'Novel Views for Object {uid}')
              plt.show()

          except Exception as e:
              print(f"Error generating novel views for object {uid}: {str(e)}")
              continue

def visualize_object_data(images, nocs_maps, uid):
    """Visualize both images and NOCS maps"""
    n_views = len(images)
    plt.figure(figsize=(20, 8))

    # Show original images
    for i in range(n_views):
        plt.subplot(2, n_views, i + 1)
        plt.imshow(images[i])
        plt.axis('off')
        plt.title(f'View {i+1}')

    # Show NOCS maps
    for i in range(n_views):
        plt.subplot(2, n_views, n_views + i + 1)
        plt.imshow(nocs_maps[i])
        plt.axis('off')
        plt.title(f'NOCS {i+1}')

    plt.suptitle(f'Object ID: {uid}')
    plt.tight_layout()
    plt.show()

def visualize_views(images, uid):
    """Visualize multiple views of an object"""
    plt.figure(figsize=(20, 4))
    for i, img in enumerate(images):
        plt.subplot(1, len(images), i+1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'View {i+1}')
    plt.suptitle(f'Object ID: {uid}')
    plt.show()

# Test the implementation
def setup_and_test():
    """Setup and test the data loader with error handling"""
    try:
        print("Creating data loader...")
        data_loader = SimpleDataLoader(image_size=512)

        if not data_loader.dataset:
            print("Error: dataset is empty")
            return None

        print(f"Dataset initialized with {len(data_loader.dataset)} objects")

        # Try loading first object
        uids = list(data_loader.dataset.keys())
        if not uids:
            print("Error: No objects found in dataset")
            return None

        print(f"Found objects with IDs: {uids}")

        for uid in uids:
            try:
                print(f"\nLoading object {uid}...")

                # Verify object data structure
                obj_data = data_loader.dataset[uid]
                print(f"Category: {obj_data['category']}")
                print(f"Number of images: {len(obj_data['images'])}")
                print(f"Number of NOCS maps: {len(obj_data['nocs_maps'])}")

                images, nocs_maps = data_loader.load_object_views(uid)
                print(f"Loaded {len(images)} images and {len(nocs_maps)} NOCS maps")

                visualize_object_data(images, nocs_maps, uid)

            except Exception as e:
                print(f"Error processing object {uid}: {str(e)}")
                print("Object data:", data_loader.dataset[uid])
                continue

        print("Success!")
        return data_loader

    except Exception as e:
        print(f"Error during setup: {str(e)}")
        import traceback
        traceback.print_exc()
        return None



# Main execution
# if __name__ == "__main__":
#     print("Starting test...")

#     # Clear the matplotlib plots
#     plt.close('all')

#     # Run setup and test with error handling
#     try:
#         data_dir = Path('./test_data')
#         if data_dir.exists():
#             shutil.rmtree(data_dir)
#         data_loader = setup_and_test()

#         # if data_loader is None:
#         #     print("Failed to initialize data loader")
#         # else:
#         #     print("\nData loader initialized successfully")
#         #     print(f"Number of objects: {len(data_loader.dataset)}")

#         #     # Try to show one object
#         #     if len(data_loader.dataset) > 0:
#         #         uid = list(data_loader.dataset.keys())[0]
#         #         print(f"\nTrying to show first object ({uid})...")
#         #         images, nocs_maps = data_loader.load_object_views(uid)
#         #         visualize_object_data(images, nocs_maps, uid)

#     except Exception as e:
#         print(f"Error in main execution: {str(e)}")
#         import traceback
#         traceback.print_exc()

In [4]:
import numpy as np
import cv2
from scipy.spatial.transform import Rotation
from typing import Tuple, List

class PoseEstimator:
    def __init__(self, image_size: int = 512):
        self.image_size = image_size
        # Define camera intrinsics (can be adjusted based on your needs)
        self.focal_length = image_size
        self.camera_matrix = np.array([
            [self.focal_length, 0, image_size/2],
            [0, self.focal_length, image_size/2],
            [0, 0, 1]
        ])

    def extract_correspondences(self,
                              rgb_image: np.ndarray,
                              nocs_map: np.ndarray,
                              min_points: int = 100) -> Tuple[np.ndarray, np.ndarray]:
        """
        Extract corresponding points between RGB image and NOCS map
        """
        # Convert images to numpy if they're PIL
        if not isinstance(rgb_image, np.ndarray):
            rgb_image = np.array(rgb_image)
        if not isinstance(nocs_map, np.ndarray):
            nocs_map = np.array(nocs_map)

        # Find features in RGB image
        orb = cv2.ORB_create(nfeatures=1000)
        keypoints = orb.detect(rgb_image, None)

        # Get 2D points from keypoints
        points_2d = []
        points_3d = []

        for kp in keypoints:
            x, y = map(int, kp.pt)

            # Get corresponding 3D point from NOCS map
            nocs_point = nocs_map[y, x] / 255.0  # Normalize to [0,1]

            # Check if point is valid (not background)
            if np.any(nocs_point > 0):
                # Convert NOCS coordinates to actual 3D coordinates
                x3d = nocs_point[0] * 2 - 1  # Convert to [-1, 1]
                y3d = nocs_point[1] * 2 - 1
                z3d = nocs_point[2] * 2 - 1

                points_2d.append([x, y])
                points_3d.append([x3d, y3d, z3d])

        points_2d = np.array(points_2d, dtype=np.float32)
        points_3d = np.array(points_3d, dtype=np.float32)

        if len(points_2d) < min_points:
            raise ValueError(f"Not enough correspondences found: {len(points_2d)} < {min_points}")

        return points_2d, points_3d

    def estimate_pose(self,
                     rgb_image: np.ndarray,
                     nocs_map: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Estimate camera pose from RGB image and NOCS map
        Returns: rotation matrix, translation vector
        """
        # Get corresponding points
        points_2d, points_3d = self.extract_correspondences(rgb_image, nocs_map)

        # Estimate pose using PnP
        success, rvec, tvec = cv2.solvePnP(
            points_3d,
            points_2d,
            self.camera_matrix,
            None,
            flags=cv2.SOLVEPNP_ITERATIVE
        )

        if not success:
            raise RuntimeError("Failed to estimate pose")

        # Convert rotation vector to matrix
        rmat, _ = cv2.Rodrigues(rvec)

        return rmat, tvec

    def visualize_pose_estimation(self,
                                rgb_image: np.ndarray,
                                nocs_map: np.ndarray,
                                rotation: np.ndarray,
                                translation: np.ndarray) -> np.ndarray:
        """
        Visualize pose estimation results
        """
        # Convert images to numpy if they're PIL
        if not isinstance(rgb_image, np.ndarray):
            rgb_image = np.array(rgb_image)

        # Create visualization image
        vis_img = rgb_image.copy()

        # Draw coordinate axes
        axis_length = 100
        axis_points = np.float32([[0,0,0], [1,0,0], [0,1,0], [0,0,1]]) * axis_length

        # Project 3D axis points to 2D
        axis_2d, _ = cv2.projectPoints(
            axis_points,
            cv2.Rodrigues(rotation)[0],
            translation,
            self.camera_matrix,
            None
        )

        # Draw axes
        origin = tuple(map(int, axis_2d[0].ravel()))
        vis_img = cv2.line(vis_img, origin, tuple(map(int, axis_2d[1].ravel())), (0,0,255), 3)  # X-axis (red)
        vis_img = cv2.line(vis_img, origin, tuple(map(int, axis_2d[2].ravel())), (0,255,0), 3)  # Y-axis (green)
        vis_img = cv2.line(vis_img, origin, tuple(map(int, axis_2d[3].ravel())), (255,0,0), 3)  # Z-axis (blue)

        return vis_img

# Add this to your SimpleDataLoader class
def estimate_poses(self):
    """
    Estimate poses for all views of all objects
    """
    pose_estimator = PoseEstimator(self.image_size)

    for uid in self.dataset.keys():
        try:
            images, nocs_maps = self.load_object_views(uid)

            # Store poses for this object
            poses = []

            for img, nocs in zip(images, nocs_maps):
                try:
                    # Estimate pose
                    rotation, translation = pose_estimator.estimate_pose(np.array(img), np.array(nocs))

                    # Create visualization
                    vis_img = pose_estimator.visualize_pose_estimation(
                        np.array(img),
                        np.array(nocs),
                        rotation,
                        translation
                    )

                    # Store pose data
                    poses.append({
                        'rotation': rotation.tolist(),
                        'translation': translation.tolist()
                    })

                    # Visualize results
                    plt.figure(figsize=(15, 5))

                    plt.subplot(131)
                    plt.imshow(img)
                    plt.title('Original Image')
                    plt.axis('off')

                    plt.subplot(132)
                    plt.imshow(nocs)
                    plt.title('NOCS Map')
                    plt.axis('off')

                    plt.subplot(133)
                    plt.imshow(vis_img)
                    plt.title('Pose Estimation')
                    plt.axis('off')

                    plt.suptitle(f'Object {uid} - Pose Estimation Results')
                    plt.show()

                except Exception as e:
                    print(f"Error estimating pose for view: {str(e)}")
                    poses.append(None)

            # Store poses in dataset
            self.dataset[uid]['poses'] = poses
            self._save_dataset()

        except Exception as e:
            print(f"Error processing object {uid}: {str(e)}")
            continue

# # Main execution
# if __name__ == "__main__":
#     print("Starting test...")

#     # Clear the matplotlib plots
#     plt.close('all')

#     # Run setup and test with error handling
#     try:
#         data_dir = Path('./test_data')
#         if data_dir.exists():
#             shutil.rmtree(data_dir)
#         data_loader = setup_and_test()

#         if data_loader is None:
#             print("Failed to initialize data loader")
#         else:
#             if data_loader is not None:
#               # Estimate poses for all objects
#               print("\nEstimating poses...")
#               data_loader.estimate_poses()

#               # You can access the poses through the dataset
#               for uid in data_loader.dataset:
#                   poses = data_loader.dataset[uid].get('poses', [])
#                   print(f"\nObject {uid} poses:")
#                   for i, pose in enumerate(poses):
#                       if pose:
#                           print(f"View {i}:")
#                           print(f"Rotation:\n{np.array(pose['rotation'])}")
#                           print(f"Translation:\n{np.array(pose['translation'])}")

#     except Exception as e:
#         print(f"Error in main execution: {str(e)}")
#         import traceback
#         traceback.print_exc()




In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image

# First, let's test if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class MultiViewGenerator(nn.Module):
    def __init__(self, latent_dim=256):
        super(MultiViewGenerator, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, latent_dim, 4, stride=2, padding=1),
        )

        # Pose encoder
        self.pose_encoder = nn.Sequential(
            nn.Linear(6, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim * 2, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x, target_pose):
        img_features = self.encoder(x)
        pose_features = self.pose_encoder(target_pose)
        pose_features = pose_features.view(-1, pose_features.size(1), 1, 1)
        pose_features = pose_features.expand(-1, -1, img_features.size(2), img_features.size(3))
        combined_features = torch.cat([img_features, pose_features], dim=1)
        return self.decoder(combined_features)

class MultiViewDataset(Dataset):
    def __init__(self, data_loader, transform=None):
        self.data_loader = data_loader
        self.transform = transform
        self.samples = []
        self._prepare_dataset()

    def _prepare_dataset(self):
        for uid in self.data_loader.dataset:
            images = []
            poses = []

            obj_data = self.data_loader.dataset[uid]
            for img_path, pose_data in zip(obj_data['images'], obj_data.get('poses', [])):
                if pose_data is not None:
                    img = Image.open(img_path)
                    if self.transform:
                        img = self.transform(img)

                    rotation = np.array(pose_data['rotation'])
                    translation = np.array(pose_data['translation'])
                    pose_vector = np.concatenate([
                        rotation.flatten()[:3],
                        translation.flatten()
                    ])

                    images.append(img)
                    poses.append(pose_vector)

            for i in range(len(images)):
                for j in range(len(images)):
                    if i != j:
                        self.samples.append({
                            'source_image': images[i],
                            'source_pose': poses[i],
                            'target_pose': poses[j],
                            'target_image': images[j]
                        })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return {
            'source_image': sample['source_image'],
            'source_pose': torch.FloatTensor(sample['source_pose']),
            'target_pose': torch.FloatTensor(sample['target_pose']),
            'target_image': sample['target_image']
        }

class MultiViewTrainer:
    def __init__(self, data_loader, device=device):
        self.data_loader = data_loader
        self.device = device

        # Initialize model
        self.model = MultiViewGenerator().to(device)

        # Initialize transforms
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Initialize dataset
        print("Initializing dataset...")
        self.dataset = MultiViewDataset(data_loader, self.transform)
        print(f"Dataset size: {len(self.dataset)} pairs")

        # Initialize data loader
        self.train_loader = DataLoader(
            self.dataset,
            batch_size=4,
            shuffle=True,
            num_workers=0  # Changed to 0 to avoid potential multiprocessing issues
        )

        # Initialize optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0002)

        # Loss functions
        self.reconstruction_loss = nn.L1Loss()
        self.perceptual_loss = self._get_perceptual_loss()

    def _get_perceptual_loss(self):
        vgg = torchvision.models.vgg16(pretrained=True).features[:16]
        vgg = vgg.to(self.device)
        for param in vgg.parameters():
            param.requires_grad = False
        return vgg

    def train(self, num_epochs=100):
        print("Starting training...")
        for epoch in range(num_epochs):
            total_loss = 0
            for batch in tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                source_images = batch['source_image'].to(self.device)
                target_poses = batch['target_pose'].to(self.device)
                target_images = batch['target_image'].to(self.device)

                generated_images = self.model(source_images, target_poses)

                recon_loss = self.reconstruction_loss(generated_images, target_images)

                gen_features = self.perceptual_loss(generated_images)
                target_features = self.perceptual_loss(target_images)
                percep_loss = self.reconstruction_loss(gen_features, target_features)

                loss = recon_loss + 0.1 * percep_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(self.train_loader)
            print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

            if (epoch + 1) % 10 == 0:
                self.visualize_results()

    def visualize_results(self):
        self.model.eval()
        with torch.no_grad():
            batch = next(iter(self.train_loader))
            source_images = batch['source_image'].to(self.device)
            target_poses = batch['target_pose'].to(self.device)
            target_images = batch['target_image'].to(self.device)

            generated_images = self.model(source_images, target_poses)

            plt.figure(figsize=(15, 5))
            for i in range(min(4, len(source_images))):
                plt.subplot(3, 4, i + 1)
                self._show_tensor_image(source_images[i])
                plt.title('Source')

                plt.subplot(3, 4, i + 5)
                self._show_tensor_image(generated_images[i])
                plt.title('Generated')

                plt.subplot(3, 4, i + 9)
                self._show_tensor_image(target_images[i])
                plt.title('Target')

            plt.tight_layout()
            plt.show()
        self.model.train()

    def _show_tensor_image(self, tensor):
        img = tensor.cpu().detach()
        img = img * 0.5 + 0.5
        plt.imshow(img.permute(1, 2, 0))
        plt.axis('off')

    def generate_novel_view(self, source_image, target_pose):
        self.model.eval()
        with torch.no_grad():
            if not isinstance(source_image, torch.Tensor):
                source_image = self.transform(source_image)
            source_image = source_image.unsqueeze(0).to(self.device)
            target_pose = torch.FloatTensor(target_pose).unsqueeze(0).to(self.device)

            generated_image = self.model(source_image, target_pose)

            generated_image = generated_image.cpu().squeeze()
            generated_image = generated_image * 0.5 + 0.5
            generated_image = transforms.ToPILImage()(generated_image)

        return generated_image

def generate_novel_views(data_loader, num_views=8):
    """Generate novel views for each object"""
    print("Initializing trainer...")
    trainer = MultiViewTrainer(data_loader)

    print("Training multi-view generation model...")
    trainer.train(num_epochs=100)

    print("\nGenerating novel views...")
    for uid in data_loader.dataset:
        try:
            images, _ = data_loader.load_object_views(uid)
            source_image = images[0]

            novel_views = []
            for i in range(num_views):
                angle = (i * 2 * np.pi) / num_views
                target_pose = np.array([
                    np.cos(angle), np.sin(angle), 0,
                    0, 0, 2
                ])

                novel_view = trainer.generate_novel_view(source_image, target_pose)
                novel_views.append(novel_view)

            plt.figure(figsize=(20, 4))
            for i, view in enumerate(novel_views):
                plt.subplot(1, num_views, i + 1)
                plt.imshow(view)
                plt.axis('off')
                plt.title(f'View {i+1}')
            plt.suptitle(f'Novel Views for Object {uid}')
            plt.show()

        except Exception as e:
            print(f"Error generating novel views for object {uid}: {str(e)}")
            continue

# Test the implementation
# if __name__ == "__main__":
#     # First make sure we have pose estimates
#     data_loader = setup_and_test()
#     data_loader.estimate_poses()

#     # Generate novel views
#     print("\nGenerating novel views...")
#     generate_novel_views(data_loader, num_views=8)

Using device: cuda


In [6]:
!pip install tqdm
# First, install dependencies
!pip install fvcore iopath
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git'
!pip install PyMCubes

Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting portalocker (from iopath)
  Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Downloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Building wheels for collected packages: fvcore, iopath
  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Created wheel for fvcore: filename=fvcore-0.1.5.post20221221-py3-none-any.whl size

In [1]:
# First, install required packages
!pip install scikit-image

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage import measure
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import torchvision.transforms as transforms

class SimpleReconstruction3D(nn.Module):
    def __init__(self, image_size=512, voxel_size=64):
        super().__init__()
        self.image_size = image_size
        self.voxel_size = voxel_size

        # Feature extraction
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.ReLU()
        )

        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(512 * 32 * 32, 2048),
            nn.ReLU(),
            nn.Linear(2048, voxel_size * voxel_size * voxel_size),
            nn.Sigmoid()  # Added sigmoid to normalize values between 0 and 1
        )

    def forward(self, x):
        batch_size = x.size(0)
        num_views = x.size(1)

        # Reshape input to process all views
        x = x.view(batch_size * num_views, 3, self.image_size, self.image_size)

        # Extract features
        features = self.encoder(x)

        # Reshape and fuse features
        features = features.view(batch_size, num_views, -1)
        features = torch.max(features, dim=1)[0]  # max pooling across views

        # Generate 3D volume
        volume = self.fusion(features)
        volume = volume.view(batch_size, 1, self.voxel_size, self.voxel_size, self.voxel_size)

        return volume


class Reconstruction3DTrainer:
    def __init__(self, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        self.device = device
        self.model = SimpleReconstruction3D().to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def preprocess_images(self, images):
        """Preprocess images for the model"""
        processed = []
        for img in images:
            if isinstance(img, np.ndarray):
                img = Image.fromarray(img)
            processed.append(self.transform(img))
        return torch.stack(processed)

    def generate_3d_volume(self, images):
        """Generate 3D volume from images"""
        self.model.eval()
        with torch.no_grad():
            # Preprocess images
            images_tensor = self.preprocess_images(images)
            images_tensor = images_tensor.unsqueeze(0).to(self.device)

            # Generate volume
            volume = self.model(images_tensor)

            # Ensure volume has values between 0 and 1
            volume = torch.sigmoid(volume)

            return volume.squeeze().cpu().numpy()

def volume_to_mesh(volume, threshold=0.5):
    """Convert volume to mesh using marching cubes"""
    # Ensure volume is correct shape and type
    volume = volume.squeeze()
    if isinstance(volume, torch.Tensor):
        volume = volume.cpu().numpy()

    # Print volume statistics for debugging
    print(f"Volume min: {volume.min()}, max: {volume.max()}, mean: {volume.mean()}")

    # Normalize volume to [0, 1] if needed
    if volume.min() < 0 or volume.max() > 1:
        volume = (volume - volume.min()) / (volume.max() - volume.min())

    # Ensure we have some values above and below threshold
    if volume.max() <= threshold or volume.min() >= threshold:
        print("Warning: Volume values don't cross threshold. Adjusting threshold...")
        threshold = (volume.max() + volume.min()) / 2

    # Extract surface mesh
    try:
        verts, faces, normals, values = measure.marching_cubes(volume, threshold)
        return verts, faces
    except Exception as e:
        print(f"Error in marching cubes: {e}")
        # Try different threshold if original fails
        try:
            new_threshold = volume.mean()
            print(f"Retrying with threshold = {new_threshold}")
            verts, faces, normals, values = measure.marching_cubes(volume, new_threshold)
            return verts, faces
        except Exception as e:
            print(f"Second attempt failed: {e}")
            return None, None

def plot_volume_slices(volume, num_slices=3):
    """Visualize slices of the 3D volume"""
    fig, axes = plt.subplots(1, num_slices, figsize=(15, 5))

    for i in range(num_slices):
        slice_idx = volume.shape[0] // (num_slices + 1) * (i + 1)
        axes[i].imshow(volume[slice_idx], cmap='viridis')
        axes[i].set_title(f'Slice {slice_idx}')
        axes[i].axis('off')

    plt.suptitle('Volume Slices')
    plt.show()

def reconstruct_3d_shape(data_loader, object_id):
    """Reconstruct 3D shape for a specific object"""
    try:
        # Get images
        images, _ = data_loader.load_object_views(object_id)

        # Initialize trainer
        trainer = Reconstruction3DTrainer()

        # Generate 3D volume
        volume = trainer.generate_3d_volume(images)

        # Convert to mesh
        verts, faces = volume_to_mesh(volume)

        return verts, faces

    except Exception as e:
        print(f"Error reconstructing shape: {str(e)}")
        return None, None

def process_object(data_loader, object_id):
    """Process a single object"""
    print(f"\nProcessing object {object_id}...")

    try:
        # Reconstruct shape
        verts, faces = reconstruct_3d_shape(data_loader, object_id)

        if verts is not None and faces is not None:
            # Visualize reconstruction
            plot_3d_mesh(verts, faces)
            print("Reconstruction successful!")

            # Visualize volume slices for debugging
            volume = data_loader.dataset[object_id].get('volume', None)
            if volume is not None:
                plot_volume_slices(volume)
        else:
            print("Reconstruction failed")
    except Exception as e:
        print(f"Error processing object {object_id}: {str(e)}")
        import traceback
        traceback.print_exc()

# Test the reconstruction
if __name__ == "__main__":
    print("Starting 3D reconstruction...")

    # Get data loader from previous steps
    data_loader = setup_and_test()
    data_loader.estimate_poses()

    # Process first object
    first_object_id = list(data_loader.dataset.keys())[0]
    process_object(data_loader, first_object_id)

    second_object_id = list(data_loader.dataset.keys())[1]
    process_object(data_loader, second_object_id)

    third_object_id = list(data_loader.dataset.keys())[2]
    process_object(data_loader, third_object_id)

Starting 3D reconstruction...


NameError: name 'setup_and_test' is not defined