diff --git a/haptools/__main__.py b/haptools/__main__.py index e50697a8..57a64f49 100755 --- a/haptools/__main__.py +++ b/haptools/__main__.py @@ -339,14 +339,6 @@ def simphenotype( show_default="all variants", help="If using a PGEN file, read genotypes in chunks of X variants; reduces memory", ) -@click.option( - "-C", - "--haps-chunk-size", - type=int, - default=None, - show_default="all haplotypes", - help="Transform in chunks of X haplotypes; reduces memory", -) @click.option( "--discard-missing", is_flag=True, @@ -379,7 +371,6 @@ def transform( ids: tuple[str] = tuple(), ids_file: Path = None, chunk_size: int = None, - haps_chunk_size: int = None, discard_missing: bool = False, output: Path = Path("-"), verbosity: str = 'CRITICAL', @@ -423,7 +414,7 @@ def transform( ids = None transform_haps( - genotypes, haplotypes, region, samples, ids, chunk_size, haps_chunk_size, + genotypes, haplotypes, region, samples, ids, chunk_size, discard_missing, output, log ) diff --git a/haptools/data/haplotypes.py b/haptools/data/haplotypes.py index 9afaa1e2..ab58e66c 100644 --- a/haptools/data/haplotypes.py +++ b/haptools/data/haplotypes.py @@ -989,7 +989,6 @@ def transform( self, gts: GenotypesRefAlt, hap_gts: GenotypesRefAlt = None, - chunk_size: int = None, ) -> GenotypesRefAlt: """ Transform a genotypes matrix via the current haplotype @@ -1004,11 +1003,6 @@ def transform( hap_gts: GenotypesRefAlt An empty GenotypesRefAlt object into which the haplotype genotypes should be stored - chunk_size: int, optional - The max number of haplotypes to transform at any given time - - If this value is provided, haplotypes will be transformed in chunks so as - to use less memory Returns ------- @@ -1023,67 +1017,41 @@ def transform( [(hap.id, hap.chrom, hap.start, 0, "A", "T") for hap in self.data.values()], dtype=hap_gts.variants.dtype, ) - # how many haplotypes should we transform at once? - chunks = chunk_size - if chunks is None or chunks > len(self.data): - chunks = len(self.data) + # index the genotypes for fast look-ups of the variant IDs + var_IDs = tuple(vID for hap in self.data.values() for vID in hap.varIDs) + gts = gts.subset(variants=var_IDs) # initialize arrays needed for proper broadcasting - shape = (1, len(self.data), gts.data.shape[1], 1) - # create a np mask array denoting which alleles belong to each haplotype - idx = np.zeros(shape, dtype=np.bool_) + shape = (1, gts.data.shape[1], 1) # and a np array denoting the allele integer in each haplotype allele_arr = np.zeros(shape, dtype=gts.data.dtype) - # index the genotypes for fast look-ups of the variant IDs - gts.index(variants=True) + hap_size = {} + end = 0 + for hap in self.data.values(): + start = end + end = start + len(hap.varIDs) + hap_size[hap.id] = (start, end) # fill out arrays -- iterate through each haplotype - for i, hap in enumerate(self.data.values()): - try: - # obtain the indices of each variant ID - ids = [gts._var_idx[vID] for vID in hap.varIDs] - except KeyError: - # check: were any of the variants absent from the genotypes? - missing_IDs = set(hap.varIDs) - set(gts.variants["id"]) - raise ValueError( - f"Variants {missing_IDs} are present in haplotype '{hap.id}' but " - "absent in the provided genotypes" - ) - idx[0, i, ids, 0] = True - allele_arr[0, i, ids, 0] = np.array([ - int(var.allele != gts.variants[j]["ref"]) - for j, var in zip(ids, hap.variants) + for hap in self.data.values(): + start, end = hap_size[hap.id] + allele_arr[0, start:end, 0] = np.array([ + int(var.allele != gts.variants[start+j]["ref"]) + for j, var in enumerate(hap.variants) ]) # finally, obtain and merge the haplotype genotypes self.log.info( - f"Transforming a set of genotypes of {len(gts.variants)} variants " - f"in {len(self.data)} haplotypes in chunks of size {chunks} haplotypes" + f"Transforming a set of genotypes of {len(gts.variants)} total variants " + f"in {len(self.data)} haplotypes" ) + equality_arr = np.equal(allele_arr, gts.data) self.log.debug( - f"Attempting to create array with dtype {gts.data.dtype} and size " + f"Allocating array with dtype {gts.data.dtype} and size " f"{(len(gts.samples), len(self.data), 2)}" ) hap_gts.data = np.empty( (gts.data.shape[0], len(self.data), 2), dtype=gts.data.dtype ) - for start in range(0, len(self.data), chunks): - end = start + chunks - if end > len(self.data): - end = len(self.data) - size = end - start - self.log.debug(f"Loading from haplotype #{start} to haplotype #{end}") - # allele_arr has shape (1, h, p, 1) and contains allele ints - # gts.data[:, np.newaxis] has shape (n, 1, p, 2) and contains allele ints - # idx has shape (1, h, p, 1) and contains a bool mask - # n = # of samples h = # of haplotypes p = # of variants - # Note that these shapes are broadcasted, so the result has shape (n, h, 2) - try: - hap_gts.data[:, start:end] = np.all( - np.equal(allele_arr[:, start:end], gts.data[:, np.newaxis]), - axis=2, where=idx[:, start:end], - ).astype(gts.data.dtype) - except np.core._exceptions._ArrayMemoryError as e: - raise ValueError( - "You don't have enough memory to transform these haplotypes! Try" - " specifying a value to the chunks_size parameter, instead" - ) from e + for i, hap in enumerate(self.data.values()): + start, end = hap_size[hap.id] + hap_gts.data[:, i] = np.all(equality_arr[:, start:end], axis=1).astype(gts.data.dtype) return hap_gts diff --git a/haptools/transform.py b/haptools/transform.py index eb8ce3de..f5388afc 100644 --- a/haptools/transform.py +++ b/haptools/transform.py @@ -83,7 +83,6 @@ def transform_haps( samples: list[str] = None, haplotype_ids: set[str] = None, chunk_size: int = None, - haps_chunk_size: int = None, discard_missing: bool = False, output: Path = Path("-"), log: Logger = None, @@ -112,11 +111,6 @@ def transform_haps( If this value is provided, variants from the PGEN file will be loaded in chunks so as to use less memory. This argument is ignored if the genotypes are not in PGEN format. - haps_chunk_size: int, optional - The max number of haplotypes to transform together at any given time - - If this value is provided, haplotypes from the .hap file will be transformed in - chunks so as to use less memory. discard_missing : bool, optional Discard any samples that are missing any of the required genotypes @@ -179,7 +173,7 @@ def transform_haps( out_file_type = "VCF/BCF" hp_gt = data.GenotypesRefAlt(fname=output, log=log) log.info("Transforming genotypes via haplotypes") - hp.transform(gt, hp_gt, chunk_size=haps_chunk_size) + hp.transform(gt, hp_gt) log.info(f"Writing haplotypes to {out_file_type} file") hp_gt.write()