In [None]:
##luigi-vars
FILTERED_HD5 = '/nbi/Research-Groups/JIC/Diane-Saunders/FP_project/FP_pipeline/PST130/data/0.3/Callset/2013/2013_filtered.hd5'
NCPU = 1
MEM_PER_CPU = 1e9

In [None]:
import vcfnp
import numpy as np
import h5py
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
import allel
import seaborn as sns
import pandas as pd
from sklearn.mixture import GaussianMixture
from collections import Counter

import dask
import dask.array as da
from dask.distributed import Client, LocalCluster, progress

import bootstrapped.bootstrap as bootstrap
from bootstrapped.stats_functions import mean

from fieldpathogenomics.utils import reference_dir, index_variants

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12,6)
%matplotlib inline

In [None]:
lc = dask.distributed.LocalCluster(n_workers=NCPU, threads_per_worker=1, memory_limit=MEM_PER_CPU)
client = Client(lc)

In [None]:
callset = h5py.File(FILTERED_HD5, mode='r')
calldata = callset['calldata']
genotypes = allel.GenotypeDaskArray(callset['calldata']['GT'])
samples = np.array(callset['samples']).astype('U')
variants = allel.VariantChunkedTable(callset['variants'])

In [None]:
n_variants = genotypes.shape[0]
n_samples = genotypes.shape[1]

pc_missing = genotypes.count_missing(axis=0) * 100 / n_variants
pc_het = genotypes.count_het(axis=0) * 100 / n_variants
pc_miss_per_site =  genotypes.count_missing(axis=1) * 100/ n_samples

In [None]:
%%time
pc_missing, pc_het, pc_miss_per_site = da.compute(pc_missing, pc_het, pc_miss_per_site)

# Mapping Statistics 

In [None]:
connection_string = "mysql+pymysql://tgac:tgac_bioinf@tgac-db1.hpccluster/buntingd_fieldpathogenomics"
df = pd.read_sql('AlignmentStats', connection_string).apply(pd.to_numeric, args=('ignore',)).drop_duplicates()
df = df.drop_duplicates()
df.set_index("Library").loc[samples].shape

In [None]:
# Select from database the only most relevant mapping statistics
path = ''
for p in df['path']:
    if len(os.path.commonprefix((p, FILTERED_HD5))) >  len(path):
        path = os.path.commonprefix((p, FILTERED_HD5))
df = df[df['path'].str.slice(0, len(path)) == path]

In [None]:
plt.figure(figsize=(16, 9))
sns.distplot(pc_missing, ax=plt.subplot(311))
sns.distplot(pc_miss_per_site, ax=plt.subplot(312))
sns.distplot(df.set_index("Library").loc[samples]['mapped_reads'], ax=plt.subplot(313))


plt.subplot(312).set_xlabel("Frac missing")
plt.subplot(311).set_xlabel("Frac missing")
plt.subplot(312).set_title("Per Site", fontsize=24)
plt.subplot(311).set_title("Per Library", fontsize=24)
plt.subplot(313).set_title("Mapped Reads", fontsize=24)

plt.tight_layout()

# Site and Sample Coverage Thresholds

In [None]:
%%time
percentiles, N80, Nbases, het_rates, het_std = [], [], [], [], []
sample_cov = np.linspace(100-pc_missing.max(), 100-pc_missing.min(), 10)
site_thresholds = np.linspace(0.1, 0.9, 9)

site_coverage, site_het, site_called = [], [], []

for i,x in enumerate(sample_cov):
    # Apply a filter at min sample coverage x
    filtered = genotypes[:, pc_missing <= (100 - x)]
    
    # Calculate site level statistics at this filter level
    site_called.append(filtered.count_called(axis=1))
    site_coverage.append(site_called[i]/filtered.shape[1])
    site_het.append(filtered.is_het().sum(axis=1)/site_called[i])

In [None]:
%%time
site_called = client.persist(site_called)
site_coverage = client.persist(site_coverage)
site_het = client.persist(site_het)

In [None]:
%%time
for i,x in enumerate(sample_cov):
    # Summerise site level stats
    percentiles.append(da.percentile(site_coverage[i], np.linspace(0, 100, 20)))
    Nbases.append( da.stack([site_called[i][site_coverage[i] > t].sum() for t in site_thresholds]) )
    N80.append((site_coverage[i] > 0.8).sum())

In [None]:
%time
futures = client.compute((percentiles, N80, Nbases), optimize_graph=False)

In [None]:
%%time
het_rates = []
for i,x in enumerate(sample_cov):
    het = [site_het[i][site_coverage[i] > t] for t in site_thresholds]
    het_rates.append( [dask.delayed(bootstrap.bootstrap)(h, mean, num_iterations=100, iteration_batch_size=1) for h in het])

In [None]:
het_futures = client.compute(het_rates)

In [None]:
percentiles, N80, Nbases = client.gather(futures)
het_rates = client.gather(het_futures)

In [None]:
fig, ax = plt.subplots(2,2)
ax1,ax2,ax3,ax4 = ax.flatten()
fig.set_size_inches((16, 12))

ax1.plot(100-np.sort(pc_missing), np.arange(len(pc_missing)), '.')
ax1.set_xlabel("Sample minimum Coverage %")
ax1.set_ylabel("Number of accepted samples")

ax2.plot(sample_cov, N80, '.-')
ax2.set_xlabel("Sample minimum Coverage %")
ax2.set_ylabel("N sites at >80% coverage");

pal = sns.cubehelix_palette(len(site_thresholds), start=0.5, rot=-1)
for i, t in enumerate(site_thresholds):
    ax3.plot(sample_cov, np.array(Nbases)[:,i], '.-', label = str(t), color=pal[i])
    ax4.errorbar(x=sample_cov, y=[x.value for x in np.array(het_rates)[:,i]],
                 yerr=np.transpose([[x.value-x.lower_bound, x.upper_bound-x.value] for x in np.array(het_rates)[:,i]]),
                 fmt='.-', label = str(t), color=pal[i], capsize=10, capthick=3)

ax3.set_xlabel("Sample minimum Coverage %")
ax3.set_ylabel("Bases Called");
ax3.legend(loc='best', title="Min site coverage")

ax4.set_xlabel("Sample minimum Coverage %")
ax4.set_ylabel("Heterozygosity");
#ax4.legend(loc='best', title="Min site coverage")

In [None]:
plt.figure(figsize=(18,8))

ax1 = plt.subplot(121)
pal = sns.cubehelix_palette(len(sample_cov), start=.5, rot=-.75)

for i, (x, p) in enumerate(zip(sample_cov[:-2], percentiles[:-2])):
    ax1.plot(p, 100-np.linspace(0, 100, 20), '.-', 
             label="coverage {0:.1f}%, {1} samples ".format(x, np.sum(pc_missing <= (100 - x))), 
             color=pal[i])


ax2 = plt.subplot(122)    
for i, t in enumerate(sample_cov[:-2]):
    ax2.errorbar(x=site_thresholds, y=[x.value for x in np.array(het_rates)[i]],
                 yerr=np.transpose([[x.value-x.lower_bound, x.upper_bound-x.value] for x in np.array(het_rates)[i]]),
                 fmt='.-', label = str(t), color=pal[i], capsize=10, capthick=3)

ax2.set_xlabel("Min site Coverage %")
ax2.set_ylabel("Heterozygosity");

ax1.legend(loc='best', title='Sample Coverage')
ax1.set_xlabel("Min Site Coverage")
ax1.set_ylabel("Percent sites at min coverage")

# Called sites vs Mapped reads

In [None]:
plt.figure(figsize=(18,8))

plt.subplot(121).plot(df.set_index("Library").loc[samples]['mapped_reads_pc'], 100-pc_missing, '.')
plt.subplot(121).set_xlabel("% reads mapped")
plt.subplot(121).set_ylabel("% sites called")

plt.subplot(122).plot(df.set_index("Library").loc[samples]['mapped_reads'], 100-pc_missing, '.')
plt.subplot(122).set_xlabel("Mapped reads")
plt.subplot(122).set_ylabel("% sites called")
plt.tight_layout()