Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Build CPP extension with clang++
run: |
export CXX=$(brew --prefix llvm@15)/bin/clang++
export LDFLAGS="-L/usr/local/opt/libomp/lib"
export CPPFLAGS="-I/usr/local/opt/libomp/include"
export CXX=$(brew --prefix llvm@18)/bin/clang++
export LDFLAGS="-L/opt/homebrew/opt/libomp/lib"
export CPPFLAGS="-I/opt/homebrew/opt/libomp/include"
pip install -e .[dev]
- name: Test with pytest
run: |
Expand Down
10 changes: 3 additions & 7 deletions .github/workflows/version.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
name: Display version

on:
push:
branches: [ "dev", "main", "alpha", "beta" ]
pull_request:
branches: [ "dev", "main", "alpha", "beta" ]
on: [push, pull_request]

permissions:
contents: read
Expand All @@ -26,8 +22,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build "setuptools-git-versioning>=2,<3"
pip install build "setuptools-git-versioning>=2,<3" numpy numba
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Display version
run: |
setuptools-git-versioning -v >> $GITHUB_STEP_SUMMARY
setuptools-git-versioning -vv >> $GITHUB_STEP_SUMMARY
17 changes: 9 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ requires = [
"setuptools-git-versioning>=2.0,<3",
"wheel",
"torch",
"numba",
"numpy",
]
build-backend = "setuptools.build_meta"
build-backend = "setuptools.build_meta:__legacy__"

[tool.setuptools-git-versioning]
enabled = true
# change the file path
version_file = "torchlpc/VERSION.txt"
count_commits_from_version_file = true # <--- enable commits tracking
dev_template = "{tag}.{branch}{ccount}" # suffix for versions will be .dev
dirty_template = "{tag}.{branch}{ccount}" # same thing here
# Temporarily disable branch formatting due to issues with regex in _version.py
# branch_formatter = "torchlpc._version:format_branch_name"
count_commits_from_version_file = true # <--- enable commits tracking
dev_template = "{tag}.{branch}{ccount}" # suffix for versions will be .dev
dirty_template = "{tag}.{branch}{ccount}" # same thing here
branch_formatter = "torchlpc._version:format_branch_name"

[tool.setuptools.package-data]
# include VERSION file to a package
Expand All @@ -29,6 +29,7 @@ exclude = ["tests", "tests.*"]
[tool.setuptools]
# this package will read some included files in runtime, avoid installing it as .zip
zip-safe = false
license-files = ["LICENSE"]

[project]
dynamic = ["version"]
Expand All @@ -39,8 +40,8 @@ authors = [{ name = "Chin-Yun Yu", email = "chin-yun.yu@qmul.ac.uk" }]
maintainers = [{ name = "Chin-Yun Yu", email = "chin-yun.yu@qmul.ac.uk" }]
description = "Fast, efficient, and differentiable time-varying LPC filtering in PyTorch."
readme = "README.md"
license = "MIT"
license-files = ["LICENSE"]
license = { text = "MIT" }

classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
Expand Down
121 changes: 63 additions & 58 deletions torchlpc/csrc/scan_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,29 @@
#include <utility>
#include <vector>

extern "C" {
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so associated with this extension
built from this file, so that all the TORCH_LIBRARY calls below are run.*/
PyObject *PyInit__C(void) {
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
extern "C"
{
/* Creates a dummy empty _C module that can be imported from Python.
The import from Python will load the .so associated with this extension
built from this file, so that all the TORCH_LIBRARY calls below are run.*/
PyObject *PyInit__C(void)
{
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_C", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
NULL, /* methods */
};
return PyModule_Create(&module_def);
}
}

template <typename scalar_t>
void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials, const at::Tensor &output) {
const at::Tensor &initials, const at::Tensor &output)
{
TORCH_CHECK(input.dim() == 2, "Input must be 2D");
TORCH_CHECK(initials.dim() == 1, "Initials must be 1D");
TORCH_CHECK(weights.sizes() == input.sizes(),
Expand All @@ -50,39 +53,33 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
auto T = input.size(1);
auto total_size = input.numel();

std::pair<scalar_t, scalar_t> buffer[total_size];

const scalar_t *input_ptr = input_contiguous.const_data_ptr<scalar_t>();
const scalar_t *initials_ptr =
initials_contiguous.const_data_ptr<scalar_t>();
const scalar_t *weights_ptr = weights_contiguous.const_data_ptr<scalar_t>();
scalar_t *output_ptr = output.mutable_data_ptr<scalar_t>();

std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer,
[](const scalar_t &a, const scalar_t &b) {
return std::make_pair(a, b);
});

at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++) {
std::inclusive_scan(
buffer + b * T, buffer + (b + 1) * T, buffer + b * T,
[](const std::pair<scalar_t, scalar_t> &a,
const std::pair<scalar_t, scalar_t> &b) {
return std::make_pair(a.first * b.first,
a.second * b.first + b.second);
},
std::make_pair((scalar_t)1.0, initials_ptr[b]));
}
});

std::transform(
buffer, buffer + total_size, output_ptr,
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end)
{
for (auto b = start; b < end; b++)
{
auto initial = initials_ptr[b];
auto weights_offset = weights_ptr + b * T;
auto input_offset = input_ptr + b * T;
auto output_offset = output_ptr + b * T;
for (int64_t t = 0; t < T; t++)
{
auto w = weights_offset[t];
auto x = input_offset[t];
initial = initial * w + x;
output_offset[t] = initial;
}
}; });
}

template <typename scalar_t>
void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out)
{
// Ensure input dimensions are correct
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
TORCH_CHECK(padded_out.dim() == 2, "out must be 2-dimensional");
Expand All @@ -106,24 +103,27 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
const scalar_t *a_ptr = a_contiguous.const_data_ptr<scalar_t>();
scalar_t *out_ptr = padded_out.mutable_data_ptr<scalar_t>();

at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++) {
auto out_offset = b * (T + order) + order;
auto a_offset = b * T * order;
for (int64_t t = 0; t < T; t++) {
scalar_t y = out_ptr[out_offset + t];
for (int64_t i = 0; i < order; i++) {
y -= a_ptr[a_offset + t * order + i] *
out_ptr[out_offset + t - i - 1];
at::parallel_for(0, B, 1, [&](int64_t start, int64_t end)
{
for (auto b = start; b < end; b++)
{
auto out_offset = out_ptr + b * (T + order) + order;
auto a_offset = a_ptr + b * T * order;
for (int64_t t = 0; t < T; t++)
{
scalar_t y = out_offset[t];
for (int64_t i = 0; i < order; i++)
{
y -= a_offset[t * order + i] * out_offset [t - i - 1];
}
out_ptr[out_offset + t] = y;
out_offset[t] = y;
}
}
});
}; });
}

at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials) {
const at::Tensor &initials)
{
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
"Input must be floating point or complex");
TORCH_CHECK(initials.scalar_type() == input.scalar_type(),
Expand All @@ -135,12 +135,14 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
input.scalar_type(), "scan_cpu",
[&] { scan_cpu<scalar_t>(input, weights, initials, output); });
[&]
{ scan_cpu<scalar_t>(input, weights, initials, output); });
return output;
}

at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
const at::Tensor &zi) {
const at::Tensor &zi)
{
TORCH_CHECK(x.is_floating_point() || x.is_complex(),
"Input must be floating point or complex");
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
Expand All @@ -156,16 +158,19 @@ at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
auto out = at::cat({zi.flip(1), x}, 1).contiguous();

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.scalar_type(), "lpc_cpu", [&] { lpc_cpu_core<scalar_t>(a, out); });
x.scalar_type(), "lpc_cpu", [&]
{ lpc_cpu_core<scalar_t>(a, out); });
return out.slice(1, zi.size(1), out.size(1)).contiguous();
}

TORCH_LIBRARY(torchlpc, m) {
TORCH_LIBRARY(torchlpc, m)
{
m.def("torchlpc::scan(Tensor a, Tensor b, Tensor c) -> Tensor");
m.def("torchlpc::lpc(Tensor a, Tensor b, Tensor c) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchlpc, CPU, m) {
TORCH_LIBRARY_IMPL(torchlpc, CPU, m)
{
m.impl("scan", &scan_cpu_wrapper);
m.impl("lpc", &lpc_cpu);
}
Loading