diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index cf5f86a..7443e53 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- Add left_multiplication and right_multiplication to ThreadingInstructions (#99) + ### Changed - Build wheels on macOS 14 for arm64 and macOS 15 for x86_64 (#108) diff --git a/src/ThreadingInstructions.cpp b/src/ThreadingInstructions.cpp index 121fde0..4ba5be4 100644 --- a/src/ThreadingInstructions.cpp +++ b/src/ThreadingInstructions.cpp @@ -15,7 +15,9 @@ // along with this program. If not, see . #include "ThreadingInstructions.hpp" +#include "GenotypeIterator.hpp" +#include #include #include #include @@ -259,3 +261,169 @@ ThreadingInstructions ThreadingInstructions::sub_range(const int range_start, co std::move(range_positions) }; } + +std::vector ThreadingInstructions::left_multiply(const std::vector& x, bool diploid, bool normalize) { + // Left-multiplication of the genotype matrix by a vector of doubles + + // Check input vector lengths are correct + if (diploid) { + if (x.size() != num_samples / 2) { + std::ostringstream oss; + oss << "Input vector must have length " << num_samples / 2 << "."; + throw std::runtime_error(oss.str()); + } + } else { + if (x.size() != num_samples) { + std::ostringstream oss; + oss << "Input vector must have length " << num_samples << "."; + throw std::runtime_error(oss.str()); + } + } + + // Initialize genotype traversal + GenotypeIterator gi = GenotypeIterator(*this); + std::size_t site_counter = 0; + std::vector out(num_sites); + + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + + // Initialize the next entry + double entry = 0.0; + + if (normalize) { + // If we want to normalize, we need the mean and standard deviation of g. + double ac = 0.0; + for (auto a : g) { + ac += a; + } + if (diploid) { + // We do the diploid standard deviation by hand + double mu = 2.0 * ac / num_samples; + double sample_var = 0.0; + for (std::size_t i=0; i < x.size(); i++) { + int h = g[2 * i] + g[2 * i + 1]; + double d = h - mu; + sample_var += d * d; + } + sample_var /= (num_samples / 2); + + double std = std::sqrt(sample_var); + for (std::size_t i=0; i < x.size(); i++) { + int h = g[2 * i] + g[2 * i + 1]; + double w = x[i]; + entry += w * (h - mu) / std; + } + } else { + double mu = ac / num_samples; + double std = std::sqrt(mu * (1 - mu)); + for (std::size_t i=0; i < g.size(); i++) { + double w = x[i]; + entry += w * (g[i] - mu) / std; + } + } + } else { + for (std::size_t i=0; i < g.size(); i++) { + double w = diploid ? x[i / 2] : x[i]; + entry += w * g[i]; + } + } + out[site_counter] = entry; + site_counter++; + } + return out; +} + +std::vector ThreadingInstructions::right_multiply(const std::vector& x, bool diploid, bool normalize) { + // Right-multiplication of the genotype matrix by a vector of doubles + + // Check input vector lengths are correct + if (x.size() != num_sites) { + std::ostringstream oss; + oss << "Input vector must have length " << num_samples / 2 << "."; + throw std::runtime_error(oss.str()); + } + + GenotypeIterator gi = GenotypeIterator(*this); + std::size_t site_counter = 0; + if (diploid) { + // Initialize output + std::vector out(num_samples / 2, 0.0); + if (normalize) { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + + // If we want to normalize, we need the mean and standard deviation of g. + double ac = 0.0; + for (auto a : g) { + ac += a; + } + + // We do the diploid standard deviation by hand + const double mu = 2.0 * ac / num_samples; + double sample_var = 0.0; + for (std::size_t i=0; i < out.size(); i++) { + int h = g[2 * i] + g[2 * i + 1]; + double d = h - mu; + sample_var += d * d; + } + sample_var /= (num_samples / 2); + const double std = std::sqrt(sample_var); + + const double w = x[site_counter] / std; + for (std::size_t i=0; i < out.size(); i++) { + const int h = g[2 * i] + g[2 * i + 1]; + out[i] += w * (h - mu); + } + site_counter++; + } + } else { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + const double w = x[site_counter]; + for (std::size_t i=0; i < out.size(); i++) { + const int h = g[2 * i] + g[2 * i + 1]; + out[i] += w * h; + } + site_counter++; + } + } + return out; + } else { + // Initialize output + std::vector out(num_samples, 0.0); + if (normalize) { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + double ac = 0.0; + for (auto a : g) { + ac += a; + } + + // Normalization constants + double mu = ac / num_samples; + double std = std::sqrt(mu * (1 - mu)); + const double w = x[site_counter] / std; + for (std::size_t i=0; i < out.size(); i++) { + out[i] += w * (g[i] - mu); + } + site_counter++; + } + } else { + while (gi.has_next_genotype()) { + // Fetch the next genotype + const std::vector& g = gi.next_genotype(); + const double w = x[site_counter]; + for (std::size_t i=0; i < out.size(); i++) { + out[i] += w * g[i]; + } + site_counter++; + } + } + return out; + } +} diff --git a/src/ThreadingInstructions.hpp b/src/ThreadingInstructions.hpp index ae0136b..f404be2 100644 --- a/src/ThreadingInstructions.hpp +++ b/src/ThreadingInstructions.hpp @@ -85,6 +85,10 @@ class ThreadingInstructions { ThreadingInstructions sub_range(const int range_start, const int range_end) const; + // Common operations + std::vector left_multiply(const std::vector& x, bool diploid=false, bool normalize=false); + std::vector right_multiply(const std::vector& x, bool diploid=false, bool normalize=false); + public: int start = 0; int end = 0; diff --git a/src/threads_arg_pybind.cpp b/src/threads_arg_pybind.cpp index b8ca12d..c829a20 100644 --- a/src/threads_arg_pybind.cpp +++ b/src/threads_arg_pybind.cpp @@ -140,7 +140,9 @@ PYBIND11_MODULE(threads_arg_python_bindings, m) { .def("sub_range", &ThreadingInstructions::sub_range) .def(py::pickle( &threading_instructions_get_state, - &threading_instructions_set_state)); + &threading_instructions_set_state)) + .def("left_multiply", &ThreadingInstructions::left_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false) + .def("right_multiply", &ThreadingInstructions::right_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false); py::class_(m, "ConsistencyWrapper") .def(py::init>&, const std::vector>&, const std::vector>&, diff --git a/test/test_multiply.py b/test/test_multiply.py new file mode 100644 index 0000000..9b2e1e5 --- /dev/null +++ b/test/test_multiply.py @@ -0,0 +1,117 @@ +# This file is part of the Threads software suite. +# Copyright (C) 2024-2025 Threads Developers. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import numpy as np +import pgenlib +import pytest + +from threads_arg.serialization import load_instructions + +from snapshot_runners import ( + TEST_DATA_DIR +) + +def _col_normalize(x): + z = x.copy() + mu = z.mean(axis=0, keepdims=True) + std = z.std(axis=0, keepdims=True) + return (z - mu) / std + +def test_left_multiply(): + # Read ground truth genotypes + pgen_path = str(TEST_DATA_DIR / "panel.pgen") + reader = pgenlib.PgenReader(str(pgen_path).encode()) + expected_num_variants = reader.get_variant_ct() + num_samples = reader.get_raw_sample_ct() + expected_gt = np.empty((expected_num_variants, 2 * num_samples), dtype=np.int32) + reader.read_alleles_range(0, expected_num_variants, expected_gt) + gt_matrix = expected_gt.transpose() + gt_matrix_dip = gt_matrix[::2] + gt_matrix[1::2] + gt_matrix_norm = _col_normalize(gt_matrix) + gt_matrix_dip_norm = _col_normalize(gt_matrix_dip) + + # Read threading instructions + threads_path = str(TEST_DATA_DIR / "expected_infer_snapshot.threads") + instructions = load_instructions(threads_path) + + # Random vector to multiply with + rng = np.random.default_rng(130222) + x_hap = rng.normal(0, 1, 2 * num_samples) + x_dip = rng.normal(0, 1, num_samples) + + # Make sure length checks are performed + with pytest.raises(RuntimeError): + instructions.left_multiply(x_dip) + with pytest.raises(RuntimeError): + instructions.left_multiply(x_hap, diploid=True) + + # Do normal left-multiplication + expected = x_hap @ gt_matrix + expected_norm = x_hap @ gt_matrix_norm + expected_dip = x_dip @ gt_matrix_dip + expected_dip_norm = x_dip @ gt_matrix_dip_norm + + # Do threads left-multiplication and confirm results are correct + found = instructions.left_multiply(x_hap) + assert np.allclose(expected, found) + found_norm = instructions.left_multiply(x_hap, normalize=True) + assert np.allclose(expected_norm, found_norm) + found_dip = instructions.left_multiply(x_dip, diploid=True) + assert np.allclose(expected_dip, found_dip) + found_dip_norm = instructions.left_multiply(x_dip, normalize=True, diploid=True) + assert np.allclose(expected_dip_norm, found_dip_norm) + +def test_right_multiply(): + # Read ground truth genotypes + pgen_path = str(TEST_DATA_DIR / "panel.pgen") + reader = pgenlib.PgenReader(str(pgen_path).encode()) + expected_num_variants = reader.get_variant_ct() + num_samples = reader.get_raw_sample_ct() + expected_gt = np.empty((expected_num_variants, 2 * num_samples), dtype=np.int32) + reader.read_alleles_range(0, expected_num_variants, expected_gt) + gt_matrix = expected_gt.transpose() + gt_matrix_dip = gt_matrix[::2] + gt_matrix[1::2] + gt_matrix_norm = _col_normalize(gt_matrix) + gt_matrix_dip_norm = _col_normalize(gt_matrix_dip) + + # Read threading instructions + threads_path = str(TEST_DATA_DIR / "expected_infer_snapshot.threads") + instructions = load_instructions(threads_path) + + # Random vector to multiply with + rng = np.random.default_rng(130222) + x = rng.normal(0, 1, expected_num_variants) + x_wrong_length = rng.normal(0, 1, expected_num_variants + 1) + + # Make sure length check is performed + with pytest.raises(RuntimeError): + instructions.left_multiply(x_wrong_length) + + # Do normal right-multiplication + expected = gt_matrix @ x + expected_norm = gt_matrix_norm @ x + expected_dip = gt_matrix_dip @ x + expected_dip_norm = gt_matrix_dip_norm @ x + + # Do threads right-multiplication and confirm results are correct + found = instructions.right_multiply(x) + assert np.allclose(expected, found) + found_norm = instructions.right_multiply(x, normalize=True) + assert np.allclose(expected_norm, found_norm) + found_dip = instructions.right_multiply(x, diploid=True) + assert np.allclose(expected_dip, found_dip) + found_dip_norm = instructions.right_multiply(x, normalize=True, diploid=True) + assert np.allclose(expected_dip_norm, found_dip_norm)