diff --git a/haptools/data/haplotypes.py b/haptools/data/haplotypes.py index ab58e66c..e243b15d 100644 --- a/haptools/data/haplotypes.py +++ b/haptools/data/haplotypes.py @@ -1017,32 +1017,31 @@ def transform( [(hap.id, hap.chrom, hap.start, 0, "A", "T") for hap in self.data.values()], dtype=hap_gts.variants.dtype, ) - # index the genotypes for fast look-ups of the variant IDs - var_IDs = tuple(vID for hap in self.data.values() for vID in hap.varIDs) - gts = gts.subset(variants=var_IDs) - # initialize arrays needed for proper broadcasting - shape = (1, gts.data.shape[1], 1) - # and a np array denoting the allele integer in each haplotype - allele_arr = np.zeros(shape, dtype=gts.data.dtype) - hap_size = {} - end = 0 + # build a fast data structure for querying the alleles in each haplotype: + # a dict mapping (variant ID, allele) -> a unique index + alleles = {} + # and a dict mapping hap ID -> an array with the indices of the hap's alleles + idxs = {} + count = 0 for hap in self.data.values(): - start = end - end = start + len(hap.varIDs) - hap_size[hap.id] = (start, end) - # fill out arrays -- iterate through each haplotype - for hap in self.data.values(): - start, end = hap_size[hap.id] - allele_arr[0, start:end, 0] = np.array([ - int(var.allele != gts.variants[start+j]["ref"]) - for j, var in enumerate(hap.variants) - ]) + idxs[hap.id] = np.empty(len(hap.variants), dtype=np.uintc) + for i, variant in enumerate(hap.variants): + key = (variant.id, variant.allele) + if key not in alleles: + alleles[key] = count + count += 1 + idxs[hap.id][i] = alleles[key] + self.log.debug(f"Copying genotypes for {len(alleles)} distinct alleles") + gts = gts.subset(variants=tuple(k[0] for k in alleles)) + self.log.debug(f"Creating array denoting alt allele status") + # initialize a np array denoting the allele integer in each haplotype + # with shape (1, gts.data.shape[1], 1) for broadcasting later + allele_arr = np.array([ + int(allele != gts.variants[i]["ref"]) + for i, (vID, allele) in enumerate(alleles) + ], dtype=gts.data.dtype)[np.newaxis, :, np.newaxis] # finally, obtain and merge the haplotype genotypes - self.log.info( - f"Transforming a set of genotypes of {len(gts.variants)} total variants " - f"in {len(self.data)} haplotypes" - - ) + self.log.info(f"Transforming genotypes for {len(self.data)} haplotypes") equality_arr = np.equal(allele_arr, gts.data) self.log.debug( f"Allocating array with dtype {gts.data.dtype} and size " @@ -1051,7 +1050,7 @@ def transform( hap_gts.data = np.empty( (gts.data.shape[0], len(self.data), 2), dtype=gts.data.dtype ) + self.log.debug("Computing haplotype genotypes. This may take a while") for i, hap in enumerate(self.data.values()): - start, end = hap_size[hap.id] - hap_gts.data[:, i] = np.all(equality_arr[:, start:end], axis=1).astype(gts.data.dtype) + hap_gts.data[:, i] = np.all(equality_arr[:, idxs[hap.id]], axis=1) return hap_gts diff --git a/haptools/transform.py b/haptools/transform.py index f5388afc..26e7c5bb 100644 --- a/haptools/transform.py +++ b/haptools/transform.py @@ -178,4 +178,5 @@ def transform_haps( log.info(f"Writing haplotypes to {out_file_type} file") hp_gt.write() + log.debug("Done!") return hp_gt