# Assigning Ancestry Labels Using a Random Forest Model

## Index
1. [Setting Default Paths](#1.Setting-Default-Paths)
2. [Intersecting Two Datasets](#1.-Intersecting-HGDP+1kG-unrelateds-with-GGV)
3. [Applying gnomAD RF Model to HGDP+1kGP-GGV Intersect](#2.-Applying-gnomAD-RF-model-to-HGDP+1kGP+GGV-intersect)
    1. [Plotting PCA After Applying gnomAD RF to HGDP+1kGP-GGV Intersect](#-2a.-Plotting-PCA-after-applying-gnomAD-RF-to-HGDP+1kGP-Intersect)
4. [Building an RF Model Using HGDP+1KGP and Applying to a New Dataset](#3.-Building-a-random-forest-model-from-HGDP+1kGP-and-applying-to-a-new-dataset)
    1. [Plotting PCA After Applying HGDP+1kGP RF to GGV](#3a.-Plotting-PCA-after-building-RF-model-from-HGDP+1kGP-dataset-and-applying-it-to-GGV)

# General Overview:

The purpose of this script is to intersect HGDP+1kGP with the GGV dataset, apply gnomAD's RF model to the intersected dataset, create a RF model using HGDP+1kGP and finally apply that RF model to a new dataset. 

**This script contains information on how to:**
- Intersect two datasets
- Apply a random forest model 
- Build a random forest model
- Plot PCA after applying a RF model to a dataset  

Author: Lindo Nkambule

In [1]:
import hail as hl
import pickle
import pandas as pd
from gnomad.sample_qc.ancestry import assign_population_pcs, pc_project
from sklearn.ensemble import RandomForestClassifier
from typing import Tuple
from bokeh.io import show, output_notebook, output_file
from bokeh.layouts import column, row
from bokeh.plotting import figure
from bokeh.models.widgets import Panel, Tabs
from bokeh.models import ColumnDataSource, Legend, TableColumn, DataTable
from bokeh.transform import factor_cmap
output_notebook()

In [None]:
hl.init()

### tmp read_qc function 
to be removed once tutorials & function are complete and we can troubleshoot importing

In [None]:
import hail as hl

def read_qc(
        raw: bool = False,
        post_qc:bool = False,
        sample_qc: bool = False,
        variant_qc: bool = False,
        outlier_removal: bool = False,
        ld_pruning: bool = False,
        rel_unrel: str = 'default',
        n_partitions: int = 0) -> hl.MatrixTable:
    """
    Wrapper function to get HGDP+1kGP data as Matrix Table at different stages of QC/filtering.
    By raw, returns pre QC MatrixTable with qc filters annotated but not filtered.

    :param bool raw: if True will return a preQC version of the dataset
    :param bool post_qc: if True will return a post QC matrix table that has gone through:
        - sample QC
        - variant QC
        - duplicate removal
        - outlier removal
    :param bool sample_qc: if True will return a post sample QC matrix table
    :param bool variant_qc: if True will return a post variant QC matrix table
    :param bool outlier_removal: if True will return a matrix table with PCA outliers removed
    :param bool ld_pruning: if True will return a matrix table that has gone through:
        - sample QC
        - variant QC
        - duplicate removal
        - LD pruning
        - additional variant filtering
    :param bool rel_unrel: default will return same mt as ld pruned above
        if 'all' will return the same matrix table as if ld_pruning is True
        if 'related_pre_outlier' will return a matrix table with only related samples pre pca outlier removal
        if 'unrelated_pre_outlier' will return a matrix table with only unrelated samples pre pca outlier removal
        if 'related_post_outlier' will return a matrix table with only related samples post pca outlier removal
        if 'unrelated_post_outlier' wil return a matrix table with only unrelated samples post pca outlier removal
    :param int n_partitions: if specified, will read in dataset with given number of partitions for the following arguments:
        - ld_pruning
        - rel_unrel
    """
    # Reading in all the tables and matrix tables needed to generate the pre_qc matrix table
    sample_meta = hl.import_table('gs://hgdp-1kg/hgdp_tgp/qc_and_figure_generation/gnomad_meta_v1.tsv')
    sample_qc_meta = hl.read_table('gs://hgdp_tgp/output/gnomad_v3.1_sample_qc_metadata_hgdp_tgp_subset.ht')
    dense_mt = hl.read_matrix_table(
        'gs://gcp-public-data--gnomad/release/3.1.2/mt/genomes/gnomad.genomes.v3.1.2.hgdp_1kg_subset_dense.mt')
    
    dense_mt = dense_mt.naive_coalesce(5000)


    # Takes a list of dicts and converts it to a struct format (works with nested structs too)
    def dict_to_struct(d):
        fields = {}
        for k, v in d.items():
            if isinstance(v, dict):
                v = dict_to_struct(v)
            fields[k] = v
        return hl.struct(**fields)

    # un-flattening a hail table with nested structure
    # dict to hold struct names as well as nested field names
    d = {}

    # Getting the row field names
    row = sample_meta.row_value

    # returns a dict with the struct names as keys and their inner field names as values
    for name in row:
        def recur(dict_ref, split_name):
            if len(split_name) == 1:
                dict_ref[split_name[0]] = row[name]
                return
            existing = dict_ref.get(split_name[0])
            if existing is not None:
                assert isinstance(existing, dict), existing
                recur(existing, split_name[1:])
            else:
                existing = {}
                dict_ref[split_name[0]] = existing
                recur(existing, split_name[1:])
        recur(d, name.split('.'))

    # using the dict created from flattened struct, creating new structs now un-flattened
    sample_meta = sample_meta.select(**dict_to_struct(d))
    sample_meta = sample_meta.key_by('s')

    # grabbing the columns needed from HGDP metadata
    new_meta = sample_meta.select(sample_meta.hgdp_tgp_meta, sample_meta.bergstrom)

    # creating a table with gnomAD sample metadata and HGDP metadata
    ht = sample_qc_meta.annotate(**new_meta[sample_qc_meta.s])

    # stripping 'v3.1::' from the names to match with the densified MT
    ht = ht.key_by(s=ht.s.replace("v3.1::", ""))

    # Using hl.annotate_cols() method to annotate the gnomAD variant QC metadata onto the matrix table
    mt = dense_mt.annotate_cols(**ht[dense_mt.s])
    
    if raw:
        print("Returning default preQC matrix table")
        # returns preQC dataset
        return mt
    
    if post_qc:
        print("Returning post sample and variant QC matrix table with duplicates and PCA outliers removed")
        sample_qc = True
        variant_qc = True
        duplicate = True
        outlier_removal = True
    
    if sample_qc:
        print("Applying sample QC")
        # Apply sample QC filters to dataset
        # filtering samples to those who should pass gnomADs sample QC
        # this filters to only samples that passed gnomad sample QC hard filters
        mt = mt.filter_cols(~mt.sample_filters.hard_filtered)

    if variant_qc:
        print("Applying variant QC")
        # Apply variant QC filters to dataset
        # Subsetting the variants in the dataset to only PASS variants (those which passed gnomAD's variant QC)
        # PASS variants are variants which have an entry in the filters field.
        # This field contains an array which contains a bool if any variant qc filter was failed
        # This is the last step in the QC process
        mt = mt.filter_rows(hl.len(mt.filters) != 0, keep=False)

    if outlier_removal:
        print("Removing PCA outliers")
        # remove PCA outliers
        # reading in the PCA outlier list
        # To read in the PCA outlier list, first need to read the file in as a list
        # using hl.hadoop_open here which allows one to read in files into hail from Google cloud storage
        pca_outlier_path = 'gs://hgdp-1kg/hgdp_tgp/pca_outliers_v2.txt'
        with hl.utils.hadoop_open(pca_outlier_path) as file:
            outliers = [line.rstrip('\n') for line in file]

        # Using hl.literal here to convert the list from a python object to a hail expression so that it can be used
        # to filter out samples
        outliers_list = hl.literal(outliers)

        # Using the list of PCA outliers, using the ~ operator which is a negation operator and obtains the compliment
        # In this case the compliment is samples which are not contained in the pca outlier list
        mt = mt.filter_cols(~outliers_list.contains(mt['s']))

    if ld_pruning:
        print("Returning ld pruned post variant and sample QC matrix table pre PCA outlier removal ")
        # read in dataset which has additional variant filtering and ld pruning run
        # data has gone through:
        #   - sample QC
        #   - variant QC
        #   - duplicate removal
        if n_partitions != 0:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/intermediate_files/filtered_n_pruned_output_updated.mt',
            _n_partitions = n_partitions)
        else:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/intermediate_files/filtered_n_pruned_output_updated.mt')

    if rel_unrel == "default":
        # do nothing
        # created a default value because there are multiple options for rel/unrel datasets
        mt = mt

    elif rel_unrel == 'related_pre_outlier':
        print("Returning post sample and variant QC matrix table " \
              "pre PCA outlier removal with only related individuals")
        # data has gone through:
        #   - sample QC
        #   - variant QC
        #   - duplicate removal
        #   - LD pruning
        #   - pc_relate 
        #   - filter to only related individuals   
        if n_partitions != 0:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/rel_updated.mt',
            _n_partitions = n_partitions)
        else:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/rel_updated.mt')
        
    elif rel_unrel == 'unrelated_pre_outlier':
        print("Returning post QC matrix table with only unrelated individuals")
        # data has gone through:
        #   - sample QC
        #   - variant QC
        #   - duplicate removal
        #   - LD pruning
        #   - pc_relate 
        #   - filter to only unrelated individuals
        if n_partitions != 0:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/unrel_updated.mt',
            _n_partitions = n_partitions)
        else:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/unrel_updated.mt')

    elif rel_unrel == 'related_post_outlier':
        print("Returning post sample and variant QC matrix table " \
              "pre PCA outlier removal with only related individuals")
        # data has gone through:
        #   - sample QC
        #   - variant QC
        #   - duplicate removal
        #   - LD pruning
        #   - pc_relate 
        #   - filter to only related individuals
        #   - PCA outlier removal
        if n_partitions != 0:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/related.mt',
            _n_partitions = n_partitions)
        else:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/related.mt')

    elif rel_unrel == 'unrelated_post_outlier':
        print("Returning post sample and variant QC matrix table " \
              "pre PCA outlier removal with only related individuals")
        # data has gone through:
        #   - sample QC
        #   - variant QC
        #   - duplicate removal
        #   - LD pruning
        #   - pc_relate 
        #   - filter to only unrelated individuals
        #   - PCA outlier removal
        if n_partitions != 0:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/unrelated.mt',
            _n_partitions = n_partitions)
        else:
            mt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/unrelated.mt')
        
    # Calculating both variant and sample_qc metrics on the mt before returning
    # so the stats are up to date with the version being written out
    mt = hl.sample_qc(mt)
    mt = hl.variant_qc(mt)
    
    return mt

# 1. Set Default Paths
These default paths can be edited by users as needed. It is recommended to run these tutorials without writing out datasets. The read_qc() function is intended to take the place of needing to write out and read in datasets by the user. 

By default we have commented out all of the write steps of the tutorials, if you would like to write out your own datasets, uncomment those sections and replace the paths with your own. 

[Back to Index](#Index)

# 2. Intersecting HGDP+1kGP unrelateds with GGV
The first step in building the random forest model is to intersect the HGDP+1kGP dataset with the Gambian Genome Variation Project dataset.  
<br>
<details><summary> For more information on Hail methods and expressions click <u><span style="color:blue">here</span></u>.</summary> 
    
<ul>
<li><a href="https://hail.is/docs/0.2/hail.MatrixTable.html#hail.MatrixTable.key_rows_by"> More on  <i> key_rows_by() </i></a></li>

<li><a href="https://hail.is/docs/0.2/hail.expr.Expression.html#hail.expr.Expression.collect"> More on  <i> collect() </i></a></li>

<li><a href="https://hail.is/docs/0.2/hail.MatrixTable.html#hail.MatrixTable.union_rows"> More on  <i> union_cols() </i></a></li>
</ul>
    
</details>

[Back to Index](#Index)

In [2]:
# use large HGDP+1KG
mt_post_qc = read_qc(default=True, n_partitions=500)
                            
print(f'Number of variants in HGDP+1KG before intersecting: {mt_post_qc.count_rows()}')
mt_ggv = hl.read_matrix_table('gs://gnomaf/gambian-genomes/COMBINED_GVCFS/gambian_genomes_merged_gvcfs.mt',
                             _n_partitions=500)
# GGV dataset is a sparse MT from combining GVCFs. Hail still keeps the non-variant sites (contain only REF allele)
# so we have to filter to variant-sites only
mt_ggv = mt_ggv.filter_rows(hl.len(mt_ggv.alleles) > 1)
print(f'Number of variant sites only in GGV before intersecting: {mt_ggv.count_rows()}')
# only variants sites

Initializing Hail with default parameters...
Running on Apache Spark version 3.1.2
SparkUI available at http://qc-notebook4-m.c.diverse-pop-seq-ref.internal:38921
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.95-513139587f57
LOGGING: writing to /home/hail/hail-20220830-1818-0.2.95-513139587f57.log


Number of variants in HGDP+1KG before intersecting: 155648020
Number of variant sites only in GGV before intersecting: 61164017


### In order to combine two datasets, three requirements must be met:

1. The row keys must match.

2. The column key schemas and column schemas must match.

3. The entry schemas must match.

In [3]:
mt_post_qc_unkeyed = mt_post_qc.key_cols_by().key_rows_by()
mt_post_qc_clean = mt_post_qc_unkeyed.select_cols(mt_post_qc_unkeyed.s)
mt_post_qc_clean = mt_post_qc_clean.select_rows(
    mt_post_qc_clean.locus, mt_post_qc_clean.alleles, mt_post_qc_clean.rsid)
mt_post_qc_clean = mt_post_qc_clean.select_entries(
    mt_post_qc_clean.GT)

# put back the keys
mt_post_qc_clean = mt_post_qc_clean.key_cols_by('s').key_rows_by(*['locus', 'alleles'])
hgdp_tgp_samples = mt_post_qc_clean.s.collect()

In [4]:
mt_ggv_unkeyed = mt_ggv.key_cols_by().key_rows_by()
mt_ggv_clean = mt_ggv_unkeyed.select_cols(mt_ggv_unkeyed.s)
mt_ggv_clean = mt_ggv_clean.select_rows(
    mt_ggv_clean.locus, mt_ggv_clean.alleles, mt_ggv_clean.rsid)
mt_ggv_clean = mt_ggv_clean.select_entries(GT = mt_ggv_clean.LGT)

# put back the keys
mt_ggv_clean = mt_ggv_clean.key_cols_by('s').key_rows_by(*['locus', 'alleles'])

# collect GGV samples to list so we can later use this to check how they were classified by the RF model
ggv_samples = mt_ggv_clean.s.collect()

In [5]:
hgdp_tgp_ggv_intersect = mt_post_qc_clean.union_cols(mt_ggv_clean)

In [6]:
# This step takes a while, so I've already checkpointed to save time
# hgdp_tgp_ggv_intersect = hgdp_tgp_ggv_intersect.checkpoint('gs://hgdp-1kg/hgdp_tgp_ggv_intersect.mt')

In [7]:
hgdp_tgp_ggv_intersect = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp_ggv_intersect.mt')

In [8]:
print(f'Number of variants after intersecting HGDP+1KG with GGV: {hgdp_tgp_ggv_intersect.count_rows()}')

Number of variants after intersecting HGDP+1KG with GGV: 26452039


# 2. Applying gnomAD RF model to HGDP+1kGP+GGV intersect
<br>
<details><summary> For more information on Hail methods and expressions click <u><span style="color:blue">here</span></u>.</summary> 
    
<ul>
<li><a href="https://hail.is/docs/0.2/experimental/index.html#hail.experimental.pc_project"> More on  <i> pc_project() </i></a></li>

<li><a href="https://hail.is/docs/0.2/utils/index.html#hail.utils.hadoop_open"> More on  <i> hadoop_open() </i></a></li>

</ul>
    
</details>

[Back to Index](#Index)

In [9]:
# gnomAD loadings Hail Table
loadings_ht = hl.read_table('gs://gcp-public-data--gnomad/release/3.1/pca/gnomad.v3.1.pca_loadings.ht')

# Project new genotypes onto loadings
ht = hl.experimental.pc_project(
    hgdp_tgp_ggv_intersect.GT,
    loadings_ht.loadings,
    loadings_ht.pca_af,
)

2022-08-30 18:24:26 Hail: WARN: cols(): Resulting column table is sorted by 'col_key'.
    To preserve matrix table column order, first unkey columns with 'key_cols_by()'


In [10]:
print(f'Number of variants in gnomAD loadings: {loadings_ht.count()}')

Number of variants in gnomAD loadings: 76399


In [11]:
hgdp_tgp_ggv_intersect = hgdp_tgp_ggv_intersect.annotate_rows(
        pca_loadings=loadings_ht[hgdp_tgp_ggv_intersect.row_key]['loadings'],
        pca_af=loadings_ht[hgdp_tgp_ggv_intersect.row_key]['pca_af'],
    )

In [12]:
# Get the number of variants found in gnomAD loadings and hgdp_tgp_ggv_intersect
# the higher the missingness, the less accurate the classification will be 
gnomad_loadings_data_interset_count = hgdp_tgp_ggv_intersect.filter_rows(hl.is_defined(hgdp_tgp_ggv_intersect.pca_loadings)
                                   & hl.is_defined(hgdp_tgp_ggv_intersect.pca_af)).count_rows()

In [13]:
print(f'Number of variants common between HGDP+1KG+GGV & gnomAD RF: {gnomad_loadings_data_interset_count}')

Number of variants common between HGDP+1KG+GGV & gnomAD RF: 39557


In [14]:
# Load gnomAD RF model
with hl.hadoop_open('gs://gcp-public-data--gnomad/release/3.1/pca/gnomad.v3.1.RF_fit.pkl', 'rb') as f:
    fit = pickle.load(f)



In [15]:
# Reduce the scores to only those used in the RF model, this was 6 for v2 and 16 for v3.1
num_pcs = fit.n_features_
ht = ht.annotate(scores=ht.scores[:num_pcs])

# assign population labels based on PCA results
ht, rf_model = assign_population_pcs(
    ht,
    pc_cols=[(i + 1) for i in range(num_pcs)],
    fit=fit,
)

2022-08-30 18:25:44 Hail: INFO: Coerced sorted dataset
INFO (gnomad.sample_qc.ancestry 230): Found the following sample count after population assignment: oth: 2989, amr: 366, afr: 1108, sas: 49, nfe: 1


In [16]:
gnomad_rf_output = ht.transmute(**{f'PC{i}': ht.pca_scores[i - 1] for i in range(1, num_pcs+1)})
gnomad_rf_output = gnomad_rf_output.to_pandas()
gnomad_rf_output['pop'] = gnomad_rf_output['pop'].str.upper()

2022-08-30 18:25:58 Hail: INFO: Coerced sorted dataset


## 2a. Plotting PCA after applying gnomAD RF to HGDP+1kGP+GGV Intersect

[Back to Index](#Index)

In [17]:
color_map = {'AFR': "#984EA3", 'EAS': "#4DAF4A", 'EUR': "#377EB8", 'CSA': "#FF7F00",
             'AMR': "#E41A1C", 'MID': "#A65628", 'OCE': "#000000", 'OTH': "#F0E442"}

tabs1 = []

ref_samples_df1 = gnomad_rf_output[gnomad_rf_output['s'].isin(hgdp_tgp_samples)]
ggv_samples_df1 = gnomad_rf_output[gnomad_rf_output['s'].isin(ggv_samples)]

def plot_pca(ref_df=None, data_df=None, pc1=None, pc2=None):
    pref = figure(width=600, height=500, background_fill_color='#fafafa', title = 'HGDP+1KG')
    pref.add_layout(Legend(), 'right')
    pref.xaxis.axis_label = pc1
    pref.yaxis.axis_label = pc2
    
    pdata = figure(width=600, height=500, background_fill_color='#fafafa', title = 'GGV')
    pdata.add_layout(Legend(), 'right')
    pdata.xaxis.axis_label = pc1
    pdata.yaxis.axis_label = pc2
    
    pcomb = figure(width=600, height=500, background_fill_color='#fafafa', title = 'HGDP+1KG+GGV')
    pcomb.add_layout(Legend(), 'right')
    pcomb.xaxis.axis_label = pc1
    pcomb.yaxis.axis_label = pc2
    pcomb.circle(ref_df[pc1].tolist(), ref_df[pc2].tolist(), size=3, color='grey', alpha=0.3)

    for pop, col in color_map.items():
        # reference
        pref.circle(ref_df[(ref_df['pop'] == pop)][pc1].tolist(), ref_df[(ref_df['pop'] == pop)][pc2].tolist(),
                    size=3, color=col, alpha=0.8, legend_label=pop)
        
        # data
        pdata.circle(data_df[(data_df['pop'] == pop)][pc1].tolist(), data_df[(data_df['pop'] == pop)][pc2].tolist(),
                     size=3, color=col, alpha=0.8, legend_label=pop)
        
        # ref+data combined
        pcomb.circle(data_df[(data_df['pop'] == pop)][pc1].tolist(), data_df[(data_df['pop'] == pop)][pc2].tolist(),
                     size=3, color=col, alpha=0.8, legend_label=pop)
        
    return pref, pdata, pcomb


for i in range(1, num_pcs, 2):
    xpc = f'PC{i}'
    ypc = f'PC{i + 1}'
    
    p1, p2, p3 = plot_pca(ref_df=ref_samples_df1, data_df=ggv_samples_df1, pc1=xpc, pc2=ypc)
        
    tab = Panel(child=column(row(p1, p2), row(p3)), title=f'{xpc}v{ypc}')

    tabs1.append(tab)

In [18]:
show(Tabs(tabs=tabs1))


# 3. Building a random forest model from HGDP+1kGP and applying to a new dataset
In the following steps we are building a random forest (RF) model with unrelated individuals from the 1kGP+HGDP dataset. This was done using global region labels. 
We then apply the model to the Gambian Genome Variation Project (GGV) dataset. 

[INSERT LINK] For more information on Random Forest models click [here]().

[INSERT LINK] For more information on the GGV dataset click [here]().
    

[Back to Index](#Index)

In [19]:
def intersect_ref(ref_mt: hl.MatrixTable = None, data_mt: hl.MatrixTable = None):
    data_in_ref = data_mt.filter_rows(hl.is_defined(ref_mt.rows()[data_mt.row_key]))
    print('sites in ref and data, inds in data: {}'.format(data_in_ref.count()))

    ref_in_data = ref_mt.filter_rows(hl.is_defined(data_mt.rows()[ref_mt.row_key]))
    print('sites in ref and data, inds in ref: {}'.format(ref_in_data.count()))
    
    return ref_in_data, data_in_ref


def run_ref_pca(mt: hl.MatrixTable = None, npcs: int = 20):
    pca_evals, pca_scores, pca_loadings = hl.hwe_normalized_pca(mt.GT, k=npcs, compute_loadings=True)
    pca_mt = mt.annotate_rows(pca_af=hl.agg.mean(mt.GT.n_alt_alleles()) / 2)
    pca_loadings = pca_loadings.annotate(pca_af=pca_mt.rows()[pca_loadings.key].pca_af)

    # individual-level PCs
    pca_scores = pca_scores.transmute(**{f'PC{i}': pca_scores.scores[i - 1] for i in range(1, npcs+1)})
    
    return pca_loadings, pca_scores


def merge_data_with_ref(ref_scores: hl.Table = None,
        ref_info: str = 'gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/hgdp_1kg_sample_info.unrelateds.pca_outliers_removed.with_project.tsv',
        data_scores: hl.Table = None) -> pd.DataFrame:
    print('Merging data with ref')
    ref_info = hl.import_table(ref_info,
                           impute=True, key='Sample')
    ref_merge = ref_scores.annotate(SuperPop = ref_info[ref_scores.s].SuperPop)

    print('merging data and ref data')
    data_ref = ref_merge.union(data_scores, unify=True)
    print('Done merging data with ref')

    return data_ref


In [20]:
# use pruned postQC MT with unrelated individuals to speed up things
mt_unrel = read_qc(rel_unrel='unrelated_post_outlier', n_partitions=500)

In [21]:
hgdp_tgp_in_ggv_mt, ggv_in_hgdp_tgp_mt = intersect_ref(ref_mt=mt_unrel, data_mt=mt_ggv)

sites in ref and data, inds in data: (211741, 394)
sites in ref and data, inds in ref: (211741, 3380)


In [22]:
ref_pca_loadings, ref_pca_scores = run_ref_pca(mt=hgdp_tgp_in_ggv_mt, npcs=20)

2022-08-30 18:34:37 Hail: INFO: hwe_normalize: found 211741 variants after filtering out monomorphic sites.
2022-08-30 18:36:32 Hail: INFO: pca: running PCA with 20 components...
2022-08-30 18:41:42 Hail: INFO: Coerced sorted dataset


In [23]:
# project data
# the gnomAD pc_project function requires genotype to be encoded as GT, not LGT
ggv_in_hgdp_tgp_mt = ggv_in_hgdp_tgp_mt.select_entries(GT = ggv_in_hgdp_tgp_mt.LGT)

data_projections_ht = pc_project(mt=ggv_in_hgdp_tgp_mt, loadings_ht=ref_pca_loadings,
                                 loading_location='loadings', af_location='pca_af')

data_scores = data_projections_ht.transmute(**{f'PC{i}': data_projections_ht.scores[i - 1] for i in range(1, 20+1)})

In [24]:
data_ref = merge_data_with_ref(ref_scores=ref_pca_scores, data_scores=data_scores)

data_ref_df = data_ref.to_pandas()

Merging data with ref


2022-08-30 18:42:34 Hail: INFO: Reading table to impute column types
2022-08-30 18:42:36 Hail: INFO: Finished type imputation
  Loading field 'Sample' as type str (imputed)
  Loading field 'SuperPop' as type str (imputed)
  Loading field 'Project' as type str (imputed)


merging data and ref data
Done merging data with ref


2022-08-30 18:48:38 Hail: INFO: Ordering unsorted dataset with network shuffle
2022-08-30 18:48:38 Hail: INFO: Coerced sorted dataset
2022-08-30 18:48:38 Hail: INFO: Coerced sorted dataset


In [25]:
ht, rf_model = assign_population_pcs(
    data_ref_df,
    pc_cols=['PC{}'.format(i + 1) for i in range(20)],
    known_col="SuperPop",
)

INFO (gnomad.sample_qc.ancestry 230): Found the following sample count after population assignment: EUR: 662, oth: 330, EAS: 718, AMR: 387, CSA: 669, AFR: 843, OCE: 27, MID: 138


Random forest feature importances are as follows: [0.18639292 0.1884225  0.15411622 0.12593267 0.10580455 0.05179108
 0.05079407 0.03028211 0.02223709 0.01260428 0.01717959 0.0051473
 0.01624743 0.00080003 0.01039461 0.00760396 0.00246636 0.00647793
 0.00432485 0.00098046]
Estimated error rate for RF model is 0.004437869822485174


## 3a. Plotting PCA after building RF model from HGDP+1kGP dataset and applying it to GGV

[Back to Index](#Index)

In [26]:
color_map = {'AFR': "#984EA3", 'EAS': "#4DAF4A", 'EUR': "#377EB8", 'CSA': "#FF7F00",
             'AMR': "#E41A1C", 'MID': "#A65628", 'OCE': "#000000", 'oth': "#F0E442"}

tabs2 = []

ref_samples_df2 = data_ref_df[data_ref_df['s'].isin(hgdp_tgp_samples)]
ggv_samples_df2 = data_ref_df[data_ref_df['s'].isin(ggv_samples)]

def plot_pca(ref_df=None, data_df=None, pc1=None, pc2=None):
    
    pcomb = figure(width=600, height=500, background_fill_color='#fafafa', title = 'Reference+Projected')
    pcomb.add_layout(Legend(), 'right')
    pcomb.xaxis.axis_label = pc1
    pcomb.yaxis.axis_label = pc2
    pcomb.circle(ref_df[pc1].tolist(), ref_df[pc2].tolist(), size=3, color='grey', alpha=0.8)

    for pop, col in color_map.items():
        # ref+data combined
        pcomb.circle(data_df[(data_df['pop'] == pop)][pc1].tolist(), data_df[(data_df['pop'] == pop)][pc2].tolist(),
                     size=3, color=col, alpha=0.8, legend_label=pop)
        
    return pcomb

for i in range(1, 20, 2):
    xpc = f'PC{i}'
    ypc = f'PC{i + 1}'
    
    plot1 = plot_pca(ref_df=ref_samples_df2, data_df=ggv_samples_df2, pc1=xpc, pc2=ypc)
        
    tab = Panel(child=plot1, title=f'{xpc}v{ypc}')

    tabs2.append(tab)

In [27]:
show(Tabs(tabs=tabs2))


In [28]:
# Get counts by POP
ggv_samples_df2['pop'].value_counts()

AFR    394
Name: pop, dtype: int64