In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import itertools as it
import pandas as pd
import numpy as np
import pysam


class Genotypes(object):
    """A class for storing genotype data which at its core 
    consists of a dense p x n genotype matrix, a dataframe 
    of metadata associated which each snp, and a list of 
    individual ids. We also provide computation options for 
    setting up dataset for commonly using analyses
    
    Arguments
    ---------
    Y : np.array
        p x n genotype matrix 
        
    snp_df : pd.DataFrame
        Dataframe storing snp level meta data
        
    inds : list
        list of individual ids
    
    Attributes
    ----------
    Y : np.array
        p x n genotype matrix 
        
    snp_df : pd.DataFrame
        Dataframe storing snp level meta data
        
    inds : list
        list of individual ids
        
    n : int
        number of individuals (samples)
        
    p : int 
        number of snps (features)
        
    base_complement : dict
        dictionary where keys are bases and values
        are the bases complement
        
    bases : list
        list of bases
        
    strand_ambiguous_alleles : dict
        dictionary of strand ambiguous alleles 
        combinations as keys and bools as values
    """
    def __init__(self, Y=None, snp_df=None, inds=None):

        # p x n genotype matrix
        self.Y = Y

        # DataFrame of SNP level information
        self.snp_df = snp_df
        
        # list of individual ids
        self.inds = inds

        if ((type(Y) != type(None)) and 
            (type(snp_df) != type(None)) and 
            (type(inds) != type(None))):
            # number of individuals
            self.n = len(self.inds)

            # number of snps
            self.p = self.snp_df.shape[0]
        else:
            # number of individuals
            self.n = None

            # number of snps
            self.p = None

        # below is adapted from 
        # https://github.com/bulik/ldsc/blob/master/munge_sumstats.py

        # complement of each base
        self.base_complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}

        # bases 
        self.bases = self.base_complement.keys()

        # dict of allele combinations that are strand ambiguous
        self.strand_ambiguous_alleles = {''.join(x): x[0] == self.base_complement[x[1]]
                                         for x in it.product(self.bases, self.bases)
                                         if x[0] != x[1]} 

    def remove_chrom_prefix(self):
        """Remove the chromsome prefix from the snp_df
        this helps for making datasets consistent before 
        merging
        """
        self.snp_df["CHROM"] = self.snp_df["CHROM"].apply(lambda x: int(x[3:]))

    def keep_autosomes(self):
        """Keep snps only present on the autosomes
        """
        # filter out snps not on the autosome
        self.snp_df = self.snp_df[self.snp_df["CHROM"] <= 22]

    def remove_rare_common(self, eps):
        """Removes SNPs that are too rare or 
        too common given some threshold eps
        Arguments
        ---------
        eps : float
            frequency threshold eps
        """
        # compute allele frequencies at each SNP
        f = np.nansum(self.Y, axis=1) / (2 * self.n)

        # get the indicies of SNPs that dont pass the thresh
        idx = np.where((f >= eps) & (f <= (1. - eps)))[0]
        
        # keep only these SNPs 
        self.snp_df = self.snp_df.iloc[idx]
    
    def remove_strand_ambiguous_alleles(self):
        """Removes any snps that have alleles that
        are strand ambiguous. 
        TODO: speed this up its currently a bit slow
        also add removal of incosistent alleles
        """
        # lambda function for applying to each snp
        strand_unamb_func = lambda row: not self.strand_ambiguous_alleles[row["A1"] + row["A2"]]

        # snps that have unambigous strand
        idx = self.snp_df.apply(strand_unamb_func, axis=1)

        # index the the snps
        self.snp_df = self.snp_df[idx]
    
    def reindex(self):
        """Reset the index column in the pandas dataframe
        """
        # get indicies of snps to keep
        idx = self.snp_df["idx"].tolist()

        # reindex data matrix
        self.Y = self.Y[idx, :] 

        # re-compute number of snps
        self.p = self.snp_df.shape[0]

        # reset index column
        self.snp_df["idx"] = np.arange(0, self.p, 1)
    
    def binarize(self):
        """Convert genotype data to binary
        by mapping 2->1, 0->0 and randomly 
        selecting a 0 or 1 for heterozygotes.
        This emulates the read sampled data found
        in ancient DNA studies
        """
        # randomly sample hets
        p_het = np.sum(self.Y[self.Y == 1.])
        if p_het > 0:
            samps = np.random.binomial(1, .5, int(p_het)).astype(np.float32) 
            np.place(self.Y, self.Y == 1., samps)

        # set 2s to 1s
        self.Y[self.Y == 2.] = 1.

    def normalize(self, scale, impute):
        """Mean center the data so each snp has mean 0. 
        If scale true then scale each SNP to have var 1. 
        If impute is true it sets missing sites to 0. i.e. 
        the approx mean. Note for all computations we 
        ignore nans.
        Arguments
        ---------
        scale : bool
            scale the data to have variance 1
        impute : bool
            predict missing sites with the expected mean 0
        """
        if scale:
            # compute mean genotype for every SNP
            mu = np.nanmean(self.Y, axis=1)

            # compute std dev for every SNP
            std = np.nanstd(self.Y, axis=1)

            # center and scale
            Zt = (self.Y.T - mu) / std
            self.Z = Zt.T
        else:
            # compute mean genotype for every SNP
            mu = np.nanmean(self.Y, axis=1)

            # center
            Zt = self.Y.T - mu
            self.Z = Zt.T

        if impute:
            # set nans to zero
            self.Z = np.nan_to_num(self.Z)