In [1]:
import numba

In [2]:
import zarr
import allel
import pandas as pd
import yaml
from pathlib import Path
import dask.array as da
import numpy as np

  data = yaml.load(f.read()) or {}


In [3]:
from scipy.special import comb
import scipy.spatial.distance as dist

In [4]:
import gcsfs
import logging

# needed for analysis
from itertools import combinations

In [5]:
sampleset = 'AG1000G-UG'

In [6]:
prod_bucket = Path('vo_agam_production')
sampleset_path = prod_bucket / 'combined_samplesets' / sampleset
storage_path = sampleset_path / 'callset.zarr'
storage_path = storage_path.as_posix()

In [7]:
# create a GCSFileSystem just for the purpose of authentication
gcs_orig = gcsfs.GCSFileSystem(project='malariagen-jupyterhub', token='cache')

In [8]:
# create another one with those credentials
gcs =  gcsfs.GCSFileSystem(project='malariagen-jupyterhub', token=gcs_orig.session.credentials)

In [9]:
# now use this to open the mapping to the zarr store
gcsmap = gcsfs.mapping.GCSMap(storage_path, gcs=gcs)

In [10]:
calldata = zarr.Group(gcsmap, read_only=True)

In [11]:
calldata

<zarr.hierarchy.Group '/' read-only>

In [12]:
gt = allel.GenotypeArray(calldata['3L']['calldata/GT'][:1000000, :2])
gt

Unnamed: 0,0,1,Unnamed: 3
0,./.,./.,
1,./.,./.,
2,./.,./.,
...,...,...,...
999997,./.,./.,
999998,./.,./.,
999999,./.,./.,


In [13]:
%%time
gac = gt.to_allele_counts(max_allele=3)

CPU times: user 171 ms, sys: 15.1 ms, total: 186 ms
Wall time: 185 ms


In [14]:
def ma_dist(x1, x2):
    
    assert x1.shape == x2.shape, (x1.shape, x2.shape)
    assert x1.shape[1] == 4, x1.shape
    
    r = np.abs(x1 - x2).sum(axis=1)
    
    # need to exclude zeroes
    # if either is 0 count as 0
    
    q = (x1.sum(axis=1) == 0) | (x2.sum(axis=1) == 0)
    
    r = np.where(q, [0], r)
    
    #assert r.shape == (npairs, ), r.shape
    
    return r.sum()

In [15]:
x1 = gac[:, 0, :]
x2 = gac[:, 1, :]

In [16]:
@numba.njit(numba.float64(numba.uint8[:, :], numba.uint8[:, :]))
def ma_dist_opt(x1, x2):
    out = 0
    for i in range(x1.shape[0]):
        x1_called = False
        x2_called = False
        d = 0
        for j in range(x1.shape[1]):
            c1 = x1[i, j]
            c2 = x2[i, j]
            if c1 > 0:
                x1_called = True
            if c2 > 0:
                x2_called = True
            if c2 > c1:
                diff = float(c2) - float(c1)
            else:
                diff = float(c1) - float(c2)
            d += diff
        if x1_called and x2_called:
            out += d
    return out


In [17]:
def ma_countable(x1, x2):
    return np.count_nonzero(x1.sum(axis=1) * x2.sum(axis=1))

In [18]:
@numba.njit(numba.int64(numba.uint8[:, :], numba.uint8[:, :]))
def ma_countable_opt(x1, x2):
    out = 0
    for i in range(x1.shape[0]):
        x1_called = False
        x2_called = False
        for j in range(x1.shape[1]):
            c1 = x1[i, j]
            c2 = x2[i, j]
            if c1 > 0:
                x1_called = True
            if c2 > 0:
                x2_called = True
        if x1_called and x2_called:
            out += 1
    return out


In [19]:
%%time
ma_dist(x1.astype('i1'), x2.astype('i1'))

CPU times: user 97.8 ms, sys: 27 ms, total: 125 ms
Wall time: 175 ms


19630

In [20]:
%%time
ma_dist_opt(x1, x2)

CPU times: user 11.4 ms, sys: 1.06 ms, total: 12.4 ms
Wall time: 12.2 ms


19630.0

In [21]:
%%time
ma_countable(x1, x2)

CPU times: user 51.5 ms, sys: 10.5 ms, total: 61.9 ms
Wall time: 103 ms


559285

In [22]:
%%time
ma_countable_opt(x1, x2)

CPU times: user 6.16 ms, sys: 19 µs, total: 6.18 ms
Wall time: 6.03 ms


559285