Skip to content

Commit

Permalink
Release the source code of TorchSparse v2.1 (#254)
Browse files Browse the repository at this point in the history
* [Major] Add v2.1 source code.

* [Minor] Update README.md

* [Minor] Add PCEngine kernel citation information.
  • Loading branch information
kentang-mit committed Oct 30, 2023
1 parent b55506a commit afa2e3b
Show file tree
Hide file tree
Showing 104 changed files with 16,218 additions and 1,973 deletions.
37 changes: 34 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# TorchSparse

<p align="center">
<img
src="./docs/figs/torchsparse.png"
height="300"
>

TorchSparse is a high-performance neural network library for point cloud processing.

### [website](http://torchsparse.mit.edu/) | [paper](https://arxiv.org/abs/2204.10319) | [presentation](https://www.youtube.com/watch?v=IIh4EwmcLUs) | [documents](http://torchsparse-docs.github.io/) | [pypi server](http://pypi.hanlab.ai/simple/torchsparse)
### [website](http://torchsparse.mit.edu/) | [paper (MICRO 2023)](https://www.dropbox.com/scl/fi/obdku0kqxjlkvuom2opk4/paper.pdf?rlkey=0zmy8eq9fzllgkx54zsvwsecf&dl=0) | [paper (MLSys 2022)](https://arxiv.org/abs/2204.10319) | [presentation](https://www.youtube.com/watch?v=IIh4EwmcLUs) | [documents](http://torchsparse-docs.github.io/) | [pypi server](http://pypi.hanlab.ai/simple/torchsparse)


## Introduction
Expand All @@ -11,6 +18,8 @@ Point cloud computation has become an increasingly more important workload for a

## News

**\[2023/10/30\]** We present TorchSparse++ at 56th IEEE/ACM International Symposium on Microarchitecture (MICRO 2023). We also fully release the source code of TorchSparse++.

**\[2023/6/18\]** TorchSparse++ has been released and presented at CVPR 2023 workshops on autonomous driving. It achieves 1.7-2.9x inference speedup over previous state-of-the-art systems.

**\[2022/8/29\]** TorchSparse is presented at MLSys 2022. Talk video is available [here](https://www.youtube.com/watch?v=IIh4EwmcLUs).
Expand Down Expand Up @@ -46,6 +55,13 @@ We provide pre-built torchsparse v2.1.0 packages (recommended) with different Py

If Pypi server does not work as expected, no worries, you can still manually download the wheels. The wheels are listed in [this website](http://pypi.hanlab.ai/simple/torchsparse). One can utilize our installation script to automatically determine the version number used to index the wheels. For example, if you use PyTorch 1.11.0, CUDA 11.5, the version number will end up to be 2.1.0+torch111cu115. You can then select the proper wheel according to your Python version.


You may also alternatively install our library from source via:

```bash
python setup.py install
```

## Benchmarks

### Inference benchmarks
Expand Down Expand Up @@ -73,6 +89,7 @@ TorchSparse is developed by the following wonderful team:
- [Ke Hong](https://ieeexplore.ieee.org/author/37089419138): Graduate student (2021-) at Tsinghua University EE, v2.1 core developer, authored PCEngine kernels;
- [Zhongming Yu](https://fishmingyu.github.io/): Ph.D. student (2022-) at UCSD CS, v2.1 core developer, authored PCEngine kernels;
- [Yujun Lin](https://yujunlin.com/): Ph.D. student (2018-) at MIT EECS, v2.0 core developer;
- [Yingqi Cao](https://github.com/ioeddk): Undergrad student at UC San Diego, currently working on the TorchSparse++ integration into algorithm frameworks;
- [Guohao Dai](https://scholar.google.com/citations?user=gz3Tkl0AAAAJ&hl=en): Associate Professor at Shanghai Jiao Tong University, mentor of the project;
- [Yu Wang](http://nicsefc.ee.tsinghua.edu.cn/): Professor at Tsinghua University, mentor of the project;
- [Song Han](https://songhan.mit.edu): Associate Professor at MIT EECS, mentor of the project.
Expand All @@ -82,6 +99,17 @@ TorchSparse is developed by the following wonderful team:

If you use TorchSparse, please use the following BibTeX entries to cite:

TorchSparse++ (TorchSparse v2.1) is presented at MICRO 2023:

```bibtex
@inproceedings{tangandyang2023torchsparse,
title={TorchSparse++: Efficient Training and Inference Framework for Sparse Convolution on GPUs},
author={Tang, Haotian and Yang, Shang and Liu, Zhijian and Hong, Ke and Yu, Zhongming and Li, Xiuyu and Dai, Guohao and Wang, Yu and Han, Song},
booktitle={IEEE/ACM International Symposium on Microarchitecture (MICRO)},
year={2023}
}
```

Preliminary version of TorchSparse++ (TorchSparse v2.1) is presented at CVPR Workshops 2023:

```bibtex
Expand Down Expand Up @@ -128,6 +156,9 @@ PCEngine paper is accepted by MLSys 2023:

## Acknowledgement

We thank Yan Yan from TuSimple for helpful discussions.
We thank Yan Yan from TuSimple for helpful discussions. Please also have a look at the [dgSparse](https://dgsparse.github.io/) library, which is designed for fast and efficient sparse computation on graphs and point clouds. The work from PCEngine (MLSys 2023) team is also highly related to us.

TorchSparse is inspired by many existing open-source libraries, including (but not limited to) [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine), [SECOND](https://github.com/traveller59/second.pytorch) and [SparseConvNet](https://github.com/facebookresearch/SparseConvNet).

We also thank [AttributeDict](https://github.com/grimen/python-attributedict/tree/master) for providing an elegant way to manage the kernel/model configurations.

Please also have a look at the [dgSparse](https://dgsparse.github.io/) library, which is designed for fast and efficient sparse computation on graphs and point clouds. The work from PCEngine (MLSys 2023) team is also highly related to us.
98 changes: 98 additions & 0 deletions cython_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# cython: language_level=3
import glob
import os
import sys

import torch
import torch.cuda
from setuptools import find_packages, setup
from torch.utils.cpp_extension import (
CUDA_HOME,
BuildExtension,
CppExtension,
CUDAExtension,
)

# from torchsparse import __version__

from Cython.Build import cythonize

cython_clean_flag = False

version_file = open("./torchsparse/version.py")
version = version_file.read().split("'")[1]
print("torchsparse version:", version)

if (torch.cuda.is_available() and CUDA_HOME is not None) or (
os.getenv("FORCE_CUDA", "0") == "1"
):
device = "cuda"
pybind_fn = f"pybind_{device}.cu"
else:
device = "cpu"
pybind_fn = f"pybind_{device}.cpp"

sources = [os.path.join("torchsparse", "backend", pybind_fn)]
for fpath in glob.glob(os.path.join("torchsparse", "backend", "**", "*")):
if (fpath.endswith("_cpu.cpp") and device in ["cpu", "cuda"]) or (
fpath.endswith("_cuda.cu") and device == "cuda"
):
sources.append(fpath)

pyx_files = []
for root, dirnames, filenames in os.walk("torchsparse"):
for filename in filenames:
file_path = os.path.join(root, filename)
if file_path.endswith(".py"):
file_path2 = file_path + "x"
os.system("mv " + file_path + " " + file_path2)
os.system("sed -i '1s/^/# cython: language_level=3\\n/' " + file_path2)
pyx_files.append(file_path2)

if pyx_files == []:
for root, dirnames, filenames in os.walk("torchsparse"):
for filename in filenames:
file_path = os.path.join(root, filename)
if file_path.endswith(".pyx"):
pyx_files.append(file_path)

extension_type = CUDAExtension if device == "cuda" else CppExtension
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp"],
"nvcc": ["-O3", "-std=c++17"],
}

setup(
name="torchsparse",
version=version,
packages=find_packages(),
ext_modules=cythonize(
[
extension_type(
"torchsparse.backend", sources, extra_compile_args=extra_compile_args
),
]
+ pyx_files
),
install_requires=[
"numpy",
"backports.cached_property",
"tqdm",
"typing-extensions",
"wheel",
"rootpath",
"attributedict",
],
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
)

# Clean up
if cython_clean_flag:
for root, dirnames, filenames in os.walk("torchsparse"):
for filename in filenames:
file_path = os.path.join(root, filename)
if file_path.endswith(".c"):
os.system("rm " + file_path)
if file_path.endswith(".pyx"):
os.system("rm " + file_path)
Binary file added docs/figs/torchsparse.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 6 additions & 5 deletions examples/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

@torch.no_grad()
def main() -> None:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = "cuda:0" if torch.cuda.is_available() else "cpu"
from torchsparse.nn import functional as F
F.set_kmap_mode('hashmap')

F.set_kmap_mode("hashmap")

for backbone in [SparseResNet21D, SparseResUNet42]:
print(f'{backbone.__name__}:')
print(f"{backbone.__name__}:")
model: nn.Module = backbone(in_channels=4, width_multiplier=1.0)
model = model.to(device).eval()

Expand All @@ -36,8 +37,8 @@ def main() -> None:

# print feature shapes
for k, output in enumerate(outputs):
print(f'output[{k}].F.shape = {output.feats.shape}')
print(f"output[{k}].F.shape = {output.feats.shape}")


if __name__ == '__main__':
if __name__ == "__main__":
main()
32 changes: 17 additions & 15 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import torchsparse
from torchsparse import SparseTensor
from torchsparse import nn as spnn
from torchsparse.nn import functional as F
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize


class RandomDataset:

def __init__(self, input_size: int, voxel_size: float) -> None:
self.input_size = input_size
self.voxel_size = voxel_size
Expand All @@ -27,26 +27,28 @@ def __getitem__(self, _: int) -> Dict[str, Any]:

coords, feats = inputs[:, :3], inputs
coords -= np.min(coords, axis=0, keepdims=True)
coords, indices = sparse_quantize(coords,
self.voxel_size,
return_index=True)
coords, indices = sparse_quantize(coords, self.voxel_size, return_index=True)

coords = torch.tensor(coords, dtype=torch.int)
feats = torch.tensor(feats[indices], dtype=torch.float)
labels = torch.tensor(labels[indices], dtype=torch.long)

input = SparseTensor(coords=coords, feats=feats)
label = SparseTensor(coords=coords, feats=labels)
return {'input': input, 'label': label}
return {"input": input, "label": label}

def __len__(self):
return 100


if __name__ == '__main__':
if __name__ == "__main__":
conv_config = F.get_default_conv_config()
# conv_config.dataflow = F.Dataflow.GatherScatter
F.set_global_conv_config(conv_config)

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--amp_enabled', action='store_true')
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp_enabled", action="store_true")
args = parser.parse_args()

random.seed(0)
Expand Down Expand Up @@ -81,14 +83,14 @@ def __len__(self):
scaler = amp.GradScaler(enabled=args.amp_enabled)

for k, feed_dict in enumerate(dataflow):
inputs = feed_dict['input'].to(device=args.device)
labels = feed_dict['label'].to(device=args.device)
inputs = feed_dict["input"].to(device=args.device)
labels = feed_dict["label"].to(device=args.device)

with amp.autocast(enabled=args.amp_enabled):
outputs = model(inputs)
loss = criterion(outputs.feats, labels.feats)

print(f'[step {k + 1}] loss = {loss.item()}')
print(f"[step {k + 1}] loss = {loss.item()}")

optimizer.zero_grad()
scaler.scale(loss).backward()
Expand All @@ -99,14 +101,14 @@ def __len__(self):
model.eval()
# enable fused and locality-aware memory access optimization
torchsparse.backends.benchmark = True # type: ignore

with torch.no_grad():
for k, feed_dict in enumerate(dataflow):
inputs = feed_dict['input'].to(device=args.device).half()
labels = feed_dict['label'].to(device=args.device)
inputs = feed_dict["input"].to(device=args.device).half()
labels = feed_dict["label"].to(device=args.device)

with amp.autocast(enabled=True):
outputs = model(inputs)
loss = criterion(outputs.feats, labels.feats)

print(f'[inference step {k + 1}] loss = {loss.item()}')
print(f"[inference step {k + 1}] loss = {loss.item()}")
38 changes: 18 additions & 20 deletions examples/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@ def generate_random_point_cloud(size=100000, voxel_size=0.2):
input = SparseTensor(coords=coords, feats=feats)
label = SparseTensor(coords=coords, feats=labels)

feed_dict = {'input': input, 'label': label}
feed_dict = {"input": input, "label": label}

return feed_dict


def generate_batched_random_point_clouds(size=100000,
voxel_size=0.2,
batch_size=2):
def generate_batched_random_point_clouds(size=100000, voxel_size=0.2, batch_size=2):
batch = []
for _ in range(batch_size):
batch.append(generate_random_point_cloud(size, voxel_size))
Expand All @@ -56,25 +54,25 @@ def dummy_train_3x3(device):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)

print('Starting dummy_train_3x3...')
print("Starting dummy_train_3x3...")
time = datetime.now()
with profiler.profile(profile_memory=True, use_cuda=True) as prof:
with profiler.record_function('model_inference'):
with profiler.record_function("model_inference"):
for _ in range(10):
feed_dict = generate_batched_random_point_clouds()
inputs = feed_dict['input'].to(device)
targets = feed_dict['label'].F.to(device).long()
inputs = feed_dict["input"].to(device)
targets = feed_dict["label"].F.to(device).long()
outputs = model(inputs)
optimizer.zero_grad()
loss = criterion(outputs.F, targets)
loss.backward()
optimizer.step()
# print('[step %d] loss = %f.'%(i, loss.item()))
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
prof.export_chrome_trace('trace_dummy_3x3.json')
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("trace_dummy_3x3.json")

time = datetime.now() - time
print('Finished dummy_train_3x3 in ', time)
print("Finished dummy_train_3x3 in ", time)


def dummy_train_3x1(device):
Expand All @@ -91,29 +89,29 @@ def dummy_train_3x1(device):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)

print('Starting dummy_train_3x1 ...')
print("Starting dummy_train_3x1 ...")
time = datetime.now()
with profiler.profile(profile_memory=True, use_cuda=True) as prof:
with profiler.record_function('model_inference'):
with profiler.record_function("model_inference"):
for _ in range(10):
feed_dict = generate_batched_random_point_clouds()
inputs = feed_dict['input'].to(device)
targets = feed_dict['label'].F.to(device).long()
inputs = feed_dict["input"].to(device)
targets = feed_dict["label"].F.to(device).long()
outputs = model(inputs)
optimizer.zero_grad()
loss = criterion(outputs.F, targets)
loss.backward()
optimizer.step()
# print('[step %d] loss = %f.'%(i, loss.item()))
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
prof.export_chrome_trace('trace_dummy_3x1.json')
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("trace_dummy_3x1.json")

time = datetime.now() - time
print('Finished dummy_train_3x1 in ', time)
print("Finished dummy_train_3x1 in ", time)


if __name__ == '__main__':
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if __name__ == "__main__":
device = "cuda:0" if torch.cuda.is_available() else "cpu"

dummy_train_3x1(device)
dummy_train_3x3(device)

0 comments on commit afa2e3b

Please sign in to comment.