# Common utility functions

This notebook has common utility functions that are independent of whether we're running on datalab-abafar, datalab-bespin or a local computer.

In [None]:
# standard library imports
import sys
import os
import re
from pathlib import Path
import base64
from collections import namedtuple, defaultdict
from collections.abc import Mapping
import functools
import random
import subprocess
import traceback

# third party imports
import pandas as pd
from pyprojroot import here
import dask
import dask.array as da
from dask import delayed
import distributed
from distributed import LocalCluster
from dask.diagnostics import ProgressBar
import zarr
import fsspec
from fsspec.implementations.zip import ZipFileSystem
import gcsfs
import yaml
import requests
import malariagen_data
import numcodecs

In [None]:
def log(*msg):
    """Simple logging function that writes to stdout and flushes immediately."""
    print(*msg, file=sys.stdout)
    sys.stdout.flush()


In [None]:
def read_original_samples(*, sample_set):
    """Read the original_samples.tsv file for a given sample set.
    
    Parameters
    ----------
    sample_set : str
    
    Returns
    -------
    df : DataFrame
    
    """
    
    # directory containing tracking metadata for the sample set
    tracking_dir = here() / 'tracking' / sample_set
    
    # read into pandas
    df_original_samples = pd.read_csv(tracking_dir / 'original_samples.tsv', sep='\t') 
    
    # add in sample_set
    df_original_samples["sample_set"] = sample_set
    
    return df_original_samples
    

In [None]:
def read_wgs_derived_samples(*, sample_set):
    """Read the wgs_derived_samples.tsv file for a given sample set.
    
    Parameters
    ----------
    sample_set : str
    
    Returns
    -------
    df : DataFrame
    
    """
    
    # directory containing tracking metadata for the sample set
    tracking_dir = here() / 'tracking' / sample_set
    
    # read into pandas
    df_wgs_derived_samples = pd.read_csv(tracking_dir / 'wgs_derived_samples.tsv', sep='\t')
    
    # also read in original samples and merge
    df_original_samples = read_original_samples(sample_set=sample_set)
    df_wgs_derived_samples = pd.merge(df_wgs_derived_samples, df_original_samples, on="original_sample_id", how="left")
    
    return df_wgs_derived_samples
    

In [None]:
def read_wgs_lanelets(*, sample_set):
    """Read the wgs_lanelets.tsv file for a given sample set.
    
    Parameters
    ----------
    sample_set : str
    
    Returns
    -------
    df : DataFrame
    
    """
    
    # directory containing tracking metadata for the sample set
    tracking_dir = here() / 'tracking' / sample_set
    
    # read into pandas
    df_wgs_lanelets = (
        pd.read_csv(tracking_dir / 'wgs_lanelets.tsv', sep='\t')
        .rename(columns={'sample': 'derived_sample_id'})
    )

    # also read in derived samples and merge
    df_wgs_derived_samples = read_wgs_derived_samples(sample_set=sample_set)
    df_wgs_lanelets = pd.merge(df_wgs_lanelets, df_wgs_derived_samples, on="derived_sample_id", how="left")
    
    return df_wgs_lanelets

In [None]:
def read_wgs_snp_data(*, sample_set):
    """Read the wgs_snp_data.tsv file for a given sample set.
    
    Parameters
    ----------
    sample_set : str
    
    Returns
    -------
    df : DataFrame
    
    """
    
    # directory containing tracking metadata for the sample set
    tracking_dir = here() / 'tracking' / sample_set
    
    # read into pandas
    df_wgs_snp_data = (
        pd.read_csv(tracking_dir / 'wgs_snp_data.tsv', sep='\t')    
        .rename(columns={'sample': 'derived_sample_id'})
    )        

    # also read in derived samples and merge
    df_wgs_derived_samples = read_wgs_derived_samples(sample_set=sample_set)
    df_wgs_snp_data = pd.merge(df_wgs_snp_data, df_wgs_derived_samples, on="derived_sample_id", how="left")

    return df_wgs_snp_data

In [None]:
def read_release_config(*, release):
    
    # New name
    file_path = here() / "tracking" / "release" / release / "release_sample_sets.yaml"

    if not os.path.exists(file_path):
        # Old name
        file_path = file_path = here() / "tracking" / "release" / release / "config.yml"
    
    with open(file_path, mode="r") as f:
        config = yaml.safe_load(f)
    return config

In [None]:
def read_species_group_config():
    with open(here() / "species-group_config.yaml", mode="r") as f:
        config = yaml.safe_load(f)
    return defaultdict(lambda: None, config)

In [None]:
def subprocess_dump(result):
    if hasattr(result, 'stdout'):
        stdout = result.stdout.decode()
        print(stdout, file=sys.stdout)
        sys.stdout.flush()
    if hasattr(result, 'stderr'):
        stderr = result.stderr.decode()
        print(stderr, file=sys.stderr)
        sys.stderr.flush()


def bash(cmd, check=True):
    log(cmd)
    cmd = "set -xeuo pipefail; " + cmd
    try:
        result = subprocess.run(cmd, 
                                check=check, 
                                capture_output=True, 
                                shell=True, 
                                executable="/bin/bash")
    except Exception as e:
        subprocess_dump(e)
        return e
    else:
        subprocess_dump(result)
        return result

In [None]:
@functools.lru_cache(maxsize=None)
def http_head(url):
    # small optimisation, cache results of http head requests because
    # these can be slow, especially if there are a lot of them to do
    response = requests.head(url)
    if response.status_code != 200:
        raise RuntimeError(f"status {response.status_code} for {url}")
    return response
    

In [None]:
def delayed_curl_to_gcs(*, gcs, source_url, gcs_url, raw_md5=None, verbose=False, force=False):
    """Copy an object from an input URL (e.g., Sanger S3) to GCS."""

    if raw_md5 is not None:

        # Convert MD5 to base64-encoded 128-bit MD5 hash
        expected_md5 = base64.b64encode(bytes.fromhex(raw_md5)).decode('utf-8')

        if not force:
            
            # Check whether this file already exists
            if gcs.exists(gcs_url):

                # Check whether the existing file's MD5 hash matches the one specified 
                file_info = gcs.info(gcs_url)
                existing_md5 = file_info['md5Hash']
                if existing_md5 == expected_md5:

                    # Don't copy this file.
                    if verbose:
                        log(f'{gcs_url} - skipping, file exists at target location and has expected MD5 and size')
                    return None

        # Compose the Bash copy command
        copy_command = f'curl -s {source_url} | gcloud storage cp -q --content-md5={expected_md5} - {gcs_url}'
    
    else:
        
        if not force:

            # Check whether this file already exists
            if gcs.exists(gcs_url):

                # Check file sizes
                expected_size = int(http_head(source_url).headers['content-length'])
                file_info = gcs.info(gcs_url)
                existing_size = file_info['size']
                if existing_size == expected_size:

                    # Don't copy this file.
                    if verbose:
                        log(f'{gcs_url} - skipping, file exists at target location and has expected size')
                    return None

        # Compose the Bash copy command, without MD5 check
        copy_command = f'curl -s {source_url} | gsutil -q cp - {gcs_url}'
            
    # Setup delayed computation
    task = delayed(bash)(copy_command)
    
    return task
    

In [None]:
try:
    
    # newer versions of zarr
    
    from zarr.storage import KVStore
    
    class SafeStore(KVStore):
        
        def __getitem__(self, key):
            try:
                return self._mutable_mapping[key]
            except KeyError as e:
                # always raise a runtime error to ensure zarr propagates the exception
                raise RuntimeError(e)

        def __contains__(self, key):
            return key in self._mutable_mapping

                
except ImportError:
    
    # older versions of zarr

    class SafeStore(Mapping):

        ## This helps to ensure that no missing data are silently filled in.

        def __init__(self, store):
            self.store = store

        def __getitem__(self, key):
            try:
                return self.store[key]
            except KeyError as e:
                # always raise a runtime error to ensure zarr propagates the exception
                raise RuntimeError(e)

        def __contains__(self, key):
            return key in self.store

        def __iter__(self):
            return iter(self.store)

        def __len__(self):
            return len(self.store)


In [None]:
def open_gcs_zip_zarr(*, gcs_url, gcs):
    """Open the zipped zarr for a given GCS URL.
    
    Parameters
    ----------
    gcs_url : str
    gcs : GCSFileSystem
    
    Returns
    -------
    zarr_data : zarr.core.Array or zarr.hierarchy.Group
    
    """
    zip_file = gcs.open(gcs_url)
    zip_fs = ZipFileSystem(zip_file)
    zarr_store = SafeStore(zip_fs.get_mapper("/"))
    zarr_data = zarr.open(store=zarr_store, mode='r')
    
    return zarr_data


In [None]:
def open_gcs_zarr(*, gcs_url, gcs):
    """Open the zarr for a given GCS URL.
    
    Parameters
    ----------
    gcs_url : str
    gcs : GCSFileSystem
    
    Returns
    -------
    zarr_data : zarr.core.Array or zarr.hierarchy.Group
    
    """
    zarr_store = SafeStore(gcs.get_mapper(gcs_url))
    zarr_data = zarr.Group(zarr_store, read_only=True)
    
    return zarr_data


In [None]:
def get_sample_set_qc_all_pass(*, release, sample_set):
    """Return pandas Series of values (True or False) for each sample in the sample_set determined by whether all QC filters have "PASS" for that sample.
    
    Parameters
    ----------
    release : str
    sample_set : str
    
    Returns
    -------
    s : Series
    
    """
    
    release_dir = here() / "tracking" / "release" / release
    
    sequence_qc_df = pd.read_csv(release_dir / "wgs_sequence_qc" / f"sequence_qc_filters_{sample_set}.tsv", sep="\t", index_col=0)
    replicate_qc_df = pd.read_csv(release_dir / "wgs_replicate_qc" / f"replicate_qc_filters_{sample_set}.tsv", sep="\t", index_col=0)
    anomaly_qc_df = pd.read_csv(release_dir / "wgs_population_qc" / f"anomaly_qc_filters_{sample_set}.tsv", sep="\t", index_col=0)
    
    qc_dfs = [sequence_qc_df, replicate_qc_df, anomaly_qc_df]
    
    # We stopped using PCA outlier filters circa Ag3.11
    pca_qc_path = release_dir / "wgs_population_qc" / f"pca_qc_filters_{sample_set}.tsv"
    if pca_qc_path.is_file():
        pca_qc_df = pd.read_csv(pca_qc_path, sep="\t", index_col=0)
        qc_dfs.append(pca_qc_df)
    
    joined_df = pd.concat(qc_dfs, axis=1, sort=False)
    
    return (joined_df == "PASS").all(axis=1)