From ccadc8063b0bd2fbf07eaf4bff37c851d8e15173 Mon Sep 17 00:00:00 2001 From: Arni Gunnarsson Date: Tue, 2 Sep 2025 00:13:04 +0100 Subject: [PATCH 1/6] left multiplication --- test/test_multiply.py | 76 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 test/test_multiply.py diff --git a/test/test_multiply.py b/test/test_multiply.py new file mode 100644 index 0000000..7ee790e --- /dev/null +++ b/test/test_multiply.py @@ -0,0 +1,76 @@ +# This file is part of the Threads software suite. +# Copyright (C) 2024-2025 Threads Developers. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import numpy as np +import pgenlib +import pytest + +from threads_arg.serialization import load_instructions + +from snapshot_runners import ( + TEST_DATA_DIR +) + +def _col_normalize(x): + z = x.copy() + mu = z.mean(axis=0, keepdims=True) + std = z.std(axis=0, keepdims=True) + return (z - mu) / std + +def test_left_multiply(): + # Read ground truth genotypes + pgen_path = str(TEST_DATA_DIR / "panel.pgen") + reader = pgenlib.PgenReader(str(pgen_path).encode()) + expected_num_variants = reader.get_variant_ct() + num_samples = reader.get_raw_sample_ct() + expected_gt = np.empty((expected_num_variants, 2 * num_samples), dtype=np.int32) + reader.read_alleles_range(0, expected_num_variants, expected_gt) + gt_matrix = expected_gt.transpose() + gt_matrix_dip = gt_matrix[::2] + gt_matrix[1::2] + gt_matrix_norm = _col_normalize(gt_matrix) + gt_matrix_dip_norm = _col_normalize(gt_matrix_dip) + + # Read threading instructions + threads_path = str(TEST_DATA_DIR / "expected_infer_snapshot.threads") + instructions = load_instructions(threads_path) + + # Random vector to multiply with + rng = np.random.default_rng(130222) + x_hap = rng.normal(0, 1, 2 * num_samples) + x_dip = rng.normal(0, 1, num_samples) + + # Make sure length checks are performed + with pytest.raises(RuntimeError): + instructions.left_multiply(x_dip) + with pytest.raises(RuntimeError): + instructions.left_multiply(x_hap, diploid=True) + + # Do normal left-multiplication + expected = x_hap @ gt_matrix + expected_norm = x_hap @ gt_matrix_norm + expected_dip = x_dip @ gt_matrix_dip + expected_dip_norm = x_dip @ gt_matrix_dip_norm + + # Do threads left-multiplication and confirm results are correct + found = instructions.left_multiply(x_hap) + assert np.allclose(expected, found) + found_norm = instructions.left_multiply(x_hap, normalize=True) + assert np.allclose(expected_norm, found_norm) + found_dip = instructions.left_multiply(x_dip, diploid=True) + assert np.allclose(expected_dip, found_dip) + breakpoint() + found_dip_norm = instructions.left_multiply(x_dip, normalize=True, diploid=True) + assert np.allclose(expected_dip_norm, found_dip_norm) From e399d7bc8699d0912b0ecda9e93d11c95236e6c6 Mon Sep 17 00:00:00 2001 From: Arni Gunnarsson Date: Tue, 2 Sep 2025 00:13:14 +0100 Subject: [PATCH 2/6] left multiplication --- src/ThreadingInstructions.cpp | 74 +++++++++++++++++++++++++++++++++++ src/ThreadingInstructions.hpp | 4 ++ src/threads_arg_pybind.cpp | 3 +- 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/src/ThreadingInstructions.cpp b/src/ThreadingInstructions.cpp index 121fde0..5f93805 100644 --- a/src/ThreadingInstructions.cpp +++ b/src/ThreadingInstructions.cpp @@ -15,7 +15,9 @@ // along with this program. If not, see . #include "ThreadingInstructions.hpp" +#include "GenotypeIterator.hpp" +#include #include #include #include @@ -259,3 +261,75 @@ ThreadingInstructions ThreadingInstructions::sub_range(const int range_start, co std::move(range_positions) }; } +std::vector ThreadingInstructions::left_multiply(const std::vector& x, bool diploid, bool normalize) { + // Left-multiplication of the genotype matrix by a vector of doubles + + // Check input vector lengths are correct + if (diploid) { + if (x.size() != num_samples / 2) { + std::ostringstream oss; + oss << "Input vector must have length " << num_samples / 2 << "."; + throw std::runtime_error(oss.str()); + } + } else { + if (x.size() != num_samples) { + std::ostringstream oss; + oss << "Input vector must have length " << num_samples << "."; + throw std::runtime_error(oss.str()); + } + } + + // Initialize genotype traversal + GenotypeIterator gi = GenotypeIterator(*this); + std::size_t site_counter = 0; + std::vector out(num_sites); + + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + + // Initialize the next entry + double entry = 0.0; + + if (normalize) { + // If we want to normalize, we need the mean and standard deviation of g. + double ac = 0.0; + for (auto a : g) { + ac += a; + } + if (diploid) { + // We do the diploid standard deviation by hand + double mu = 2.0 * ac / num_samples; + double sample_var = 0.0; + for (std::size_t i=0; i < x.size(); i++) { + int h = g.at(2 * i) + g.at(2 * i + 1); + double d = h - mu; + sample_var += d * d; + } + sample_var /= (num_samples / 2); + + double std = std::sqrt(sample_var); + for (std::size_t i=0; i < x.size(); i++) { + int h = g.at(2 * i) + g.at(2 * i + 1); + double w = x.at(i); + entry += w * (h - mu) / std; + } + } else { + double mu = ac / num_samples; + double std = std::sqrt(mu * (1 - mu)); + for (std::size_t i=0; i < g.size(); i++) { + double w = x.at(i); + entry += w * (g.at(i) - mu) / std; + } + } + } else { + for (std::size_t i=0; i < g.size(); i++) { + double w = diploid ? x.at(i / 2) : x.at(i); + entry += w * g.at(i); + } + } + out[site_counter] = entry; + site_counter++; + } + return out; +} diff --git a/src/ThreadingInstructions.hpp b/src/ThreadingInstructions.hpp index ae0136b..f404be2 100644 --- a/src/ThreadingInstructions.hpp +++ b/src/ThreadingInstructions.hpp @@ -85,6 +85,10 @@ class ThreadingInstructions { ThreadingInstructions sub_range(const int range_start, const int range_end) const; + // Common operations + std::vector left_multiply(const std::vector& x, bool diploid=false, bool normalize=false); + std::vector right_multiply(const std::vector& x, bool diploid=false, bool normalize=false); + public: int start = 0; int end = 0; diff --git a/src/threads_arg_pybind.cpp b/src/threads_arg_pybind.cpp index b8ca12d..c5bdd2f 100644 --- a/src/threads_arg_pybind.cpp +++ b/src/threads_arg_pybind.cpp @@ -140,7 +140,8 @@ PYBIND11_MODULE(threads_arg_python_bindings, m) { .def("sub_range", &ThreadingInstructions::sub_range) .def(py::pickle( &threading_instructions_get_state, - &threading_instructions_set_state)); + &threading_instructions_set_state)) + .def("left_multiply", &ThreadingInstructions::left_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false); py::class_(m, "ConsistencyWrapper") .def(py::init>&, const std::vector>&, const std::vector>&, From f4ee56a87db34da609929f880fc8fe4761d61f74 Mon Sep 17 00:00:00 2001 From: Arni Gunnarsson Date: Tue, 2 Sep 2025 00:30:47 +0100 Subject: [PATCH 3/6] remove stray breakpoint --- test/test_multiply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_multiply.py b/test/test_multiply.py index 7ee790e..05ba296 100644 --- a/test/test_multiply.py +++ b/test/test_multiply.py @@ -71,6 +71,5 @@ def test_left_multiply(): assert np.allclose(expected_norm, found_norm) found_dip = instructions.left_multiply(x_dip, diploid=True) assert np.allclose(expected_dip, found_dip) - breakpoint() found_dip_norm = instructions.left_multiply(x_dip, normalize=True, diploid=True) assert np.allclose(expected_dip_norm, found_dip_norm) From 311da6549891cae0b3c503117583ba2dfec21e02 Mon Sep 17 00:00:00 2001 From: Arni Gunnarsson Date: Tue, 2 Sep 2025 23:06:42 +0100 Subject: [PATCH 4/6] add right-multiplication --- src/ThreadingInstructions.cpp | 94 +++++++++++++++++++++++++++++++++++ src/threads_arg_pybind.cpp | 3 +- test/test_multiply.py | 42 ++++++++++++++++ 3 files changed, 138 insertions(+), 1 deletion(-) diff --git a/src/ThreadingInstructions.cpp b/src/ThreadingInstructions.cpp index 5f93805..cb148fe 100644 --- a/src/ThreadingInstructions.cpp +++ b/src/ThreadingInstructions.cpp @@ -261,6 +261,7 @@ ThreadingInstructions ThreadingInstructions::sub_range(const int range_start, co std::move(range_positions) }; } + std::vector ThreadingInstructions::left_multiply(const std::vector& x, bool diploid, bool normalize) { // Left-multiplication of the genotype matrix by a vector of doubles @@ -333,3 +334,96 @@ std::vector ThreadingInstructions::left_multiply(const std::vector ThreadingInstructions::right_multiply(const std::vector& x, bool diploid, bool normalize) { + // Right-multiplication of the genotype matrix by a vector of doubles + + // Check input vector lengths are correct + if (x.size() != num_sites) { + std::ostringstream oss; + oss << "Input vector must have length " << num_samples / 2 << "."; + throw std::runtime_error(oss.str()); + } + + GenotypeIterator gi = GenotypeIterator(*this); + std::size_t site_counter = 0; + if (diploid) { + // Initialize output + std::vector out(num_samples / 2, 0.0); + if (normalize) { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + + // If we want to normalize, we need the mean and standard deviation of g. + double ac = 0.0; + for (auto a : g) { + ac += a; + } + + // We do the diploid standard deviation by hand + const double mu = 2.0 * ac / num_samples; + double sample_var = 0.0; + for (std::size_t i=0; i < out.size(); i++) { + int h = g.at(2 * i) + g.at(2 * i + 1); + double d = h - mu; + sample_var += d * d; + } + sample_var /= (num_samples / 2); + const double std = std::sqrt(sample_var); + + const double w = x.at(site_counter) / std; + for (std::size_t i=0; i < out.size(); i++) { + const int h = g.at(2 * i) + g.at(2 * i + 1); + out[i] += w * (h - mu); + } + site_counter++; + } + } else { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + const double w = x.at(site_counter); + for (std::size_t i=0; i < out.size(); i++) { + const int h = g.at(2 * i) + g.at(2 * i + 1); + out[i] += w * h; + } + site_counter++; + } + } + return out; + } else { + // Initialize output + std::vector out(num_samples, 0.0); + if (normalize) { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + double ac = 0.0; + for (auto a : g) { + ac += a; + } + + // Normalization constants + double mu = ac / num_samples; + double std = std::sqrt(mu * (1 - mu)); + const double w = x.at(site_counter) / std; + for (std::size_t i=0; i < out.size(); i++) { + out[i] += w * (g.at(i) - mu); + } + site_counter++; + } + } else { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + const double w = x.at(site_counter); + for (std::size_t i=0; i < out.size(); i++) { + out[i] += w * g.at(i); + } + site_counter++; + } + } + return out; + } +} diff --git a/src/threads_arg_pybind.cpp b/src/threads_arg_pybind.cpp index c5bdd2f..c829a20 100644 --- a/src/threads_arg_pybind.cpp +++ b/src/threads_arg_pybind.cpp @@ -141,7 +141,8 @@ PYBIND11_MODULE(threads_arg_python_bindings, m) { .def(py::pickle( &threading_instructions_get_state, &threading_instructions_set_state)) - .def("left_multiply", &ThreadingInstructions::left_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false); + .def("left_multiply", &ThreadingInstructions::left_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false) + .def("right_multiply", &ThreadingInstructions::right_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false); py::class_(m, "ConsistencyWrapper") .def(py::init>&, const std::vector>&, const std::vector>&, diff --git a/test/test_multiply.py b/test/test_multiply.py index 05ba296..9b2e1e5 100644 --- a/test/test_multiply.py +++ b/test/test_multiply.py @@ -73,3 +73,45 @@ def test_left_multiply(): assert np.allclose(expected_dip, found_dip) found_dip_norm = instructions.left_multiply(x_dip, normalize=True, diploid=True) assert np.allclose(expected_dip_norm, found_dip_norm) + +def test_right_multiply(): + # Read ground truth genotypes + pgen_path = str(TEST_DATA_DIR / "panel.pgen") + reader = pgenlib.PgenReader(str(pgen_path).encode()) + expected_num_variants = reader.get_variant_ct() + num_samples = reader.get_raw_sample_ct() + expected_gt = np.empty((expected_num_variants, 2 * num_samples), dtype=np.int32) + reader.read_alleles_range(0, expected_num_variants, expected_gt) + gt_matrix = expected_gt.transpose() + gt_matrix_dip = gt_matrix[::2] + gt_matrix[1::2] + gt_matrix_norm = _col_normalize(gt_matrix) + gt_matrix_dip_norm = _col_normalize(gt_matrix_dip) + + # Read threading instructions + threads_path = str(TEST_DATA_DIR / "expected_infer_snapshot.threads") + instructions = load_instructions(threads_path) + + # Random vector to multiply with + rng = np.random.default_rng(130222) + x = rng.normal(0, 1, expected_num_variants) + x_wrong_length = rng.normal(0, 1, expected_num_variants + 1) + + # Make sure length check is performed + with pytest.raises(RuntimeError): + instructions.left_multiply(x_wrong_length) + + # Do normal right-multiplication + expected = gt_matrix @ x + expected_norm = gt_matrix_norm @ x + expected_dip = gt_matrix_dip @ x + expected_dip_norm = gt_matrix_dip_norm @ x + + # Do threads right-multiplication and confirm results are correct + found = instructions.right_multiply(x) + assert np.allclose(expected, found) + found_norm = instructions.right_multiply(x, normalize=True) + assert np.allclose(expected_norm, found_norm) + found_dip = instructions.right_multiply(x, diploid=True) + assert np.allclose(expected_dip, found_dip) + found_dip_norm = instructions.right_multiply(x, normalize=True, diploid=True) + assert np.allclose(expected_dip_norm, found_dip_norm) From 8dbbd85562624fea1f4c8b25e93ec1dbb30bffd2 Mon Sep 17 00:00:00 2001 From: Arni Gunnarsson Date: Fri, 5 Sep 2025 23:34:57 +0100 Subject: [PATCH 5/6] replace .at() with brackets --- src/ThreadingInstructions.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ThreadingInstructions.cpp b/src/ThreadingInstructions.cpp index cb148fe..4ba5be4 100644 --- a/src/ThreadingInstructions.cpp +++ b/src/ThreadingInstructions.cpp @@ -303,7 +303,7 @@ std::vector ThreadingInstructions::left_multiply(const std::vector ThreadingInstructions::left_multiply(const std::vector ThreadingInstructions::right_multiply(const std::vector ThreadingInstructions::right_multiply(const std::vector& g = gi.next_genotype(); - const double w = x.at(site_counter); + const double w = x[site_counter]; for (std::size_t i=0; i < out.size(); i++) { - const int h = g.at(2 * i) + g.at(2 * i + 1); + const int h = g[2 * i] + g[2 * i + 1]; out[i] += w * h; } site_counter++; @@ -407,9 +407,9 @@ std::vector ThreadingInstructions::right_multiply(const std::vector ThreadingInstructions::right_multiply(const std::vector& g = gi.next_genotype(); - const double w = x.at(site_counter); + const double w = x[site_counter]; for (std::size_t i=0; i < out.size(); i++) { - out[i] += w * g.at(i); + out[i] += w * g[i]; } site_counter++; } From 7b77a6bb28eded8e614a2464b7f3043f311e1f97 Mon Sep 17 00:00:00 2001 From: Alex Allmont Date: Fri, 21 Nov 2025 17:26:32 +0000 Subject: [PATCH 6/6] docs: update changelog left/right multiplication (#99) --- RELEASE_NOTES.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index cf5f86a..7443e53 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- Add left_multiplication and right_multiplication to ThreadingInstructions (#99) + ### Changed - Build wheels on macOS 14 for arm64 and macOS 15 for x86_64 (#108)