<a href="https://colab.research.google.com/github/Nobobi-Hasan/PointNeXt-PartSegmentation-FallenTrees/blob/main/PointNeXt_02_04_Training_shapenetpart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
import os
import shutil
import torch
import subprocess
import sys

In [2]:
# Mount Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Path to the project root in Drive
DRIVE_PROJECT_ROOT = "/content/drive/MyDrive/ML_Projects/PointNeXt"

# Subfolders
DRIVE_DATA_DIR = os.path.join(DRIVE_PROJECT_ROOT, "Data")
DRIVE_MODELS_DIR = os.path.join(DRIVE_PROJECT_ROOT, "Models")

# Input paths
DRIVE_ZIP_PATH = os.path.join(DRIVE_DATA_DIR, "processed_data.zip")
LOCAL_DATA_DIR = "/content/processed_data"

print(f"Project Root: {DRIVE_PROJECT_ROOT}")

Project Root: /content/drive/MyDrive/ML_Projects/PointNeXt


In [4]:
# Copy processed data from Drive
if not os.path.exists("/content/processed_data"):
    if os.path.exists(DRIVE_ZIP_PATH):
        print("Copying processed_data.zip from Drive...")
        shutil.copy(DRIVE_ZIP_PATH, "/content/processed_data.zip")
        print("Unzipping...")
        !unzip -q -o /content/processed_data.zip -d /
        print("Data Ready at /content/processed_data")
    else:
        print(f"Error: Could not find processed_data.zip at {DRIVE_ZIP_PATH}")

# Clone PointNeXt
if not os.path.exists("/content/PointNeXt"):
    print("Cloning PointNeXt...")
    %cd /content
    !git clone https://github.com/guochengqian/PointNeXt.git

## For openpoints

In [5]:
# For openpoints

%cd /content/PointNeXt

# 1. Replace SSH url with HTTPS url in .gitmodules
!sed -i 's/git@github.com:/https:\/\/github.com\//' .gitmodules

# 2. Sync the new URL
!git submodule sync

# 3. Update the submodule (This will work now)
print("Downloading openpoints via HTTPS...")
!git submodule update --init --recursive

print("\u2705 Submodule 'openpoints' downloaded successfully.")

/content/PointNeXt
Synchronizing submodule url for 'openpoints'
Downloading openpoints via HTTPS...
‚úÖ Submodule 'openpoints' downloaded successfully.


## Config File

In [6]:
config_path = "/content/PointNeXt/cfgs/shapenetpart/custom_fallen_trees.yaml"

config_content = """
num_classes: 4
shape_classes: 2
epochs: 100

# --- DATA CONFIG ---
# Explicitly define feature keys to match the custom data loader
feature_keys: 'pos,x'

model:
  NAME: BasePartSeg
  encoder_args:
    NAME: PointNextEncoder
    blocks: [1, 1, 1, 1, 1]
    strides: [1, 2, 2, 2, 2]
    width: 32
    in_channels: 7
    sa_layers: 3
    sa_use_res: True
    radius: 0.1
    radius_scaling: 2.5
    nsample: 32
    expansion: 4
    aggr_args:
      feature_type: 'dp_fj'
    reduction: 'max'
    group_args:
      NAME: 'ballquery'
      normalize_dp: True
    conv_args:
      order: conv-norm-act
    act_args:
      act: 'relu'
    norm_args:
      norm: 'bn'
  decoder_args:
    NAME: PointNextPartDecoder
    cls_map: curvenet
  cls_args:
    NAME: SegHead
    global_feat: max,avg
    num_classes: 4
    shape_classes: 2
    in_channels: null
    norm_args:
      norm: 'bn'

dataset:
  common:
    NAME: FallenTreePart
    data_root: /content/processed_data
    use_normal: False
    use_xyz: True
    num_points: 2048
  train:
    split: train
  val:
    split: val

batch_size: 16
dataloader:
  num_workers: 4

lr: 0.001
min_lr: null
optimizer:
  NAME: adamw
  weight_decay: 1.0e-4

criterion_args:
  NAME: Poly1FocalLoss

# sched:
#   # NAME: MultiStepLR
#   NAME: MultiStepLRScheduler
#   milestones: [70, 90]  # Drop LR at epoch 70 and 90
#   gamma: 0.1
#   warmup_epochs: 0

# scheduler
epochs: 100
sched: multistep
decay_epochs: [70, 90]
decay_rate: 0.1
warmup_epochs: 0

datatransforms:
  train: [PointsToTensor, PointCloudScaling, PointCloudCenterAndNormalize, PointCloudJitter, ChromaticDropGPU]
  val: [PointsToTensor, PointCloudCenterAndNormalize]
  kwargs:
    jitter_sigma: 0.001
    jitter_clip: 0.005
    scale: [0.8, 1.2]
    gravity_dim: 1
    angle: [0, 1.0, 0]

log_dir: /content/PointNeXt/log/shapenetpart/custom_trees
"""

with open(config_path, 'w') as f:
    f.write(config_content)
print(f"\u2705 Config Updated.")

‚úÖ Config Updated.


## Dataset Handle

In [7]:
# Create the directory
new_dataset_dir = "/content/PointNeXt/openpoints/dataset/fallentree"
os.makedirs(new_dataset_dir, exist_ok=True)
print(f"\U0001F4BE Created folder: {new_dataset_dir}")

# Create the '__init__.py' to make it a package
init_path = os.path.join(new_dataset_dir, "__init__.py")
with open(init_path, 'w') as f:
    f.write("from .fallentree import FallenTreePart\n")
print(f"\u2705 Created: {init_path}")

# Create the 'fallentree.py' (The Custom Loader)
code_path = os.path.join(new_dataset_dir, "fallentree.py")

üíæ Created folder: /content/PointNeXt/openpoints/dataset/fallentree
‚úÖ Created: /content/PointNeXt/openpoints/dataset/fallentree/__init__.py


In [8]:
# Define Path
code_path = "/content/PointNeXt/openpoints/dataset/fallentree/fallentree.py"

# Define Code
dataset_code = """
import os
import glob
import json
import logging
import numpy as np
import torch
from torch.utils.data import Dataset
from ..build import DATASETS

# NEW PROPORTIONAL SAMPLING FUNCTION
def proportional_sample(xyz, part_labels, npoint):
    \"\"\"
    Proportional Stratified Sampling:
    Calculates the % of each part in the original tree and keeps that same % in the final 2048 sample.
    \"\"\"

    total_points = len(xyz)
    if total_points <= npoint:
        # If tree is small, just repeat points (Upsample)
        return np.random.choice(total_points, npoint, replace=True)

    unique_parts, counts = np.unique(part_labels, return_counts=True)

    # Calculate Ratios (e.g., Root is 5% of tree, Trunk is 95%)
    ratios = counts / total_points

    # Calculate target points (e.g., 5% of 2048 = 102 points)
    target_counts = (ratios * npoint).astype(int)

    final_indices = []

    # Sample points per part
    for part, count in zip(unique_parts, target_counts):
        count = max(1, count)   # Ensure we take at least 1 point if the part exists

        part_indices = np.where(part_labels == part)[0]

        # Pick random points from this specific part
        chosen = np.random.choice(part_indices, count, replace=False)
        final_indices.extend(chosen)

    # FILLING THE GAP (due to rounding)
    current_count = len(final_indices)
    if current_count < npoint:
        needed = npoint - current_count
        # Pick random points from the WHOLE tree to fill the tiny gap
        remaining_fill = np.random.choice(total_points, needed, replace=False)
        final_indices.extend(remaining_fill)

    # If by some edge case we have too many, trim it
    if current_count > npoint:
        final_indices = final_indices[:npoint]

    return np.array(final_indices)

@DATASETS.register_module()
class FallenTreePart(Dataset):
    classes = ['standing', 'fallen']
    num_classes = 4
    shape_classes = 2

    # --- FIX: Add dummy key -1 for part_seg_refinement compatibility ---
    # -1 points to all parts (0,1,2,3) so the code can find the max index
    cls2parts = {
        0: [0],
        1: [1, 2, 3],
        -1: [0, 1, 2, 3]
    }

    part_start = [0, 1]

    # Pre-compute embedding
    cls2partembed = torch.zeros(shape_classes, num_classes)
    for i in [0, 1]: # Iterate only real classes
        idx = cls2parts[i]
        cls2partembed[i].scatter_(0, torch.LongTensor(idx), 1)

    def __init__(self,
                 data_root,
                 split=None,
                 num_points=2048,
                 use_normal=False,
                 use_xyz=True,
                 **kwargs):
        self.root = data_root
        self.npoints = num_points
        self.split = split

        # if split == 'val': split_name = 'test'
        # else: split_name = split

        split_name = split

        split_file = os.path.join(self.root, 'train_test_split', f'shuffled_{split_name}_file_list.json')
        if not os.path.exists(split_file):
             raise FileNotFoundError(f"Split list not found: {split_file}")

        logging.info(f"Loading {split} split from: {split_file}")
        with open(split_file, 'r') as f:
            raw_list = json.load(f)

        self.file_list = []
        for item in raw_list:
            clean_item = item.replace(\"\\\\\", \"/\")
            fname = os.path.basename(clean_item)
            candidates = [
                clean_item,
                os.path.join(self.root, clean_item),
                os.path.join(self.root, '0', fname),
                os.path.join(self.root, '1', fname)
            ]
            found = False
            for path in candidates:
                if os.path.exists(path):
                    self.file_list.append(path)
                    found = True
                    break

        logging.info(f"Found {len(self.file_list)} valid files for {split} split.")

    def __getitem__(self, index):
        file_path = self.file_list[index]
        cls_idx = 0 if '/0/' in file_path.replace('\\\\', '/') else 1

        data = np.load(file_path).astype(np.float32)
        xyz = data[:, 0:3]
        features = data[:, 3:7]
        part_label = data[:, 7].astype(np.int64)
        cls_label = np.array([cls_idx]).astype(np.int64)

        choice = proportional_sample(xyz, part_label, self.npoints)

        # if len(xyz) >= self.npoints:
        #     choice = np.random.choice(len(xyz), self.npoints, replace=False)
        # else:
        #     choice = np.random.choice(len(xyz), self.npoints, replace=True)

        xyz = xyz[choice]
        features = features[choice]
        part_label = part_label[choice]

        return {'pos': xyz, 'x': features, 'y': part_label, 'cls': cls_label}

    def __len__(self):
        return len(self.file_list)
"""

with open(code_path, 'w') as f:
    f.write(dataset_code)
print(f"\u2705 Updated: {code_path}.")

‚úÖ Updated: /content/PointNeXt/openpoints/dataset/fallentree/fallentree.py.


In [9]:
# Register the new folder in the Main Library
# We need to add "from .fallentree import FallenTreePart" to openpoints/dataset/__init__.py

main_init = "/content/PointNeXt/openpoints/dataset/__init__.py"
with open(main_init, 'r') as f:
    content = f.read()

if "fallentree" not in content:
    print("Registering new dataset in main __init__.py...")
    with open(main_init, 'a') as f:
        f.write("\nfrom .fallentree import FallenTreePart\n")
    print("\u2705 Registration Complete.")
else:
    print("\u2705 Already registered.")

‚úÖ Already registered.


## Install Dependencies

In [10]:
# Install Dependencies
print("Installing Dependencies...")
%cd /content/PointNeXt

# A. Install PyTorch 2.4.0 (Compatible)
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121

# B. Install torch-scatter/sparse
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

# C. Fix requirements.txt
!sed -i 's/==.*//g' requirements.txt

# D. Install requirements
!pip install -r requirements.txt

# E. SKIP 'pip install -e .' (Because setup.py is missing)
print("\u2705 Dependencies Installed. (Skipped setup.py)")

Installing Dependencies...
/content/PointNeXt
Looking in indexes: https://download.pytorch.org/whl/cu121
Looking in links: https://data.pyg.org/whl/torch-2.4.0+cu121.html
Collecting ninja (from -r requirements.txt (line 3))
  Using cached ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Collecting multimethod (from -r requirements.txt (line 11))
  Using cached multimethod-2.0.2-py3-none-any.whl.metadata (8.4 kB)
Collecting pyvista (from -r requirements.txt (line 15))
  Using cached pyvista-0.46.4-py3-none-any.whl.metadata (15 kB)
Collecting deepspeed (from -r requirements.txt (line 19))
  Using cached deepspeed-0.18.3.tar.gz (1.6 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting shortuuid (from -r requirements.txt (line 20))
  Using cached shortuuid-1.0.13-py3-none-any.whl.metadata (5.8 kB)
Collecting mkdocs-material (from -r requirements.txt (line 23))
  Using cached mkdocs_material-9.7.0-py3-none-any.whl.metadata (19 kB)
Collect

In [11]:
# 1. Search for the missing setup.py
print("Searching for C++ kernel setup.py...")
target_dir = "/content/PointNeXt/openpoints/cpp"
found_setup = False

for root, dirs, files in os.walk(target_dir):
    if "setup.py" in files:
        print(f"\u2705 Found setup.py at: {root}")
        found_setup = True

        # 2. Force Compile
        print(f"Compiling kernels in {root}...")
        try:
            subprocess.check_call([sys.executable, "setup.py", "install"], cwd=root)
            print("Compilation Successful!")
        except subprocess.CalledProcessError as e:
            print(f"\u274c Compilation Failed: {e}")

if not found_setup:
    print("\u274c Critical Error: Could not find setup.py anywhere in openpoints/cpp!")
    print("Did the 'git submodule update' step finish successfully?")

Searching for C++ kernel setup.py...
‚úÖ Found setup.py at: /content/PointNeXt/openpoints/cpp/pointnet2_batch
Compiling kernels in /content/PointNeXt/openpoints/cpp/pointnet2_batch...
Compilation Successful!
‚úÖ Found setup.py at: /content/PointNeXt/openpoints/cpp/emd
Compiling kernels in /content/PointNeXt/openpoints/cpp/emd...
Compilation Successful!
‚úÖ Found setup.py at: /content/PointNeXt/openpoints/cpp/pointops
Compiling kernels in /content/PointNeXt/openpoints/cpp/pointops...
Compilation Successful!
‚úÖ Found setup.py at: /content/PointNeXt/openpoints/cpp/subsampling
Compiling kernels in /content/PointNeXt/openpoints/cpp/subsampling...
‚ùå Compilation Failed: Command '['/usr/bin/python3', 'setup.py', 'install']' returned non-zero exit status 1.
‚úÖ Found setup.py at: /content/PointNeXt/openpoints/cpp/chamfer_dist
Compiling kernels in /content/PointNeXt/openpoints/cpp/chamfer_dist...
Compilation Successful!


In [12]:
# Fix Missing Validation List

split_dir = "/content/processed_data/train_test_split"
test_file = os.path.join(split_dir, "shuffled_test_file_list.json")
val_file = os.path.join(split_dir, "shuffled_val_file_list.json")

print("Checking validation list...")

if os.path.exists(test_file):
    if not os.path.exists(val_file):
        print("Creating dummy validation list (copy of test list)...")
        shutil.copy(test_file, val_file)
        print(f"\u2705 Created: {val_file}")
    else:
        print("\u2705 Validation list already exists.")
else:
    print("\u274c Error: Test list not found! Did the previous copy step work?")

Checking validation list...
‚úÖ Validation list already exists.


## Train On shapenetpart

In [13]:
# Patch main.py to fix Tensor Key Error
import os

target_file = "/content/PointNeXt/examples/shapenetpart/main.py"
print(f"Patching {target_file} to fix Tensor Key Error...")

with open(target_file, 'r') as f:
    code = f.read()

# The problematic line: parts = cls2parts[cls[shape_idx]]
# We change it to: parts = cls2parts[int(cls[shape_idx])]

bad_line = "parts = cls2parts[cls[shape_idx]]"
good_line = "parts = cls2parts[int(cls[shape_idx])]"

if bad_line in code:
    code = code.replace(bad_line, good_line)
    print("\u2705 Patched: Cast tensor to int for dictionary lookup.")
else:
    print("\u26A0\uFE0F Warning: Could not find exact line. Check if file changed.")

with open(target_file, 'w') as f:
    f.write(code)

Patching /content/PointNeXt/examples/shapenetpart/main.py to fix Tensor Key Error...
‚úÖ Patched: Cast tensor to int for dictionary lookup.


In [14]:
# NOW Run Training
print("\nStarting Training...")
%cd /content/PointNeXt
!PYTHONPATH=. python examples/shapenetpart/main.py --cfg cfgs/shapenetpart/custom_fallen_trees.yaml mode=train


Starting Training...
/content/PointNeXt
2025-12-12 09:35:21.284839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765532121.565238    6108 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765532121.643725    6108 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765532122.241789    6108 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765532122.241826    6108 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765532122.241832    6108 computat

## Save to Drive

In [15]:
# Smart Backup using Glob
import glob

# 1. Define the pattern (The * acts as the regex)
# We look for any folder starting with the project name inside the log dir
search_pattern = "/content/PointNeXt/log/shapenetpart/shapenetpart-train-custom_fallen_trees*"

print(f" Searching for runs matching: {search_pattern}...")

# 2. Find all matching folders
found_folders = glob.glob(search_pattern)

if not found_folders:
    print("\u274c Error: No training folders found matching that pattern.")
else:
    # 3. Pick the LATEST folder (in case trained multiple times)
    latest_run_dir = max(found_folders, key=os.path.getmtime)
    print(f"\u2705 Found latest run: {os.path.basename(latest_run_dir)}")

    # 4. Construct the checkpoint path
    source_ckpt_dir = os.path.join(latest_run_dir, "checkpoint")
    drive_model_dir = "/content/drive/MyDrive/ML_Projects/PointNeXt/Models"

    # 5. Perform Backup
    if os.path.exists(source_ckpt_dir):
        os.makedirs(drive_model_dir, exist_ok=True)
        files = glob.glob(os.path.join(source_ckpt_dir, "*_best.pth"))

        if files:
            best_model = files[0]
            filename = os.path.basename(best_model)
            dest_path = os.path.join(drive_model_dir, filename)
            shutil.copy(best_model, dest_path)
            print(f"\U0001F4BE Success! Model saved to: {dest_path}")
        else:
            print("\u26A0\uFE0F Warning: No '_best.pth' file found.")
    else:
        print(f"\u274c Error: Checkpoint folder missing in {latest_run_dir}")

 Searching for runs matching: /content/PointNeXt/log/shapenetpart/shapenetpart-train-custom_fallen_trees*...
‚úÖ Found latest run: shapenetpart-train-custom_fallen_trees-ngpus1-seed1066-20251212-093536-S85TXDdERtSiQuih8nf377
üíæ Success! Model saved to: /content/drive/MyDrive/ML_Projects/PointNeXt/Models/shapenetpart-train-custom_fallen_trees-ngpus1-seed1066-20251212-093536-S85TXDdERtSiQuih8nf377_ckpt_best.pth
