# Produce provisional taxa

This notebook uses classifiers trained on the results of cohorts analyis to predict the taxa of samples based on their SNP genotypes and produce a TSV for each sample set in the specified release.

- **Note:** You will probably need a high RAM machine to run this, i.e. 27 GB. 
- **Note:** This notebook will probably need to install older versions of Python packages in order to import and use the stored classifiers. These packages are usually installed in `~/local/` and should be removed afterwards in order to avoid overriding the standard environment.

## Setup
This will import various commonly used modules and functions.

In [1]:
from pyprojroot import here
%run {here()}/DataLab_bespin_functions.ipynb

conda_prefix:  /home/conda/global/envs/global-mgenv-5.1.0
current_environment:  global/global-mgenv-5.1.0


In [2]:
# See git repo clone root
here()

PosixPath('/home/leehart/gitRepos/vector-ops')

In [3]:
# See cwd
Path.cwd()

PosixPath('/home/leehart/gitRepos/vector-ops/tracking/release/v3.13/wgs_population_qc')

## Additional imports

In [4]:
import joblib
from datetime import datetime
import importlib.metadata

## Settings

In [5]:
# Determine the release we are working on.
release_version = Path.cwd().parent.name
release_version

'v3.13'

In [6]:
repo_clone_path = here()

In [7]:
# Get species-group config variables
sg_config = read_species_group_config()
production_bucket = sg_config['production_bucket']
release_bucket = sg_config['release_bucket']
contigs = sg_config['contigs']
allsites_zip_path = sg_config['allsites_zip_path']
major_version = sg_config['major_version']
prov_taxa_classifier_set = sg_config['prov_taxa_classifier_set']

In [8]:
# The path to the tracking directory in this repo clone
tracking_dir_path = repo_clone_path / 'tracking'

## Settings for classifiers and classifications

In [9]:
contigs_to_include = contigs

In [10]:
# Note: this might eventually need to be set in config
partition_size = 1_000_000

In [11]:
# E.g. gs://vo_agam_production/v3.x/taxon_classifiers/RF_20231019
classifier_dir_gcs_path = f'gs://{production_bucket}/v{major_version}.x/taxon_classifiers/{prov_taxa_classifier_set}'

In [12]:
diploid_genotype_encodings_gcs_path = f'{classifier_dir_gcs_path}/diploid_genotype_encodings.yaml'

In [13]:
# Note: this might eventually need to be set in config
classifier_id_template = 'RF_{contig}_{start_pos}-{stop_pos}'

In [14]:
# Note: to use released snp_genotypes instead, e.g. post-release, after production genotypes have been deleted, path:
#snp_genotypes_gcs_path_template = f'gs://{release_bucket}/{release_version}/snp_genotypes/all/' + '{sample_set}'
snp_genotypes_gcs_path_template = f'gs://{production_bucket}/v{major_version}.x/curation/' + '{sample_set}/snp_genotypes_combined.zarr'

In [15]:
taxon_classes_local_output_path_template = 'provisional_taxa_classes.txt'

In [16]:
provisional_taxa_output_path_template = 'provisional_taxa_stats_{sample_set}.tsv'

In [17]:
predicted_taxa_probs_output_gcs_path_template = f'gs://{production_bucket}/tracking/release/{release_version}/wgs_population_qc/predicted_taxa_probs/predicted_taxa_probs_' + '{sample_set}.tsv'

In [18]:
classifier_set_dir_gcs_path = f'{production_bucket}/v{major_version}.x/taxon_classifiers/{prov_taxa_classifier_set}'

In [19]:
classifier_requirements_file_gcs_path = f'{classifier_set_dir_gcs_path}/classifier_requirements.txt'
classifier_requirements_file_gcs_path

'vo_agam_production/v3.x/taxon_classifiers/RF_20231019/classifier_requirements.txt'

## Get a GCS connection

In [20]:
gcs = init_gcs()

## Check that we're using an environment that is compatible with the classifiers

**Note:** We need to use the same version of classifier packages as those used by the stored classifiers, otherwise we will see `InconsistentVersionWarning` and errors might occur.

In [21]:
with gcs.open(classifier_requirements_file_gcs_path, 'r') as fh:
    classifier_requirements = [line.strip() for line in fh.readlines()]

In [22]:
for classifier_requirement in classifier_requirements:
    
    # Get the required package and version
    required_package, required_version = classifier_requirement.split('==')
    
    print()
    print('The classifiers require', classifier_requirement)
    
    # Get the installed version of the package
    installed_version = importlib.metadata.version(required_package)
    
    print('- The installed version is:', installed_version)
    
    if installed_version != required_version:
        print('- Installing', classifier_requirement)
        %pip install -q {classifier_requirement}
    
    # Ensure that the installed version matches the classifier's requirement
    assert installed_version == required_version, f'- but version {installed_version} is installed'
    
    # Get the location of the package
    distribution = importlib.metadata.distribution(required_package)
    package_location = distribution.locate_file('')
    
    print('- Location:', package_location)


The classifiers require scikit-learn==1.3.0
- The installed version is: 1.3.0
- Location: /home/leehart/.local/lib/python3.10/site-packages


## Functions

### Functions for samples

In [23]:
@functools.lru_cache(maxsize=None)
def get_samples_sets_tuple(*, release_str):
    # Get all of the sample sets for this release from the release sample sets config
    rss_config = read_release_config(release=release_str)
    return tuple(rss_config['sample_sets'])

In [24]:
def get_derived_samples_df(*, release_strings_tuple):
    # Return a DataFrame containing all of the derived sample ids for the specified release strings.
    # Include the corresponding release string and sample set for each sample id.
    
    # Collect the DataFrames for each sample set
    sample_set_dfs = []
    
    for release_str in release_strings_tuple:
        
        sample_sets = get_samples_sets_tuple(release_str=release_str)
        
        for sample_set in sample_sets:
            
            # Get a DataFrame containing all of the the derived sample ids for this sample_set
            sample_set_derived_samples_df = read_wgs_derived_samples(sample_set=sample_set)
            
            # Rename the derived_sample_id column to sample_id
            sample_set_derived_samples_df.rename(columns={'derived_sample_id': 'sample_id'}, inplace=True)
            
            # Set the sample_id as the index
            sample_set_derived_samples_df.set_index('sample_id', inplace=True)
            
            # Add the release_str as the first index
            sample_set_derived_samples_df['release_str'] = release_str
            sample_set_derived_samples_df.set_index('release_str', append=True, inplace=True)
            sample_set_derived_samples_df = sample_set_derived_samples_df.swaplevel('sample_id', 'release_str', axis=0)
            
            # Add the sample_set as the second index
            sample_set_derived_samples_df['sample_set'] = sample_set
            sample_set_derived_samples_df.set_index('sample_set', append=True, inplace=True)
            sample_set_derived_samples_df = sample_set_derived_samples_df.swaplevel('sample_id', 'sample_set', axis=0)
            
            # We're not interested in the other columns
            # Note: we use [[]] to get a DataFrame rather than a Series, which maintains expectations.
            # Note: this will keep the multi-index (release_str, sample_set, sample_id)
            sample_set_dfs.append(sample_set_derived_samples_df[[]])
    
    return pd.concat(sample_set_dfs)

### Functions for genotypes

In [25]:
def get_contig_max_pos(*, genomic_positions_zarr, contig):
    # Return the maximum position for the specified contig
    return max(genomic_positions_zarr[contig]["variants"]["POS"])

In [26]:
def get_contig_partitions(*, genomic_positions_zarr, contig, partition_size):
    # Return a list of partition_tuple (contig, start_pos, stop_pos) where pos are inclusive
    
    # Get the maximum position for this contig
    contig_max_pos = get_contig_max_pos(genomic_positions_zarr=genomic_positions_zarr, contig=contig)
    
    # Start with the partition (0, size - 1)
    partition_start_pos = 0
    partition_end_pos = partition_size - 1
    
    # Collect the partitions for this contig
    partitions = []
    
    # While the partition's end position is less than the maximim position
    while partition_end_pos <= contig_max_pos:
        
        # Add this partition
        partitions.append((partition_start_pos, partition_end_pos))
        
        # Start the next partition just after the end of this partition
        partition_start_pos = partition_end_pos + 1
        
        # End the next partition at the end of the next partition (!)
        partition_end_pos += partition_size

    # If there are still positions remaining after the start_pos
    if partition_start_pos < contig_max_pos:
        # Add the last partition
        partitions.append((partition_start_pos, contig_max_pos))
    
    return partitions

In [27]:
@functools.lru_cache(maxsize=None)
def get_snp_genotypes_zarr(*, gcs, snp_genotypes_gcs_path_template, release_str, sample_set):
    # Return the snp_genotypes zarr for the specified sample set.
    # Raise an exception if the relevant path is not found.
    
    # Note: in some contexts, the release_str placeholder will not be in the path template,
    #       in which case, only the sample_set placeholder will be replaced.
    snp_genotypes_gcs_path = snp_genotypes_gcs_path_template.format(
        release_str=release_str, sample_set=sample_set
    )

    # Check this path exists. This sample set might have had 0 samples pass QC.
    if not gcs.isdir(snp_genotypes_gcs_path):
        raise Exception(f'ERROR from get_snp_genotypes_zarr(): snp_genotypes_gcs_path not found {snp_genotypes_gcs_path}')

    return open_gcs_zarr(gcs_url=snp_genotypes_gcs_path, gcs=gcs)

In [28]:
def get_diploid_genotypes_na(*, genomic_positions_zarr, gcs, snp_genotypes_gcs_path_template, release_str, sample_set, partition_tuple):
    # Return the diploid genotypes (GT array) for all the samples in the specified sample set and partition tuple.
    # Return a DataFrame of the sample ids aligned with the retrieved genotypes.
    # Raise an exception if the relevant snp_genotypes were not found.
    
    # Extract the contig, start and stop position from the partition_tuple
    (contig, start_pos, stop_pos) = partition_tuple
    
    # Get the snp_genotypes Zarr for this sample set.
    # Note: this function should raise an exception if there are no genotypes found for the sample set.
    snp_genotypes_zarr = get_snp_genotypes_zarr(
        gcs=gcs,
        snp_genotypes_gcs_path_template=snp_genotypes_gcs_path_template,
        release_str=release_str,
        sample_set=sample_set
    )

    # Handle error when snp_genotypes zarr not found, in case it was not raised above.
    if snp_genotypes_zarr is None:
        raise Exception(f'ERROR from get_diploid_genotypes_na(): snp_genotypes_zarr was None for sample_set {sample_set}')
    
    # Get the list of sample ids, which correspond to the genotypes for this sample_set
    aligned_sample_ids = snp_genotypes_zarr['samples'][:]
    
    # Note: the sample ids are stored as byte strings when released, which we can decode
    #   but while in production, they are stored as normal strings, which will raise an error if we try to decode    
    if any(isinstance(element, (bytes, bytearray)) for element in aligned_sample_ids):
        aligned_sample_ids = np.char.decode(aligned_sample_ids, 'utf-8')
    
    # Get the contig genotypes Zarr array for the specified contig
    contig_genotypes_za = snp_genotypes_zarr[contig]['calldata/GT']

    # Get the genomic positions for the specified contig
    genomic_positions = allel.SortedIndex(genomic_positions_zarr[contig]['variants/POS'])

    # Get the positions slice for the specified range
    pos_slice = genomic_positions.locate_range(start_pos, stop_pos)

    # Get the computed diploid genotypes as a Numpy array for the specified slice of positions
    # TODO: Can probably eliminate the use of Dask here?
    diploid_genotypes_na = da.from_zarr(contig_genotypes_za)[pos_slice].compute()
    
    # Check that the number of aligned samples matches the corresponding dimension in the Dask array
    assert diploid_genotypes_na.shape[1] == len(aligned_sample_ids)
        
    return diploid_genotypes_na, aligned_sample_ids

In [29]:
def get_diploid_genotype_encodings(*, gcs, diploid_genotype_encodings_gcs_path):
    
    # Load data from YAML file
    with gcs.open(diploid_genotype_encodings_gcs_path, 'r') as yaml_file:
        diploid_genotypes_as_str_encodings = yaml.safe_load(yaml_file)
        
    # Convert string keys back to tuples
    # Convert int() values back to np.uint8()
    diploid_genotype_encodings = {
        tuple(eval(key)): np.uint8(value) for key, value in diploid_genotypes_as_str_encodings.items()
    }
    
    return diploid_genotype_encodings

In [30]:
def diploid_genotype_encoder(diploid_genotype_encodings, first_allele_arr, second_allele_arr) -> np.ndarray:
    # Return array containing the encoded values for the two given parallel diploid genotype arrays 
    return np.vectorize(lambda a, b: diploid_genotype_encodings[(a, b)])(first_allele_arr, second_allele_arr)

In [31]:
def encode_diploid_genotypes_na(*, diploid_genotype_encodings, diploid_genotypes_na):
    # Return the encoded genotypes (uint8) for the given diploid genotypes (GT array)
    return diploid_genotype_encoder(diploid_genotype_encodings, diploid_genotypes_na[:, :, 0], diploid_genotypes_na[:, :, 1])

In [32]:
def get_encoded_genotypes_na(
    *, diploid_genotype_encodings, genomic_positions_zarr, gcs, snp_genotypes_gcs_path_template,
    samples_df, partition_tuple, warn_missing_genotypes=True
):
    # Return the encoded genotypes (uint8) for the specified samples and partition.
    # Return a DataFrame of the sample ids matching the retrieved genotypes.
    # Optionally warn if a target sample id was not found in the set of available genotypes.
    
    # Get DataFrames grouped by release_str
    release_str_dfgb = samples_df.groupby('release_str')
    
    # Collect the encoded genotype arrays for each release (as a list, so we can concatenate them later)
    encoded_genotype_narrs_per_release = []
    
    # Collect the sample ids as a list of dictionaries {'release_str': foo, 'sample_set': bar, 'sample_id': baz}
    # This should be in the same order as the encoded genotypes.
    got_sample_id_dicts_for_all_releases = []
    
    # For each release, get the diploid genotype Dask arrays
    for release_str, release_str_samples_df in release_str_dfgb:
        
        # Get the unique list of sample_set values for this release.
        # Warning: this will not preserve the same order of the sample sets, but we won't rely on that.
        sample_sets = release_str_samples_df.index.get_level_values('sample_set').unique().tolist()
        
        # Collect the encoded genotype arrays for each sample set (as a list, so we can concatenate them later)
        encoded_genotype_narrs_per_sample_set = []
        
        for sample_set in sample_sets:
            
            # Extract the contig, start and stop position from the partition_tuple
            (contig, start_pos, stop_pos) = partition_tuple
            
            # Get the diploid genotypes, which should also provide the aligned sample ids
            diploid_genotypes_na, aligned_sample_ids = get_diploid_genotypes_na(
                genomic_positions_zarr=genomic_positions_zarr,
                gcs=gcs,
                snp_genotypes_gcs_path_template=snp_genotypes_gcs_path_template,
                release_str=release_str,
                sample_set=sample_set,
                partition_tuple=partition_tuple
            )
            
            # Encode the diploid genotypes
            encoded_genotypes_na = encode_diploid_genotypes_na(
                diploid_genotype_encodings=diploid_genotype_encodings,
                diploid_genotypes_na=diploid_genotypes_na
            )
            
            # Get the target samples for this sample_set
            # This should preserve the order of the sample ids.
            sample_set_df = release_str_samples_df.xs(sample_set, level='sample_set')
            target_sample_ids = sample_set_df.index.get_level_values('sample_id').tolist()
            
            # Warn if a target sample id was not found in the set of available genotypes.
            # Note: This might produce a lot of output when using the list of derived samples
            #       because unsequenced samples have not been filtered out.
            if warn_missing_genotypes:
                for target_sample_id in target_sample_ids:
                    if target_sample_id not in aligned_sample_ids:
                        print('WARNING get_encoded_genotypes_na(): target_sample_id not found', release_str, sample_set, target_sample_id)

            # Create a boolean mask to select the target samples from those available
            target_sample_selection_mask = [sample_id in target_sample_ids for sample_id in aligned_sample_ids]
            
            # Get the encoded genotypes for the specified samples
            sample_selection_encoded_genotypes_na = encoded_genotypes_na[:, target_sample_selection_mask]
            
            # Add the encoded genotypes for this sample_set to the list
            encoded_genotype_narrs_per_sample_set.append(sample_selection_encoded_genotypes_na)
            
            # Get the list of obtained sample ids, which can differ from the target and those available
            # This should still preserve the order of the sample_ids.
            got_sample_ids = aligned_sample_ids[target_sample_selection_mask]
            
            # Get the list of sample id dictionaries for this sample set.
            # This should preserve the order of the sample ids.
            got_sample_id_dicts = [{'release_str': release_str, 'sample_set': sample_set, 'sample_id': got_sample_id} for got_sample_id in got_sample_ids]
            
            # Add the list of obtained sample id dicts to the list.
            # This should preserve the order of the sample ids.
            got_sample_id_dicts_for_all_releases.extend(got_sample_id_dicts)
        
        
        # Concatenate the arrays of encoded genotype for all the sample sets in this release.
        # We concatenate along the samples dimension (axis=1).
        # The number of genotypes should be the same length for all arrays for all sample sets.
        encoded_genotypes_na_for_all_sample_sets = np.concatenate(encoded_genotype_narrs_per_sample_set, axis=1)
        
        # Add the encoded genotypes for this release to the list
        encoded_genotype_narrs_per_release.append(encoded_genotypes_na_for_all_sample_sets)
    
    
    # Concatenate the arrays of encoded genotypes for all the specified releases
    # We concatenate along the samples dimension (axis=1).
    # The number of genotypes should be the same length for all arrays for all releases.
    encoded_genotypes_for_all_releases_na = np.concatenate(encoded_genotype_narrs_per_release, axis=1)
    
    # Transpose the array into the shape (n_samples, n_genotypes)
    encoded_genotypes_for_all_releases_na = encoded_genotypes_for_all_releases_na.T
    
    # Convert the list of sample id dictionaries to a DataFrame
    # Note: this should retain the order of the got sample_ids so they remain algned with the genotypes.
    aligned_samples_df = pd.DataFrame(got_sample_id_dicts_for_all_releases)
    aligned_samples_df.set_index(['release_str', 'sample_set', 'sample_id'], inplace=True)
    
    return encoded_genotypes_for_all_releases_na, aligned_samples_df

### Functions for classifiers

In [33]:
@functools.lru_cache(maxsize=None)
def get_classifier_id(*, classifier_id_template, partition_tuple):
    # Return the classifier_id based on the given partition_tuple
    (contig, start_pos, stop_pos) = partition_tuple
    return classifier_id_template.format(contig=contig, start_pos=start_pos, stop_pos=stop_pos)

In [34]:
@functools.lru_cache(maxsize=None)
def get_classifier_joblib_gcs_path(*, classifier_dir_gcs_path, classifier_id_template, partition_tuple):
    
    classifier_id = get_classifier_id(classifier_id_template=classifier_id_template, partition_tuple=partition_tuple)
    
    (contig, start_pos, stop_pos) = partition_tuple
    
    return classifier_dir_gcs_path + f'/{contig}/{classifier_id}.joblib'

In [35]:
# Warning: Don't cache this function!
def classifier_exists(*, classifier_dir_gcs_path, classifier_id_template, partition_tuple):
    # If the classifier for the specified partition_tuple exists, return True.
    # Else return False.
    
    # Get the path to the classifier
    classifier_file_gcs_path = get_classifier_joblib_gcs_path(
        classifier_dir_gcs_path=classifier_dir_gcs_path,
        classifier_id_template=classifier_id_template,
        partition_tuple=partition_tuple
    )
    
    return gcs.exists(classifier_file_gcs_path)

In [36]:
def import_classifier(*, classifier_dir_gcs_path, classifier_id_template, partition_tuple):
    # Return the classifier for the specified partition_tuple
    
    # Get the path to the classifier
    classifier_file_gcs_path = get_classifier_joblib_gcs_path(
        classifier_dir_gcs_path=classifier_dir_gcs_path,
        classifier_id_template=classifier_id_template,
        partition_tuple=partition_tuple
    )
    
    # Open a file-handle to the GCS output path using read-binary mode
    with gcs.open(classifier_file_gcs_path, 'rb') as fh:
        
        # Use joblib to load the classifier from the file-handle
        classifier = joblib.load(fh)
    
    return classifier

In [37]:
def get_predicted_probs_via_classifier(*, classifier, genotypes_arr):
    # Return the predicted class probabilities for the given classifier.
    return classifier.predict_proba(genotypes_arr)

## Functions for probabilities

In [38]:
def get_taxon_by_prob_column_name_dict(*, taxon_classes):
    # Map the names of the taxon probability columns with corresponding taxon classes
    taxon_by_prob_column_name_dict = {
        f'{taxon}_prob': taxon for taxon in taxon_classes
    }
    return taxon_by_prob_column_name_dict

In [39]:
def get_max_taxon_from_df_row(df_row, taxon_classes):
    
    # Get the dictionary mapping taxon prob columns to class labels
    taxon_by_prob_column_name_dict = get_taxon_by_prob_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability column names
    taxon_prob_column_names = list(taxon_by_prob_column_name_dict.keys())
    
    # Get the name of the prob column with the highest value
    max_taxon_column_name = df_row[taxon_prob_column_names].idxmax()
    
    # Get the taxon represented by that column
    max_taxon = taxon_by_prob_column_name_dict[max_taxon_column_name]
    
    return max_taxon

In [40]:
def get_max_taxon_prob_diff_from_df_row(df_row, taxon_classes):
    
    # Get the dictionary mapping taxon prob columns to class labels
    taxon_by_prob_column_name_dict = get_taxon_by_prob_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability columns
    taxon_prob_columns = list(taxon_by_prob_column_name_dict.keys())
    
    # Get the values from the taxon probability columns
    taxon_probs = df_row[taxon_prob_columns]

    # Get the probability values in descending order
    taxon_probs_descending = sorted(taxon_probs, reverse=True)

    # Get the difference between the two highest values
    conf_by_prob_diff = taxon_probs_descending[0] - taxon_probs_descending[1]
    
    return conf_by_prob_diff

In [41]:
def get_PL_from_prob(prob, epsilon=1e-10):
    
    # Method: https://gatk.broadinstitute.org/hc/en-us/articles/360035890451-Calculation-of-PL-and-GQ-by-HaplotypeCaller-and-GenotypeGVCFs
    
    # PL = Phred-scaled likelihood
    
    # "low PL values mean [the thing] is more likely, and high PL values means it’s less likely"
    
    # Using a very small value (epsilon) to avoid taking the logarithm of zero.
    
    raw_PL_from_prob = -10 * np.log10(prob + epsilon)
    
    return raw_PL_from_prob

In [42]:
def get_max_taxon_prob_qual_from_df_row(df_row, taxon_classes, cap=99):
    
    # Method: https://gatk.broadinstitute.org/hc/en-us/articles/360035890451-Calculation-of-PL-and-GQ-by-HaplotypeCaller-and-GenotypeGVCFs
    
    # "Quality" calculated in a similar way to GQ (genotype quality)
    
    # Get the dictionary mapping taxon prob columns to class labels
    taxon_by_prob_column_name_dict = get_taxon_by_prob_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability columns
    taxon_prob_columns = list(taxon_by_prob_column_name_dict.keys())
    
    # Get the values from the taxon probability columns
    taxon_probs = df_row[taxon_prob_columns]
    
    # Get the "raw PL" (Phred-scaled likelihood) for each taxon probability
    taxon_raw_PLs = [get_PL_from_prob(taxon_prob) for taxon_prob in taxon_probs]
    
    # Find the lowest value in the taxon_raw_PL list
    min_taxon_raw_PL = min(taxon_raw_PLs)
    
    # Subtract the lowest value from each value in the taxon_raw_PL list
    normalized_taxon_PLs = [taxon_raw_PL - min_taxon_raw_PL for taxon_raw_PL in taxon_raw_PLs]
    
    # Sort the normalized PL values in ascending order
    sorted_normalized_taxon_PLs = sorted(normalized_taxon_PLs)
    
    # Get the lowest and second lowest normalized PL values
    lowest_normalized_taxon_PL = sorted_normalized_taxon_PLs[0]
    second_lowest_normalized_taxon_PL = sorted_normalized_taxon_PLs[1]
    
    # Get the difference between the lowest and second lowest normalized PL values
    conf_by_prob_qual = second_lowest_normalized_taxon_PL - lowest_normalized_taxon_PL
    
    # Cap the value "for practical reasons"
    if cap is not None:
        conf_by_prob_qual = min(conf_by_prob_qual, cap)
    
    return conf_by_prob_qual

In [43]:
def get_samples_taxon_prob_averages_df(*, taxon_probs_by_partition_df, taxon_classes):
    # Return a DataFrame containing the average value for each taxon probability column across all partitions.
    
    # Get the dictionary mapping taxon prob columns to class labels
    taxon_by_prob_column_name_dict = get_taxon_by_prob_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability columns
    taxon_prob_columns = list(taxon_by_prob_column_name_dict.keys())
    
    # Compose the aggregation dictionary, specifying the aggregation function for each column
    taxon_prob_column_agg_dict = {col: 'mean' for col in taxon_probs_by_partition_df.columns if col in taxon_prob_columns}
    
    # Group by 'sample_id' and aggregate by taking the mean of each taxon probability
    sample_taxon_prob_averages_df = taxon_probs_by_partition_df.groupby(
        level=['release_str', 'sample_set', 'sample_id']
    ).agg(taxon_prob_column_agg_dict)

    # Rename the columns to indicate that they represent mean probabilities
    sample_taxon_prob_averages_df.columns = [f'{col}_mean' for col in sample_taxon_prob_averages_df.columns]
    
    return sample_taxon_prob_averages_df

In [44]:
def get_samples_max_taxon_counts_df(*, taxon_probs_by_partition_df, taxon_classes):
    # Return a DataFrame containing the max_taxon counts for each max_taxon value across all partitions.
    
    # Get the counts of each max_taxon value per sample as a Pandas Series using size() and groupby() 
    samples_max_taxon_counts_srs = taxon_probs_by_partition_df.groupby(
        ['release_str', 'sample_set', 'sample_id', 'max_taxon']
    ).size()
    
    # Convert the unique max_taxon values to columns and fill missing counts with 0.
    samples_max_taxon_counts_df = samples_max_taxon_counts_srs.unstack(fill_value=0)
    
    # Include counts for taxon classes that did not appear in the data by reindexing and filling with 0.
    samples_max_taxon_counts_df = samples_max_taxon_counts_df.reindex(columns=taxon_classes, fill_value=0)
    
    # Rename the columns to include the suffix "_votes"
    samples_max_taxon_counts_df.columns.name = None
    samples_max_taxon_counts_df.columns = [f'{taxon}_votes' for taxon in taxon_classes]
    
    return samples_max_taxon_counts_df

In [45]:
def get_taxon_by_prob_mean_column_name_dict(*, taxon_classes):
    # Map the names of the taxon probability average columns with corresponding taxon classes
    taxon_by_prob_mean_column_name_dict = {
        f'{taxon}_prob_mean': taxon for taxon in taxon_classes
    }
    return taxon_by_prob_mean_column_name_dict

In [46]:
def get_max_mean_prob_taxon_from_df_row(df_row, taxon_classes):

    # Get the dictionary mapping taxon prob mean columns to class labels
    taxon_by_prob_mean_column_name_dict = get_taxon_by_prob_mean_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability average columns
    taxon_prob_mean_column_names = list(taxon_by_prob_mean_column_name_dict.keys())
    
    # Get the name of the mean prob column with the highest value
    max_mean_prob_taxon_column_name = df_row[taxon_prob_mean_column_names].idxmax()
    
    # Get the taxon represented by that max column
    max_mean_prob_taxon = taxon_by_prob_mean_column_name_dict[max_mean_prob_taxon_column_name]
    
    return max_mean_prob_taxon

In [47]:
def get_max_mean_prob_taxon_diff_from_df_row(df_row, taxon_classes):
    
    # TODO: I expect we could merge this with get_max_taxon_prob_diff_from_df_row() 
    
    # Get the dictionary mapping taxon prob mean columns to class labels
    taxon_by_prob_mean_column_name_dict = get_taxon_by_prob_mean_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability average columns
    taxon_prob_mean_column_names = list(taxon_by_prob_mean_column_name_dict.keys())
    
    # Get the values from the taxon probability average columns
    taxon_prob_averages = df_row[taxon_prob_mean_column_names]
    
    # Get the probability averages in descending order
    taxon_probs_averages_descending = sorted(taxon_prob_averages, reverse=True)
    
    # Get the difference between the two highest values
    conf_by_prob_average_diff = taxon_probs_averages_descending[0] - taxon_probs_averages_descending[1]
    
    return conf_by_prob_average_diff

In [48]:
def get_max_mean_prob_taxon_qual_from_df_row(df_row, taxon_classes, cap=99):
    
    # TODO: I expect we could merge this with get_max_taxon_prob_qual_from_df_row()
    
    
    # Method: https://gatk.broadinstitute.org/hc/en-us/articles/360035890451-Calculation-of-PL-and-GQ-by-HaplotypeCaller-and-GenotypeGVCFs
    
    # "Quality" calculated in a similar way to GQ (genotype quality)
    
    # Get the dictionary mapping taxon prob mean columns to class labels
    taxon_by_prob_mean_column_name_dict = get_taxon_by_prob_mean_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon probability average columns
    taxon_prob_mean_column_names = list(taxon_by_prob_mean_column_name_dict.keys())
    
    # Get the values from the taxon probability average columns
    taxon_prob_averages = df_row[taxon_prob_mean_column_names]
    
    # Get the "raw PL" (Phred-scaled likelihood) for each taxon probability average
    taxon_raw_PLs = [get_PL_from_prob(taxon_prob) for taxon_prob in taxon_prob_averages]
    
    # Find the lowest value in the taxon_raw_PL list
    min_taxon_raw_PL = min(taxon_raw_PLs)
    
    # Subtract the lowest value from each value in the taxon_raw_PL list
    normalized_taxon_PLs = [taxon_raw_PL - min_taxon_raw_PL for taxon_raw_PL in taxon_raw_PLs]
    
    # Sort the normalized PL values in ascending order
    sorted_normalized_taxon_PLs = sorted(normalized_taxon_PLs)
    
    # Get the lowest and second lowest normalized PL values
    lowest_normalized_taxon_PL = sorted_normalized_taxon_PLs[0]
    second_lowest_normalized_taxon_PL = sorted_normalized_taxon_PLs[1]
    
    # Get the difference between the lowest and second lowest normalized PL values
    conf_by_prob_qual = second_lowest_normalized_taxon_PL - lowest_normalized_taxon_PL
    
    # Cap the value "for practical reasons"
    if cap is not None:
        conf_by_prob_qual = min(conf_by_prob_qual, cap)
    
    return conf_by_prob_qual

In [49]:
def get_taxon_by_votes_column_name_dict(*, taxon_classes):
    # Map the names of the taxon votes columns with corresponding taxon classes
    taxon_by_votes_column_name_dict = {
        f'{taxon}_votes': taxon for taxon in taxon_classes
    }
    return taxon_by_votes_column_name_dict

In [50]:
def get_max_votes_taxon_from_df_row(df_row, taxon_classes):
    
    # TODO: this looks very similar to get_max_mean_prob_taxon_from_df_row()

    # Get the dictionary mapping taxon votes columns to class labels
    taxon_by_votes_column_name_dict = get_taxon_by_votes_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon votes columns
    taxon_votes_column_names = list(taxon_by_votes_column_name_dict.keys())
    
    # Get the name of the votes column with the highest value
    max_votes_taxon_column_name = df_row[taxon_votes_column_names].idxmax()
    
    # Get the taxon represented by that max column
    max_votes_taxon = taxon_by_votes_column_name_dict[max_votes_taxon_column_name]
    
    return max_votes_taxon

In [51]:
def get_max_votes_taxon_ratio_diff_from_df_row(df_row, taxon_classes, epsilon=1e-10):
    
    # Get the dictionary mapping taxon votes columns to class labels
    taxon_by_votes_column_name_dict = get_taxon_by_votes_column_name_dict(taxon_classes=taxon_classes)
    
    # Get the list of taxon votes columns
    taxon_votes_column_names = list(taxon_by_votes_column_name_dict.keys())
    
    # Get the values from the columns
    taxon_vote_counts = df_row[taxon_votes_column_names]
    
    # Get the total number of votes
    taxon_votes_total = sum(taxon_vote_counts)
    
    # Get the vote ratios (as decimal fractions) using list comprehension.
    # Use a very small value (epsilon) to avoid division by zero.
    taxon_vote_ratios = [taxon_vote_count / (taxon_votes_total + epsilon) for taxon_vote_count in taxon_vote_counts]
    
    # Get the vote ratios (as decimal fractions) in descending order
    taxon_vote_ratios_descending = sorted(taxon_vote_ratios, reverse=True)
    
    # Get the difference between the two highest values
    conf_by_ratio_diff = taxon_vote_ratios_descending[0] - taxon_vote_ratios_descending[1]
    
    return conf_by_ratio_diff

In [52]:
def get_predicted_probs_df(*, derived_samples_df):
    
    # TODO: doc
    
    # Check the consistency of the classifier classes
    aligned_taxon_classes = None

    # Check the consistency of the number of samples,
    #   e.g. when we get the genotypes or classifications for each partition
    samples_count = None

    # Get the aligned_samples_df, checked for consistency with each get_encoded_genotypes_na()
    consistent_aligned_samples_df = None

    # Collect the DataFrames of predicted probabilities per partition,
    #   i.e. (release_str, sample_set, sample_id, contig, start_pos, stop_pos, [taxon]_prob)
    predicted_prob_dfs_per_partition = []

    # For each contig in the list of contigs to include
    for contig in contigs_to_include:

        print()
        print(contig)

        contig_partitions = get_contig_partitions(
            genomic_positions_zarr=genomic_positions_zarr,
            contig=contig,
            partition_size=partition_size
        )

        for start_pos, stop_pos in contig_partitions:

            print('- Partition', start_pos, stop_pos)

            print(' - Time', datetime.utcnow().strftime("%H:%M:%S"))

            # Compose the partition tuple
            partition_tuple = (contig, start_pos, stop_pos)

            # Skip if classifier doesn't exist
            if not classifier_exists(
                classifier_dir_gcs_path=classifier_dir_gcs_path,
                classifier_id_template=classifier_id_template,
                partition_tuple=partition_tuple
            ):
                print(' - WARNING: Classifier not found. Skipping.')
                continue


            print(' - Getting classifier')

            # Import the classifier from GCS
            classifier = import_classifier(
                classifier_dir_gcs_path=classifier_dir_gcs_path,
                classifier_id_template=classifier_id_template,
                partition_tuple=partition_tuple
            )

            # Check the consistency of the classifier classes
            if aligned_taxon_classes is None:
                aligned_taxon_classes = classifier.classes_
                
                # Get the local output path for the file
                taxon_classes_local_output_path = taxon_classes_local_output_path_template
                
                # Write each taxon class to the file
                with open(taxon_classes_local_output_path, 'w') as fh:
                    for taxon_class in aligned_taxon_classes:
                        fh.write(taxon_class + '\n')
                
                # Read the file back, for reproducibility
                with open(taxon_classes_local_output_path, 'r') as fh:
                    # Read all lines from the file
                    aligned_taxon_classes = [line.strip() for line in fh.readlines()]
                
            else:
                assert np.array_equal(classifier.classes_, aligned_taxon_classes)
            
            
            print(' - Getting encoded genotypes')

            # Get the encoded genotypes narr for the target samples for this partition
            # Also get the aligned samples DataFrame
            genotypes_na, aligned_samples_df = get_encoded_genotypes_na(
                diploid_genotype_encodings=diploid_genotype_encodings,
                genomic_positions_zarr=genomic_positions_zarr,
                gcs=gcs,
                snp_genotypes_gcs_path_template=snp_genotypes_gcs_path_template,
                samples_df=derived_samples_df,
                partition_tuple=partition_tuple,
                warn_missing_genotypes=False
            )

            # Check the returned number of aligned samples corresponds to the genotypes arr (n_samples, n_genotypes)
            assert len(aligned_samples_df) == genotypes_na.shape[0]

            # Check the consistency of the number of samples
            if samples_count is None:
                samples_count = len(aligned_samples_df)
            else:
                assert len(aligned_samples_df) == samples_count

            # Check that aligned_samples_df is constent with previous runs
            if consistent_aligned_samples_df is None:
                consistent_aligned_samples_df = aligned_samples_df
            else:
                assert aligned_samples_df.equals(consistent_aligned_samples_df)

            print(' - Getting predicted class probabilities')
            # This should have the same (samples, classes)
            # This should return results in the same order as consistent_aligned_samples_df (and genotypes_na)
            classifier_predicted_probs_arr = get_predicted_probs_via_classifier(classifier=classifier, genotypes_arr=genotypes_na)

            # Check the consistency of the number of samples and classes
            assert classifier_predicted_probs_arr.shape[0] == samples_count
            assert classifier_predicted_probs_arr.shape[1] == len(aligned_taxon_classes)

            print(' - Making a DataFrame')

            ## Get the predicted probabilities for this partition as a DataFrame

            # Get a copy of the aligned sample ids (release_str, sample_set, sample_id)
            copy_of_aligned_samples_df = consistent_aligned_samples_df.copy().reset_index()

            # Make a record of the partition tuple for every sample (contig, start_pos, stop_pos)
            partition_df = pd.DataFrame(
                {'contig': contig, 'start_pos': start_pos, 'stop_pos': stop_pos},
                index=range(len(copy_of_aligned_samples_df))
            )

            # Get the list of column names for each taxon probability
            aligned_taxon_prob_columns = [
                f'{taxon}_prob' for taxon in aligned_taxon_classes
            ]

            # Convert the predicted probabilities array for this partition to a DataFrame
            classifier_predicted_probs_df = pd.DataFrame(
                classifier_predicted_probs_arr,
                columns=aligned_taxon_prob_columns
            )

            # Concatenate the sub-DataFrames together into one, for this partition
            predicted_probs_df = pd.concat(
                [copy_of_aligned_samples_df, partition_df, classifier_predicted_probs_df],
                axis=1
            )

            # Add the DataFrame of predicted probabilities for this partition to the list
            predicted_prob_dfs_per_partition.append(predicted_probs_df)

            
    ## Get a DataFrame of all predicted probs
            
    # Check that the number of samples in the predicted_prob DataFrames is consistent
    assert len(predicted_prob_dfs_per_partition[0]) == len(consistent_aligned_samples_df)

    # Concatenate all of the predicted_prob DataFrames
    all_predicted_probs_df = pd.concat(predicted_prob_dfs_per_partition)
    
    # Set the key columns as indexes
    all_predicted_probs_df = all_predicted_probs_df.set_index(
        ['release_str', 'sample_set', 'sample_id', 'contig', 'start_pos', 'stop_pos']
    )
    
    return all_predicted_probs_df, aligned_taxon_classes

In [53]:
def outputs_already_exist(*, sample_set):
    
    # Get the expected output paths
    predicted_taxa_probs_output_gcs_path = predicted_taxa_probs_output_gcs_path_template.format(sample_set=sample_set)
    provisional_taxa_output_path = provisional_taxa_output_path_template.format(
        sample_set=sample_set
    )
    
    # Determine whether all the outputs exist
    if gcs.exists(predicted_taxa_probs_output_gcs_path) and Path(provisional_taxa_output_path).exists():
        return True
    else:
        return False

## Get a DataFrame of the samples that we want to run through the classifiers

In [54]:
release_strings_tuple = (release_version,)

In [55]:
derived_samples_df = get_derived_samples_df(release_strings_tuple=release_strings_tuple)
derived_samples_df

release_str,sample_set,sample_id
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83156-7466STDY14206595
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83157-7466STDY14206596
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83158-7466STDY14206597
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83159-7466STDY14206598
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83160-7466STDY14206599
v3.13,1324-VO-ET-GOLASSA-VMF00257,...
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83580-7466STDY14207410
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83581-7466STDY14207411
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83582-7466STDY14207412
v3.13,1324-VO-ET-GOLASSA-VMF00257,VBS83583-7466STDY14207413


In [56]:
derived_samples_df.index.get_level_values('sample_set').value_counts()

sample_set
1324-VO-ET-GOLASSA-VMF00257    425
Name: count, dtype: int64

## Get the genomic positions

In [57]:
genomic_positions_zarr = open_gcs_zip_zarr(gcs_url=allsites_zip_path, gcs=gcs)

## Get the diploid genotype encodings

In [58]:
diploid_genotype_encodings = get_diploid_genotype_encodings(
    gcs=gcs,
    diploid_genotype_encodings_gcs_path=diploid_genotype_encodings_gcs_path
)

## Main process

In [59]:
# Collect DataFrame containing samples that diff in aggregates, i.e. max_votes_taxon != max_mean_prob_taxon
agg_diff_df_by_sample_set = {}

# Group by sample_set and loop over each group
for sample_set, sample_set_group_df in derived_samples_df.groupby('sample_set'):
    
    print()
    print(sample_set)
    
    # Determine whether the outputs for this sample_set already exist, then skip
    if outputs_already_exist(sample_set=sample_set):
        print('- WARNING: outputs already exist. Skipping.')
        
        # Get the path
        provisional_taxa_output_path = provisional_taxa_output_path_template.format(
            sample_set=sample_set
        )
        
        # Get the aggregated data
        samples_taxon_prob_agg_df_import = pd.read_csv(provisional_taxa_output_path, sep='\t')
        
        # Compare max_votes_taxon with max_mean_prob_taxon
        agg_diff_df = samples_taxon_prob_agg_df_import[samples_taxon_prob_agg_df_import['max_votes_taxon'] != samples_taxon_prob_agg_df_import['max_mean_prob_taxon']]
        print('- agg_diff_df len', len(agg_diff_df))
        
        # Collect the differences
        agg_diff_df_by_sample_set[sample_set] = agg_diff_df
        
        # Skip to the next sample_set
        continue
    
    # Determine whether snp_genotypes are missing for this sample_set, then skip 
    try:
        # Get a DataFrame of the predicted taxon probs for each partition.
        # Also get the list of aligned taxon classes.
        all_predicted_probs_df, aligned_taxon_classes = get_predicted_probs_df(derived_samples_df=sample_set_group_df)
    except Exception as err:
        print(err)
        print('- WARNING: Skipping.')
        
        # Skip to the next sample_set
        continue
    
    # Add the max_taxon column to the DataFrame using the function
    all_predicted_probs_df['max_taxon'] = all_predicted_probs_df.apply(
        get_max_taxon_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Add the max_taxon_prob_diff column using the function
    all_predicted_probs_df['max_taxon_prob_diff'] = all_predicted_probs_df.apply(
        get_max_taxon_prob_diff_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Add the max_taxon_prob_qual column using the function
    all_predicted_probs_df['max_taxon_prob_qual'] = all_predicted_probs_df.apply(
        get_max_taxon_prob_qual_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Export the DataFrame of all the predicated probs for this sample set
    predicted_taxa_probs_output_gcs_path = predicted_taxa_probs_output_gcs_path_template.format(sample_set=sample_set)
    with gcs.open(predicted_taxa_probs_output_gcs_path, 'w') as fh:
        all_predicted_probs_df.to_csv(fh, sep='\t', index=True)
    
    # Reload for reproducibility
    with gcs.open(predicted_taxa_probs_output_gcs_path, 'r') as fh:
        all_predicted_probs_df = pd.read_csv(
            fh,
            sep='\t',
            index_col=['release_str', 'sample_set', 'sample_id', 'contig', 'start_pos', 'stop_pos']
        )
    
    
    # Get the averages for each taxon probability for each sample over all partitions
    samples_taxon_prob_averages_df = get_samples_taxon_prob_averages_df(
        taxon_probs_by_partition_df=all_predicted_probs_df,
        taxon_classes=aligned_taxon_classes
    )
    
    # Add the max_mean_prob_taxon column to the DataFrame using the function
    samples_taxon_prob_averages_df['max_mean_prob_taxon'] = samples_taxon_prob_averages_df.apply(
        get_max_mean_prob_taxon_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Add the max_mean_taxon_prob_diff column using the function
    samples_taxon_prob_averages_df['max_mean_taxon_prob_diff'] = samples_taxon_prob_averages_df.apply(
        get_max_mean_prob_taxon_diff_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Add the max_mean_taxon_prob_qual column using the function
    samples_taxon_prob_averages_df['max_mean_taxon_prob_qual'] = samples_taxon_prob_averages_df.apply(
        get_max_mean_prob_taxon_qual_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Get the max_taxon value counts for each sample over all partitions  
    samples_max_taxon_counts_df = get_samples_max_taxon_counts_df(
        taxon_probs_by_partition_df=all_predicted_probs_df,
        taxon_classes=aligned_taxon_classes
    )
    
    # Add the max_votes_taxon column to the DataFrame using the function
    samples_max_taxon_counts_df['max_votes_taxon'] = samples_max_taxon_counts_df.apply(
        get_max_votes_taxon_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Add the max_votes_taxon_ratio_diff column using the function
    samples_max_taxon_counts_df['max_votes_taxon_ratio_diff'] = samples_max_taxon_counts_df.apply(
        get_max_votes_taxon_ratio_diff_from_df_row,
        taxon_classes=aligned_taxon_classes,
        axis=1
    )
    
    # Merge the probability averages DataFrame with the votes DataFrame
    samples_taxon_prob_agg_df = samples_taxon_prob_averages_df.merge(samples_max_taxon_counts_df, left_index=True, right_index=True)
    
    # Compare max_votes_taxon with max_mean_prob_taxon
    agg_diff_df = samples_taxon_prob_agg_df[samples_taxon_prob_agg_df['max_votes_taxon'] != samples_taxon_prob_agg_df['max_mean_prob_taxon']]
    print('agg_diff_df len', len(agg_diff_df))
    
    # Collect the differences
    agg_diff_df_by_sample_set[sample_set] = agg_diff_df
    
    
    ## Export the aggregated provisional taxon data for this sample set
    
    # Convert the indexes to columns
    samples_taxon_prob_agg_df_export = samples_taxon_prob_agg_df.reset_index()
    
    # Drop the release_str and sample_set columns
    samples_taxon_prob_agg_df_export.drop(columns=['release_str', 'sample_set'], inplace=True)
    
    # Rename the sample_id column to derived_sample_id
    samples_taxon_prob_agg_df_export.rename(columns={'sample_id': 'derived_sample_id'}, inplace=True)
    
    # Get the path
    provisional_taxa_output_path = provisional_taxa_output_path_template.format(
        sample_set=sample_set
    )
    
    # Export the aggregated data from the DataFrame to a TSV file
    samples_taxon_prob_agg_df_export.to_csv(provisional_taxa_output_path, sep='\t', index=False)


1324-VO-ET-GOLASSA-VMF00257

X
- Partition 0 999999
 - Time 11:54:40
 - Getting classifier
 - Getting encoded genotypes
 - Getting predicted class probabilities
 - Making a DataFrame
- Partition 1000000 1999999
 - Time 11:56:50
 - Getting classifier
 - Getting encoded genotypes
 - Getting predicted class probabilities
 - Making a DataFrame
- Partition 2000000 2999999
 - Time 11:58:41
 - Getting classifier
 - Getting encoded genotypes
 - Getting predicted class probabilities
 - Making a DataFrame
- Partition 3000000 3999999
 - Time 12:00:31
 - Getting classifier
 - Getting encoded genotypes
 - Getting predicted class probabilities
 - Making a DataFrame
- Partition 4000000 4999999
 - Time 12:02:28
 - Getting classifier
 - Getting encoded genotypes
 - Getting predicted class probabilities
 - Making a DataFrame
- Partition 5000000 5999999
 - Time 12:04:22
 - Getting classifier
 - Getting encoded genotypes
 - Getting predicted class probabilities
 - Making a DataFrame
- Partition 6000000 6

In [60]:
exemplar_sample_set = list(agg_diff_df_by_sample_set.keys())[0]
exemplar_sample_set

'1324-VO-ET-GOLASSA-VMF00257'

In [61]:
agg_diff_df_by_sample_set[exemplar_sample_set][['max_mean_prob_taxon', 'max_votes_taxon']]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,max_mean_prob_taxon,max_votes_taxon
release_str,sample_set,sample_id,Unnamed: 3_level_1,Unnamed: 4_level_1
