# RingGen: Practical Text-to-3D Ring Generation Pipeline

This notebook provides a streamlined, practical implementation of the RingGen pipeline for training and generating 3D ring models on Google Colab with GPU acceleration. It focuses on:

1. Setting up the environment efficiently
2. Accessing training data from Google Drive
3. Training Shap-E and CAP3D models with GPU acceleration
4. Generating high-quality 3D ring models from text prompts
5. Visualizing and exporting the results

## Step 1: Check GPU Availability and Setup Environment

In [None]:
# Check if GPU is available
!nvidia-smi

# Check Python version
!python --version

# Import torch and check CUDA availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
    device = "cuda"
else:
    print("WARNING: No GPU detected. Training will be slow on CPU.")
    device = "cpu"

# Install required packages
!pip install torch trimesh numpy matplotlib tqdm requests plotly

## Step 2: Clone the Repository and Install Dependencies

Clone the RingGen repository and install all necessary dependencies including Shap-E.

In [None]:
# Clone the repository
!git clone https://github.com/abhyodaya1011/text-to-3d-pipeline.git
%cd text-to-3d-pipeline

# Install Shap-E and other dependencies
!pip install -e .
!python setup_shap_e.py

# Create necessary directories
!mkdir -p data/rings
!mkdir -p data/labeled_meshes
!mkdir -p data/latents
!mkdir -p shap_e_model_cache
!mkdir -p outputs/training
!mkdir -p outputs/generated

## Step 3: Mount Google Drive and Access Data

Mount Google Drive to access training data and save results.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Path to your data on Google Drive
# Update these paths to match your Google Drive structure
DRIVE_DATA_PATH = "/content/drive/MyDrive/RingGen/data/rings"
DRIVE_MODEL_CACHE_PATH = "/content/drive/MyDrive/RingGen/shap_e_model_cache"
DRIVE_OUTPUTS_PATH = "/content/drive/MyDrive/RingGen/outputs"

# Create directories in Google Drive if they don't exist
!mkdir -p "/content/drive/MyDrive/RingGen/data/rings"
!mkdir -p "/content/drive/MyDrive/RingGen/shap_e_model_cache"
!mkdir -p "/content/drive/MyDrive/RingGen/outputs/generated"
!mkdir -p "/content/drive/MyDrive/RingGen/outputs/training"

# Check if the paths exist and copy data if available
import os
import glob

# Function to find and copy ring data
def find_and_copy_ring_data():
    if os.path.exists(DRIVE_DATA_PATH) and os.listdir(DRIVE_DATA_PATH):
        print(f"Copying ring data from {DRIVE_DATA_PATH}...")
        !cp -r $DRIVE_DATA_PATH/* data/rings/
        return True
    else:
        print("Searching for OBJ files in Google Drive...")
        ring_files = glob.glob("/content/drive/MyDrive/**/*.obj", recursive=True)
        print(f"Found {len(ring_files)} OBJ files")
        
        if len(ring_files) > 0:
            for file in ring_files[:5]:  # Show first 5 files
                print(f"Found: {file}")
            
            # Ask if user wants to copy these files
            user_input = input("Do you want to copy these files to the data directory? (y/n): ")
            if user_input.lower() == 'y':
                for file in ring_files:
                    !cp "$file" data/rings/
                print(f"Copied {len(ring_files)} files to data/rings/")
                return True
        return False

# Copy Shap-E model cache if available
if os.path.exists(DRIVE_MODEL_CACHE_PATH) and os.listdir(DRIVE_MODEL_CACHE_PATH):
    print(f"Copying Shap-E model cache from {DRIVE_MODEL_CACHE_PATH}...")
    !cp -r $DRIVE_MODEL_CACHE_PATH/* shap_e_model_cache/

# Copy existing outputs if available
if os.path.exists(DRIVE_OUTPUTS_PATH) and os.listdir(DRIVE_OUTPUTS_PATH):
    print(f"Copying existing outputs from {DRIVE_OUTPUTS_PATH}...")
    !cp -r $DRIVE_OUTPUTS_PATH/* outputs/

# Find and copy ring data
has_data = find_and_copy_ring_data()

if not has_data:
    print("\nNo ring data found. You can upload OBJ files to Google Drive and run this cell again.")
    print("Alternatively, you can use the sample prompts to generate rings without training.")

## Step 4: Training Pipeline

This section runs the training pipeline using the `train_with_labeled_data.py` script, which is the most comprehensive and robust implementation for training Shap-E and CAP3D models.

In [None]:
# Check if we have data for training
import os
import glob

ring_files = glob.glob("data/rings/*.obj")
print(f"Found {len(ring_files)} OBJ files for training")

# Ask if the user wants to run training
run_training = True
if len(ring_files) == 0:
    print("No training data available. Skipping training.")
    run_training = False
else:
    user_input = input("Do you want to run the training pipeline? (y/n): ")
    run_training = user_input.lower() == 'y'

if run_training:
    # Set training parameters
    max_files = 10  # Limit number of files for faster training
    epochs = 5      # Number of training epochs
    batch_size = 4  # Batch size for training
    
    # Allow user to customize parameters
    custom_params = input("Do you want to customize training parameters? (y/n): ")
    if custom_params.lower() == 'y':
        try:
            max_files = int(input(f"Max files to use for training (current: {max_files}): ") or max_files)
            epochs = int(input(f"Number of training epochs (current: {epochs}): ") or epochs)
            batch_size = int(input(f"Batch size (current: {batch_size}): ") or batch_size)
        except ValueError:
            print("Invalid input. Using default values.")
    
    print("\n=== Starting Training Pipeline ===\n")
    print(f"Parameters: max_files={max_files}, epochs={epochs}, batch_size={batch_size}, device={device}")
    
    # Run the training script
    !python train_with_labeled_data.py \
        --input data/rings \
        --output outputs/training \
        --max-files {max_files} \
        --epochs {epochs} \
        --batch-size {batch_size} \
        --device {device}
    
    # Copy training outputs to Google Drive
    print("\nCopying training outputs to Google Drive...")
    !mkdir -p "/content/drive/MyDrive/RingGen/outputs/training"
    !cp -r outputs/training/* "/content/drive/MyDrive/RingGen/outputs/training/"
else:
    print("Skipping training phase. Will use existing models for generation if available.")

## Step 5: Generate 3D Ring Models

Generate 3D ring models from text prompts using the trained models or pre-trained models.

In [None]:
# Check if we have models for generation
import os

shap_e_model = os.path.join("outputs/training/shap_e_checkpoints/checkpoints/shap_e_model.pt")
cap3d_model = os.path.join("outputs/training/cap3d_checkpoints/checkpoints/cap3d_model.pt")

models_exist = os.path.exists(shap_e_model) and os.path.exists(cap3d_model)

if not models_exist:
    print("Warning: Trained models not found. Will use default models for generation.")
    # Check if we have models in Google Drive
    drive_shap_e = "/content/drive/MyDrive/RingGen/outputs/training/shap_e_checkpoints/checkpoints/shap_e_model.pt"
    drive_cap3d = "/content/drive/MyDrive/RingGen/outputs/training/cap3d_checkpoints/checkpoints/cap3d_model.pt"
    
    if os.path.exists(drive_shap_e) and os.path.exists(drive_cap3d):
        print("Found models in Google Drive. Copying...")
        !mkdir -p outputs/training/shap_e_checkpoints/checkpoints
        !mkdir -p outputs/training/cap3d_checkpoints/checkpoints
        !cp "$drive_shap_e" "$shap_e_model"
        !cp "$drive_cap3d" "$cap3d_model"
        models_exist = True

# Set up prompts for generation
default_prompts = [
    "A classic solitaire engagement ring with a round diamond and thin gold band",
    "A vintage-inspired ring with three diamonds and intricate gallery details",
    "A modern minimalist ring with a princess cut diamond and platinum band",
    "An art deco style ring with emerald and diamond accents",
    "A men's wedding band with brushed titanium and carbon fiber inlay"
]

# Allow user to customize prompts
custom_prompts = input("Do you want to use custom prompts? (y/n): ")
if custom_prompts.lower() == 'y':
    prompts = []
    print("Enter your prompts (one per line, press Enter twice to finish):")
    while True:
        prompt = input()
        if not prompt:
            break
        prompts.append(prompt)
    
    if not prompts:  # If no prompts were entered
        print("No prompts entered. Using default prompts.")
        prompts = default_prompts
else:
    prompts = default_prompts

# Set number of samples per prompt
num_samples = 2
try:
    num_samples = int(input(f"Number of samples per prompt (current: {num_samples}): ") or num_samples)
except ValueError:
    print("Invalid input. Using default value.")

# Run generation
print("\n=== Starting Ring Generation ===\n")
print(f"Generating {len(prompts)} prompts with {num_samples} samples each")
for i, prompt in enumerate(prompts):
    print(f"Prompt {i+1}: {prompt}")

# Generate rings using the standalone generate_rings.py script
prompt_arg = ",".join([f"\"{p}\"" for p in prompts])
!python generate_rings.py \
    --shap-e-model "$shap_e_model" \
    --cap3d-model "$cap3d_model" \
    --output outputs/generated \
    --prompts {prompt_arg} \
    --num-samples {num_samples} \
    --device {device}

# Copy generated outputs to Google Drive
print("\nCopying generated outputs to Google Drive...")
!mkdir -p "/content/drive/MyDrive/RingGen/outputs/generated"
!cp -r outputs/generated/* "/content/drive/MyDrive/RingGen/outputs/generated/"

## Step 6: Visualize Generated Rings

Visualize the generated ring models using both Matplotlib and Plotly for interactive 3D viewing.

In [None]:
import os
import glob
import json
import numpy as np
import trimesh
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_mesh_matplotlib(mesh_path, title):
    """Visualize a 3D mesh using Matplotlib."""
    try:
        # Load the mesh
        mesh = trimesh.load(mesh_path)
        
        # Create a figure and 3D axes
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        ax.set_title(title, fontsize=14)
        
        # Plot the mesh faces
        vertices = mesh.vertices
        faces = mesh.faces
        
        # Plot each face as a triangular polygon
        for face in faces:
            face_vertices = vertices[face]
            x = face_vertices[:, 0]
            y = face_vertices[:, 1]
            z = face_vertices[:, 2]
            ax.plot_trisurf(x, y, z, color='gold', alpha=0.8, shade=True)
        
        # Set equal aspect ratio
        ax.set_box_aspect([1, 1, 1])
        
        # Hide axis labels
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_zticklabels([])
        
        # Center the view
        center = mesh.centroid
        bounds = mesh.bounds
        max_range = np.max(bounds[1] - bounds[0])
        ax.set_xlim(center[0] - max_range/2, center[0] + max_range/2)
        ax.set_ylim(center[1] - max_range/2, center[1] + max_range/2)
        ax.set_zlim(center[2] - max_range/2, center[2] + max_range/2)
        
        plt.tight_layout()
        plt.show()
        return True
    except Exception as e:
        print(f"Error visualizing mesh: {e}")
        return False

# Find all generated mesh files
mesh_files = glob.glob("outputs/generated/*.obj")
print(f"Found {len(mesh_files)} generated mesh files")

# Visualize each mesh
for mesh_path in mesh_files:
    filename = os.path.basename(mesh_path).replace(".obj", "")
    
    # Try to get the prompt from metadata file
    metadata_path = mesh_path.replace(".obj", "_metadata.json")
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
            prompt = metadata.get("prompt", filename)
    else:
        prompt = filename
        
    print(f"Visualizing: {prompt}")
    visualize_mesh_matplotlib(mesh_path, prompt)

In [None]:
# Now let's visualize with Plotly for interactive 3D viewing
import plotly.graph_objects as go

def visualize_mesh_plotly(mesh_path, title):
    """Visualize a 3D mesh using Plotly for interactive viewing."""
    try:
        # Load the mesh
        mesh = trimesh.load(mesh_path)
        
        # Get vertices and faces
        vertices = mesh.vertices
        faces = mesh.faces
        
        # Create the plotly figure
        fig = go.Figure(data=[
            go.Mesh3d(
                x=vertices[:, 0],
                y=vertices[:, 1],
                z=vertices[:, 2],
                i=faces[:, 0],
                j=faces[:, 1],
                k=faces[:, 2],
                color='gold',
                opacity=0.9,
                flatshading=True
            )
        ])
        
        # Update layout
        fig.update_layout(
            title=title,
            scene=dict(
                xaxis=dict(showticklabels=False),
                yaxis=dict(showticklabels=False),
                zaxis=dict(showticklabels=False),
            ),
            width=700,
            height=700,
            margin=dict(l=0, r=0, t=40, b=0)
        )
        
        fig.show()
        return True
    except Exception as e:
        print(f"Error visualizing mesh with Plotly: {e}")
        return False

# Visualize each mesh with Plotly
for mesh_path in mesh_files:
    filename = os.path.basename(mesh_path).replace(".obj", "")
    
    # Try to get the prompt from metadata file
    metadata_path = mesh_path.replace(".obj", "_metadata.json")
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
            prompt = metadata.get("prompt", filename)
    else:
        prompt = filename
        
    print(f"Interactive visualization: {prompt}")
    visualize_mesh_plotly(mesh_path, prompt)

## Step 7: Analyze Ring Components and Metadata

Examine the component labels and metadata of the generated rings.

In [None]:
import pandas as pd
from IPython.display import display

# Find component JSON files
component_files = glob.glob('outputs/generated/*_metadata.json')

if component_files:
    print(f"Found {len(component_files)} component files")
    for comp_file in component_files:
        with open(comp_file, 'r') as f:
            metadata = json.load(f)
        
        print(f"\nMetadata for {os.path.basename(comp_file)}:")
        print(f"Prompt: {metadata.get('prompt', 'N/A')}")
        
        if 'components' in metadata:
            print("Components:")
            df = pd.DataFrame(metadata['components'])
            display(df)
        else:
            print("No component data found in this file.")
else:
    print("No component files found.")

## Step 8: Create Downloadable Archive

Create a zip file of all outputs for easy download.

In [None]:
# Create a zip file for download
print("Creating zip file of outputs...")
!zip -r ringgen_outputs.zip outputs/

# Download the outputs
from google.colab import files
files.download('ringgen_outputs.zip')

print("All results saved to Google Drive and available for download!")

## Conclusion

You've successfully run the streamlined RingGen pipeline on Google Colab with GPU acceleration! This notebook provides a practical implementation for:

1. Training Shap-E and CAP3D models on your own ring data
2. Generating high-quality 3D ring models from text prompts
3. Visualizing and analyzing the generated models
4. Saving all results for further use

### Tips for Best Results

1. **For training**: Use at least 10-20 high-quality ring models for better results
2. **For generation**: Be specific in your prompts about style, materials, and design elements
3. **For performance**: Use GPU acceleration whenever possible
4. **For storage**: Save important models and results to Google Drive

### Next Steps

1. **Fine-tune the models**: Experiment with different training parameters
2. **Expand your dataset**: Add more diverse ring models for training
3. **Export to CAD**: Use the generated models in CAD software for manufacturing
4. **Customize the pipeline**: Modify the scripts to suit your specific needs

Remember that Colab sessions have time limits (usually around 12 hours for free tier), so for longer training sessions, you might need to use Colab Pro or another GPU platform.