Skip to content

Commit

Permalink
fix: static block size rather than dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
amanas committed May 19, 2022
1 parent 4595ee2 commit 6b81af0
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions src/dnarecords/writer.py
Expand Up @@ -62,6 +62,7 @@ class DNARecordsWriter:
...
:param expr: a Hail expression. Currently, ony expressions coercible to numeric are supported
:param block_size: block size to handle transposing the matrix
:param staging: path to staging directory to use for intermediate data. Default: /tmp/dnarecords/staging.
"""
from typing import TYPE_CHECKING
Expand All @@ -74,15 +75,15 @@ class DNARecordsWriter:
_j_blocks: set
_nrows: int
_ncols: int
_sparsity: float
_chrom_ranges: dict
_mt: 'MatrixTable'
_skeys: 'DataFrame'
_vkeys: 'DataFrame'

def __init__(self, expr: 'Expression', staging: str = '/tmp/dnarecords/staging'):
def __init__(self, expr: 'Expression', block_size=(10000, 10000), staging: str = '/tmp/dnarecords/staging'):
self._assert_expr_type(expr)
self._expr = expr
self._block_size = block_size
self._kv_blocks_path = f'{staging}/kv-blocks'
self._vw_dna_staging = f'{staging}/vw-dnaparquet'
self._sw_dna_staging = f'{staging}/sw-dnaparquet'
Expand Down Expand Up @@ -143,25 +144,12 @@ def _set_max_nrows_ncols(self):
self._nrows = self._mt.count_rows()
self._ncols = self._mt.count_cols()

def _set_sparsity(self):
mts = self._mt.head(10000, None)
entries = mts.key_cols_by().key_rows_by().entries().to_spark().filter('v is not null').count()
self._sparsity = entries / (mts.count_rows() * mts.count_cols())

def _get_block_size(self):
import math
M, N, S = self._nrows + 1, self._ncols + 1, self._sparsity + 1e-6
B = 1e7 / S # Recommended # entries per block
m = math.ceil(M / math.sqrt(B * M / N))
n = math.ceil(N / math.sqrt(B * N / M))
return m, n

def _build_ij_blocks(self):
import pyspark.sql.functions as F
m, n = self._get_block_size()
m, n = self._block_size
df = self._mt.key_cols_by().key_rows_by().entries().to_spark().filter('v is not null')
df = df.withColumn('ib', F.col('i') % m)
df = df.withColumn('jb', F.col('j') % n)
df = df.withColumn('ib', F.floor(F.col('i')/F.lit(m)))
df = df.withColumn('jb', F.floor(F.col('j')/F.lit(n)))
df.write.partitionBy('ib', 'jb').mode('overwrite').parquet(self._kv_blocks_path)

def _set_ij_blocks(self):
Expand Down Expand Up @@ -341,7 +329,6 @@ def write(self, output: str, sparse: bool = True, sample_wise: bool = True, vari
if sparse:
self._filter_out_zeroes()
self._set_max_nrows_ncols()
self._set_sparsity()
self._build_ij_blocks()
self._set_ij_blocks()

Expand Down

0 comments on commit 6b81af0

Please sign in to comment.