Skip to content

Commit

Permalink
Merge pull request #12 from jrm5100/master
Browse files Browse the repository at this point in the history
Speed up plink loading by 3x-4x
  • Loading branch information
jrm5100 committed Apr 12, 2021
2 parents dc9c718 + f5f9044 commit 5d6e9df
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 61 deletions.
146 changes: 86 additions & 60 deletions pandas_genomics/io/plink.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import pandas as pd
import numpy as np
from ..arrays import GenotypeDtype, GenotypeArray
from ..scalars import Variant, MISSING_IDX, Genotype
from ..scalars import Variant, MISSING_IDX


def from_plink(
bed_file: str, swap_alleles: bool = False, max_variants: Optional[int] = None
):
"""
Load genetic data from plink files (.bed, .bim, and .fam) into a DataFrame.
Load genetic data from plink v1 files (.bed, .bim, and .fam) into a DataFrame.
Parameters
----------
Expand All @@ -31,7 +31,7 @@ def from_plink(
Notes
-----
Plink files encode all variants as diploid (2n) and utilize "missing" alleles if the variant is actually haploid
Plink v1 files encode all variants as diploid (2n) and utilize "missing" alleles if the variant is actually haploid
Examples
--------
Expand All @@ -52,17 +52,35 @@ def from_plink(

print(f"Loading genetic data from '{bed_file.stem}'")

# Load fam file (PLINK sample information file)
# Load fam file
df = load_sample_info(fam_file)
# Lod bim file
variant_list = load_variant_info(bim_file, max_variants) # Load bed file
gt_array_dict = load_genotypes(
bed_file, variant_list, num_samples=len(df), swap_alleles=swap_alleles
)

# Merge with sample allele index
df = pd.concat([df, pd.DataFrame.from_dict(gt_array_dict)], axis=1)
df = df.set_index(["FID", "IID", "IID_father", "IID_mother", "sex", "phenotype"])

return df


def load_sample_info(fam_file):
"""Load fam file (PLINK sample information file) into a df"""
df = pd.read_table(fam_file, header=None, sep=" ")
df.columns = ["FID", "IID", "IID_father", "IID_mother", "sex", "phenotype"]
# Update 'sex'
df["sex"] = df["sex"].astype("category")
df["sex"] = df["sex"].cat.rename_categories({1: "male", 2: "female", 0: "unknown"})
# 'Phenotype' is not encoded because it may be more complicated than just case/control
num_samples = len(df)
print(f"\tLoaded information for {num_samples} samples from '{fam_file.name}'")
print(f"\tLoaded information for {len(df)} samples from '{fam_file.name}'")
return df

# Load bim file (PLINK extended MAP file)

def load_variant_info(bim_file, max_variants):
"""Load bim file (PLINK extended MAP file) into a list of variants"""
variant_info = pd.read_table(bim_file, header=None, sep="\t")
# Note 'position' is in centimorgans, 'coordinate' is what pandas-genomics refers to as 'position' (in base-pairs)
variant_info.columns = [
Expand All @@ -75,18 +93,43 @@ def from_plink(
]
# chromosome is a category
variant_info["chromosome"] = variant_info["chromosome"].astype("category")
num_variants = len(variant_info)

# Limit num_variants
if max_variants is not None:
if max_variants < 1:
raise ValueError(f"'max_variants' set to an invalid value: {max_variants}")
else:
num_variants = max_variants

print(f"\tLoaded information for {num_variants} variants from '{bim_file.name}'")

# Load bed file (PLINK binary biallelic genotype table)
variant_info = variant_info.iloc[:max_variants]
variant_list = [create_variant(row) for idx, row in variant_info.iterrows()]
print(
f"\tLoaded information for {len(variant_list)} variants from '{bim_file.name}'"
)
return variant_list


def create_variant(variant_info_row):
variant_id = str(variant_info_row["variant_id"])
a1 = str(variant_info_row["allele1"])
a2 = str(variant_info_row["allele2"])
# 0 indicates a missing allele
if a2 == "0":
a2 = None
if a1 == "0":
a1 = None
else:
a1 = [a1] # pass as list
variant = Variant(
chromosome=str(variant_info_row["chromosome"]),
position=int(variant_info_row["coordinate"]),
id=variant_id,
ref=a2,
alt=a1,
ploidy=2,
)
return variant


def load_genotypes(bed_file, variant_list, num_samples, swap_alleles):
"""Load bed file (PLINK binary biallelic genotype table) into a dictionary of name:GenotypeArray"""
gt_bytes = np.fromfile(bed_file, dtype="uint8")
# Ensure the file is valid
CORRECT_FIRST_BYTES = np.array([108, 27, 1], dtype="uint8")
Expand All @@ -96,57 +139,40 @@ def from_plink(
)
gt_bytes = gt_bytes[3:]
# Divide array into one row per variant
# (have to reshape using num_samples since num_variants may be set lower)
chunk_size = num_samples // 4
if num_samples % 4 > 0:
chunk_size += 1
gt_bytes = gt_bytes.reshape(-1, chunk_size)
# Process each variant
for v_idx in range(num_variants):
variant_info_dict = variant_info.iloc[v_idx].to_dict()
variant_id = str(variant_info_dict["variant_id"])
a1 = str(variant_info_dict["allele1"])
a2 = str(variant_info_dict["allele2"])
# 0 indicates a missing allele
if a2 == "0":
a2 = None
if a1 == "0":
a1 = None
else:
a1 = [a1] # pass as list
variant = Variant(
chromosome=str(variant_info_dict["chromosome"]),
position=int(variant_info_dict["coordinate"]),
id=variant_id,
ref=a2,
alt=a1,
ploidy=2,
)
# Each byte (8 bits) is a concatenation of two bits per sample for 4 samples
# These are ordered from right to left, like (sample4, sample3, sample2, sample1)
# Convert each byte into 4 2-bits and flip them to order samples correctly
genotypes = np.flip(np.unpackbits(gt_bytes[v_idx]).reshape(-1, 4, 2), axis=1)
# flatten the middle dimension to give a big list of genotypes in the correct order and
# remove excess genotypes at the end that are padding rather than real samples
genotypes = genotypes.reshape(-1, 2)[:num_samples]
# Replace 0, 1 with missing (1, 0 is heterozygous)
missing_gt = (genotypes == (0, 1)).all(axis=1)
genotypes[missing_gt] = (MISSING_IDX, MISSING_IDX)
# Replace 1, 0 with 0, 1 for heterozygous so the reference allele is first
het_gt = (genotypes == (1, 0)).all(axis=1)
genotypes[het_gt] = (0, 1)
# Create GenotypeArray representation of the data
dtype = GenotypeDtype(variant)
scores = np.empty(num_samples)
scores[:] = np.nan
data = np.array(list(zip(genotypes, scores)), dtype=dtype._record_type)
gt_array = GenotypeArray(values=data, dtype=dtype)
gt_array_dict = {}
for v_idx, variant in enumerate(variant_list):
variant_gt_bytes = gt_bytes[v_idx]
gt_array = create_gt_array(num_samples, variant_gt_bytes, variant)
if swap_alleles:
gt_array.set_reference(1)
df[f"{v_idx}_{variant_id}"] = gt_array
gt_array_dict[f"{v_idx}_{gt_array.variant.id}"] = gt_array
print(f"\tLoaded genotypes from '{bed_file.name}'")

# Set sample info as the index
df = df.set_index(["FID", "IID", "IID_father", "IID_mother", "sex", "phenotype"])

return df
return gt_array_dict


def create_gt_array(num_samples, variant_gt_bytes, variant):
# Each byte (8 bits) is a concatenation of two bits per sample for 4 samples
# These are ordered from right to left, like (sample4, sample3, sample2, sample1)
# Convert each byte into 4 2-bits and flip them to order samples correctly
genotypes = np.flip(np.unpackbits(variant_gt_bytes).reshape(-1, 4, 2), axis=1)
# flatten the middle dimension to give a big list of genotypes in the correct order and
# remove excess genotypes at the end that are padding rather than real samples
genotypes = genotypes.reshape(-1, 2)[:num_samples]
# Replace 0, 1 with missing (1, 0 is heterozygous)
missing_gt = (genotypes == (0, 1)).all(axis=1)
genotypes[missing_gt] = (MISSING_IDX, MISSING_IDX)
# Replace 1, 0 with 0, 1 for heterozygous so the reference allele is first
het_gt = (genotypes == (1, 0)).all(axis=1)
genotypes[het_gt] = (0, 1)
# Create GenotypeArray representation of the data
dtype = GenotypeDtype(variant)
scores = np.empty(num_samples)
scores[:] = np.nan
data = np.array(list(zip(genotypes, scores)), dtype=dtype._record_type)
gt_array = GenotypeArray(values=data, dtype=dtype)
return gt_array
1 change: 0 additions & 1 deletion tests/io/test_plink.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

def test_loaded_small():
"""Validate the small dataset"""
# was 5.76 seconds
bed_file = DATA_DIR / "plink_test_small.bed"
result = io.from_plink(bed_file)
assert result.shape == (150, 3020)
Expand Down

0 comments on commit 5d6e9df

Please sign in to comment.