# ViGWAS
### This notebook is written using Hail v0.2 and VariantSpark.

# User Block

In [None]:
## read from sample annotations and vcf files
sample_annot_deli = ',' #'\t' for tsv and ',' for csv

#S3_Path='s3://YOUR_BUCKET_OR_PATH/ViGWAS/'
S3_Path='s3://csiro-tb/notebooks/Hail2/ViGWAS/'

# Cccept multiple sample annotation file
sample_annot_file_list = [S3_Path+'sample_input/hipster.csv']

# support 'vcf' or 'plink'
mt_file_type = 'vcf'

# 'variant' for different variants in different files, 'sample' for same variants in different files
mt_merge_type = 'sample' 

if (mt_merge_type=='sample'):
    mt_file_list = [S3_Path+'sample_input/S1.vcf.bgz', S3_Path+'sample_input/S2.vcf.bgz']
    # The name of result directory in your S3_Path
    analysis_name = 'my_analysis_s'
else:
    mt_file_list = [S3_Path+'sample_input/V1.vcf.bgz', S3_Path+'sample_input/V2.vcf.bgz']
    # The name of result directory in your S3_Path
    analysis_name = 'my_analysis_v'
    
## For plink files (bed bim fam), provide prefix only
#mt_file_list = ['paths/to/plink-1', 'paths/to/plink-2']



downsample_percent = 0.01 # ratio of variants to be selected (randomly) for QC plots
graph_type = 'stack' # or 'group': representation when 2 variables are plotted at the same time
fields_to_plot = ['isFemale', 'Population', 'isCase'] # list of fields in sample annotations for QC plotting

n_factor = 4 # number of factors for PCA

## variant-spark
mtry_fraction=0.1
num_of_tree=100
batch_size=25
min_node_size=50
max_depth=10

## Some configs
numCPU = 256
memory = '60g'

# Environment initialization

In [None]:
## Environment init

import os
from pyspark import SparkContext
sc = SparkContext()

import hail as hl
import varspark.hail as vshl
vshl.init(sc=sc)

from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, FactorRange, LabelSet, Label
from bokeh.transform import factor_cmap
from bokeh.palettes import d3
from bokeh.core.properties import value
from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.layouts import gridplot
from bokeh.models.mappers import CategoricalColorMapper

from pprint import pprint
#output_notebook()

import re
import numpy as np
import math as math
import sys
import operator
from collections import OrderedDict
import subprocess
from itertools import cycle
import shutil

In [None]:
%%sh
sudo yum install -y git
cd ~
rm -rf VIGWAS
git clone https://github.com/aehrc/VIGWAS.git
echo "=========="
ls -l

In [None]:
os.chdir('/home/hadoop/VIGWAS/')
result_dir = re.sub(r'\s','_',analysis_name)
cwd = os.getcwd()
result_dir = cwd + '/' +result_dir
print(result_dir)
numPartitions = numCPU*4

## Some general functions

In [None]:
## loop around fileds and generate graph for each field
#  get rows field count, names, types
## INPUT: table 
## OUTPUT: col_count(int), col_names(list<str>), col_types(list<str>)
def get_table_fields(t):
    col_names = list(t.row.dtype.keys())
    col_types = list(t.row.dtype.values())
    col_types = [str(w).replace('dtype', '') for w in col_types]
    col_dict = dict(zip(col_names,col_types))
    return(col_dict)

In [None]:
## save plot in a html file
## and write a plot division with proper headings in html
## INPUT: bokeh figure object p, filename_to_save(str), folder_to_save(str), id for the division in html(str)
## OUTPUT: division html(str)
def save_plot(p, title, folder, id_link):
    html = file_html(p, CDN, title)
    filepath = 'plots/'+folder+'/'+title+'.html'
    directory = os.path.dirname(filepath)
    if not os.path.exists(directory):
        os.makedirs(directory)
    f = open(filepath, "w+")
    f.write(html)
    f.close()
    plot_html = "<div>\n<h3 id='" + id_link + "' class='w3-text-teal'>" + title + "</h3>\n"
    plot_html = plot_html + '<object width="100%" height="' + str(p.plot_height+100) + '" data="../'+filepath+'"></object>\n'
    plot_html = plot_html + "</div>\n<br><br>\n"
    return(plot_html)

In [None]:
## given a float/int field in rows/cols table
## generate a histogram of the field with specified downsample percent
## INPUT: matrixtable, field(expression), field_name(str), downsample_percent(float), analyse_name
## OUTPUT: bokeh plot
def get_hist_graph(mt, field, field_name, downsample_percent, isVariant, analyse_name):
    title=field_name + ' Histogram'
    id_link = analyse_name + '-' + field_name
    
    if isVariant:
        stats = mt.aggregate_rows(hl.expr.aggregators.stats(field))
        if stats == None:
            plot_html = "<div>\n<h3 id='" + id_link + "' class='w3-text-teal'>" + title + "</h3>\n"
            plot_html = plot_html + '<p> divisor equals 0 exist. Unable to generate a histogram with NA values </p><br>'
            plot_html = plot_html + "</div>\n<br><br>\n"
            return(plot_html)
        unique_count = len(mt.aggregate_rows(hl.expr.aggregators.counter(field)))
        hist = mt.aggregate_rows(hl.expr.aggregators.hist(field,stats.min-1,stats.max+1,int(unique_count*downsample_percent)+1))
    else:
        stats = mt.aggregate_cols(hl.expr.aggregators.stats(field))
        if stats == None:
            plot_html = "<div>\n<h3 id='" + id_link + "' class='w3-text-teal'>" + title + "</h3>\n"
            plot_html = plot_html + '<p> divisor equals 0 exist. Unable to generate a histogram with NA values </p><br>'
            plot_html = plot_html + "</div>\n<br><br>\n"
            return(plot_html)
        unique_count = len(mt.aggregate_cols(hl.expr.aggregators.counter(field)))
        hist = mt.aggregate_cols(hl.expr.aggregators.hist(field,stats.min-1,stats.max+1,int(unique_count*downsample_percent)+1))
    
    p = hl.plot.histogram(hist, legend=field_name, title=title)
    print(title)
    return(save_plot(p, title,analyse_name, id_link))

In [None]:
### get unique value counts of given field
def get_unique_values(mt, field, isVariant):
    if isVariant:
        unique_values = mt.aggregate_rows(hl.expr.aggregators.counter(field))
    else:    
        unique_values = mt.aggregate_cols(hl.expr.aggregators.counter(field))

    return unique_values


In [None]:
## given a matrixtable and a row/col field
## find if the field is categorical or not
def isCat(mt, field, isVariant):
    unique_count = len(get_unique_values(mt, field, isVariant))
    if unique_count <= 10:
        isCat = True
    else:
        isCat = False
    
    return isCat

In [None]:
## given a matrix table and a field
## produce a bar graph of the field
## INPUT: matrixtable, field<expression>, field_name<str>, isVariant<boolean>, analyse_name<str>
## OUTPUT: bargraph p
def get_bar_graph(mt, field, field_name, isVariant, analyse_name):
    legend = field_name
    title = legend + ' bar graph'
    width = 0.5
    
    unique_values = get_unique_values(mt,field, isVariant)
    top_list = list(unique_values.values())
    x_co_names = [ str(i) for i in list(unique_values.keys())]
    
    source = ColumnDataSource(data=dict(x_names = x_co_names, tops = top_list))
    p = figure(title=title, x_range = x_co_names, x_axis_label=legend, y_axis_label='Frequency', background_fill_color='#EEEEEE',
              tooltips="@x_names: @tops", plot_height =600)
    p.vbar(
        x = 'x_names', width=width, top='tops', bottom=0, source=source)
    id_link = analyse_name + '-' +field_name
    #show(p)
    return(save_plot(p, title,analyse_name,id_link))

In [None]:
## for get_bar_graph_A_by_B
## get top values for combo keys
## add 0 to non-exist entries
## INPUT: A_values(list), filtered matrix table, selected filtered fieldA(expr), cols/rows(boolean)
def get_top_values(A_values,mt_filtered, field_filtered, isVariant):
    tops = get_unique_values(mt_filtered, field_filtered, isVariant)
    for a in A_values:
        if a not in tops.keys():
            tops[a] = 0
    return [tops[a] for a in A_values]

In [None]:
## for get_bar_graph_A_by_B
## map data from multi series into one tuple orderly
## for grouped bar graph input
## INPUT:  counts of each combo keyed by A_values(dict)
## OUTPUT: A tuple of counts ordered by x (combo key variable)
def map_counts(tops):
    counts = []
    index = 0
    key_names = list(tops.keys())
    while index < len(tops[key_names[0]]):
        list_to_append = [tops[a][index]for a in key_names]
        counts.extend(list_to_append)
        index = index + 1
    counts = tuple(counts)
    
    return(counts)

In [None]:
## get A grouped by B
## INPUT: matrixtable, field_names(list<str>) [A,B](full name after mt), graph_type(str) :group or stack, analyse_name(str)
## OUTPUT: a bar graph of A by B
def get_bar_graph_A_by_B(mt, field_names, graph_type, analyse_name):
    width = 0.5
    palette = d3['Category10'][10]
    # get field names and unique values
    name_indexes = [ field_name.split('.') for field_name in field_names]
    field_values = [list(get_unique_values(mt,mt[name[0]][name[1]],False).keys()) for name in name_indexes]
    title = name_indexes[0][1] + ' by ' + name_indexes[1][1]
    id_link = analyse_name + '-' + name_indexes[0][1] + 'by' + name_indexes[1][1]
    # get filter data and feed into top
    tops = {}
    for v in field_values[1]:
        mt_filtered = mt.filter_cols(mt[name_indexes[1][0]][name_indexes[1][1]] == v, keep=True)
        field_filtered = mt_filtered[name_indexes[0][0]][name_indexes[0][1]]
        tops[str(v)] = list(get_top_values(field_values[0], mt_filtered, field_filtered, False))
    # string type everything
    field_values[0] = [str(x) for x in field_values[0]]
    field_values[1] = [str(x) for x in field_values[1]]
    
    if graph_type == 'group':
        # format and map data
        x = [(str(fn1), str(fn2)) for fn1 in field_values[0] for fn2 in field_values[1]]
        counts = map_counts(tops)
        labels = [ str(fn2) for fn1 in field_values[0] for fn2 in field_values[1]]
        source = ColumnDataSource(data=dict(x=x, counts=counts,labels= labels))

        p = figure(x_range = FactorRange(*x), plot_height=600, title=title, tooltips="@x: @counts")

        p.vbar(x='x', top='counts', width=0.9, source=source,
               fill_color=factor_cmap('x', 
                                      palette=palette, 
                                      factors=[str(value) for value in field_values[1]],
                                      start=1, end=2),
                legend='labels')
    elif graph_type == 'stack':
        data = {'A_values': field_values[0]}
        data.update(tops)
        source = ColumnDataSource(data)
        colors = palette[0:len(field_values[1])]
        print(field_values[0])
        p = figure(x_range=field_values[0], plot_height=600, title=title,
                  toolbar_location='right', tooltips="$name @A_values: @$name")


        p.vbar_stack(list(tops.keys()),
                     x='A_values',
                      width=0.9, color=colors, source=source,
                     legend=[value(x) for x in list(tops.keys())])

    p.y_range.start = 0
    p.xgrid.grid_line_color = None
    p.axis.minor_tick_line_color = None
    p.outline_line_color = None
    p.legend.location = "top_right"
    p.legend.orientation = "horizontal"


    p.xaxis.major_label_orientation = 1
    p.xgrid.grid_line_color = None
    
    return(save_plot(p, title,analyse_name, id_link))

In [None]:
def create_side_bar(field_list, analyse_name):
    sidebar_html = '<nav class="w3-sidebar w3-bar-block w3-collapse w3-large w3-theme-l5 w3-animate-left" id="mySidebar">\
  <a href="javascript:void(0)" onclick="w3_close()" class="w3-right w3-xlarge w3-padding-large w3-hover-black w3-hide-large" title="Close Menu">\
    <i class="fa fa-remove"></i>\
  </a>\
  <h4 class="w3-bar-item"><b>'+ analyse_name + '</b></h4>\n'
    for f in field_list:
        sidebar_html = sidebar_html + '<a class="w3-bar-item w3-button" href="#' + analyse_name +'-' + f +'">' + f + ' Distributions</a>\n'
        
    sidebar_html = sidebar_html + '</nav>\
    <!-- Overlay effect when opening sidebar on small screens -->\
<div class="w3-overlay w3-hide-large" onclick="w3_close()" style="cursor:pointer" title="close side menu" id="myOverlay"></div>'
    return(sidebar_html)

# Read Files

## Load Functions

In [None]:
## read in sample annotation files 
## and check compulsory fields and field values
## INPUT: path-to-file
## OUTPUT: table t
def read_sample_annot(file, deli):#, field_std_names, field_std_values):
    t = hl.import_table(file, impute=True, delimiter=deli)
    # get table fields
    col_dict = get_table_fields(t)
    # check mandatory fields
    if 'Sample' not in list(col_dict.keys()):
        sys.exit('Sample does not exist!')
    t= t.key_by('Sample')
    
    if 'isCase' in list(col_dict.keys()):
        if col_dict['isCase'] != 'bool':
            sys.exit('isCase is not bool')
    else:
        sys.exit('isCase does not exist!')
    
    t = t.annotate(CaseControl = hl.cond(t.isCase, 'Case', 'Control'))
        
    # check other optional fields
    if 'isFemale' in list(col_dict.keys()):
        if col_dict['isFemale'] != 'bool':
            sys.exit('isFemale is not bool!')
    else:
        sys.exit('isFemale does not exist!')
    t = t.annotate(Gender = hl.cond(t.isFemale,'Female', 'Male'))
    
    return(t)

In [None]:
def read_metadata(metadata, merge_type, file_type):
    mts = []    
    ## read in files
    if file_type == 'plink':
        for file in metadata:
            print('plink file: ' + file)
            mt_name = file + '.mt'
            hl.import_plink(bed= file + '.bed', bim = file + '.bim', fam = file + '.fam',skip_invalid_loci=True, min_partitions=int(numPartitions)).write(mt_name,overwrite=True)
            mts.append(hl.read_matrix_table(mt_name))
    elif file_type == 'vcf':
        for file in metadata:
            print('vcf file: ' + file)
            mt_name = re.sub(r'vcf(\.bgz)$','mt', file )
            hl.import_vcf(path=file,skip_invalid_loci=True, min_partitions=int(numPartitions)).write(mt_name,overwrite=True)
            mts.append(hl.read_matrix_table(mt_name))
    else:
        sys.exit('Metadata file type not provided or supported!')
    
    ## merge files
    if merge_type == 'variant':
        mt = hl.MatrixTable.union_rows(*mts)
    elif merge_type == 'sample':
        mt = mts[0]
        for cur_mt in mts[1:]:
            mt = mt.union_cols(cur_mt)
    
    return(mt, mts)
    

In [None]:
## combine sample annotation table to the matrix table
## INPUT: t(table) ,mt(MatrixTable)
## OUTPUT: matrixtable with t joined to the cols
def join_rows_cols(t, mt):
    # get id info of two tables
    id_list_annot = [ re.sub('^(.*=)\'(.*)\'(.*)$','\\2', str(w) ) for w in t.select().collect()]
    id_list_mt = [ re.sub('^(.*=)\'(.*)\'(.*)$','\\2', str(w) ) for w in mt.cols().key_by('s').select().collect()]
   
    # compare id info
    isSubset = all(elem in id_list_annot  for elem in id_list_mt)
    if not isSubset:
        sys.exit('Sample annotations missing!')
    # join mt
    mt = mt.annotate_cols(pheno = t[mt.s])
    # break multiallelic
    mt = hl.split_multi_hts(mt)
    ## get variant_qc and sample_qc
    mt = hl.variant_qc(mt)
    mt = hl.sample_qc(mt)
    # remove duplicates
    mt = mt.distinct_by_row()
    mt = mt.distinct_by_col()
    return(mt)
        

In [None]:
def print_init_summary_html(sample_annot, metadata, mt, mts):#, field_std_names, field_std_values):
    
    summary_html = '<p> ' + '<b>Annotation File: </b>' 
    summary_html = summary_html + '<ul style="list-style-type: disc;">\n'
    ## for each annotation file
    for file in sample_annot:
        summary_html = summary_html + '<li>' + file + '</li>\n'
    summary_html = summary_html + '</ul>\n' + '<b>VCF file: </b>' + '<ul style="list-style-type: disc;">\n'    
    for file in metadata:
        summary_html = summary_html + '<li>' + file + '</li>\n'
    summary_html = summary_html + '</ul>\n' 
    
    for i in range(0,len(metadata)):
        summary_html = summary_html + '<b>metadata filename: </b>' + metadata[i] + '<br>\n'
        summary_html = summary_html + '<b># samples: </b>' + str(mts[i].count_cols()) + '<br>\n'
        summary_html = summary_html + '<b># variants: </b>' + str(mts[i].count_rows()) + '<br>\n'
        summary_html = summary_html + '<b>Call rate: </b>' + str(mt.count_rows()/mts[i].count_rows()) + '<br>'
    summary_html = summary_html + '<b><i>After joining sample annotations and vcf files....</i></b><br>'
    summary_html = summary_html + '<b>Total # of Sample analysed: </b>' + str(mt.count_cols()) + '<br>'
    summary_html = summary_html + '<b>Total # of Variant analysed: </b>' + str(mt.count_rows()) + '<br>'
    
    summary_html = summary_html + '</p>'
    
    ## save file
    # write into html file
    filepath = "htmls/summary.html"
    directory = os.path.dirname(filepath)
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    f = open(filepath, "w+")
    f.write(summary_html)
    f.close()


In [None]:
def print_saved_summary_html(mt,mt_name):
    # read saved mt
    
    summary_html = '<p> '
    summary_html = summary_html + '<b>MatrixTable File: </b>' + mt_name + '<br>'
    summary_html = summary_html + '<b># of Sample analysed: </b>' + str(mt.count_cols()) + '<br>'
    summary_html = summary_html + '<b># of Variants analysed: </b>'+ str(mt.count_rows()) + '<br>'
   
    summary_html = summary_html + '</p>'
    
    ## save file
    # write into html file
    filepath = "htmls/summary.html"
    directory = os.path.dirname(filepath)
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    f = open(filepath, "w+")
    f.write(summary_html)
    f.close()


In [None]:

if sample_annot_file_list and mt_file_list :
    # read in files
    t = read_sample_annot(sample_annot_file_list, sample_annot_deli)
    mt, mts = read_metadata(mt_file_list, mt_merge_type, mt_file_type)
    mt = join_rows_cols(t, mt)
    # create result directory and change directory
    shutil.copytree('templates', result_dir)
    os.chdir(result_dir)
    # print out summary result
    print_init_summary_html(sample_annot_file_list, mt_file_list, mt, mts)
elif mt_name in globals():
    # read saved mt
    mt = hl.read_matrix_table(mt_name)
    # create result directory and change directory
    shutil.copytree('templates', result_dir)
    os.chdir(result_dir)
    # print out summary result
    print_saved_summary_html(mt, mt_name)
    

## Input file infos and read files (2 options)

### read from sample annotation files and vcf files
#### sample annotations:
tab-delimited
header needed
compulsory fields: SampleID and isCase (bool)
optional fields (analyse if exist): isFemale/isMale, Population, SuperPopulation
field names must be the same as above in the header
#### vcf:
-standard vcf format (can be bgz or not)

# Sample Annotation Data Analysis

## Load Functions

In [None]:
def get_sample_annot_info(mt, downsample_percent, graph_type, preset_fields):
    analyse_name = 'Sample Annotations'
    plot_html = {}
    for f in preset_fields:
        plot_html[f] = '<div>\
        <h2 id="' + f +'" class="w3-text-teal">' + f + '</h2>\n'
    isVariant = False
    col_dict = get_table_fields(mt.cols().select('pheno').key_by('pheno').select().flatten())
    
    cat_fields = {}
    fns = {}    
    for k in list(col_dict.keys()):
        field_names = k.split('.')
        print(field_names)
        if len(field_names) == 2:
            field = mt[field_names[0]][field_names[1]]
            field_name = field_names[1]
            fns[k] = field_name
        else:
            print('error too many field depth')
            break
        if field_name not in preset_fields:
            print('Found non-preset fields: ' + field_name)
            continue
        
        
        
        cat_fields[k] = isCat(mt, field, isVariant)
        if re.match(r'.*int.*',col_dict[k], re.M) or re.match(r'.*float.*',col_dict[k],re.M) and not isCat(mt, field, isVariant):
            # numbers -> histgram
            plot_html[field_name] = plot_html[field_name] + get_hist_graph(mt,field,field_name,downsample_percent,isVariant, analyse_name)
        else:
            # catogorical values and non-float types get bar_graphs
            plot_html[field_name] = plot_html[field_name]  + get_bar_graph(mt, field, field_name, isVariant, analyse_name)
    
    #get combined graphs
    for f in cat_fields.keys():
        for f2 in cat_fields.keys():
            if cat_fields[f2] == True and f != f2:
                print('f:' + f + ' f2: ' + f2)
                plot_html[fns[f]] = plot_html[fns[f]] + get_bar_graph_A_by_B(mt,[f,f2],graph_type, analyse_name)
    plot_html_all = ''
    plot_html_all = create_side_bar(preset_fields, analyse_name)
    plot_html_all = plot_html_all + '<!-- Main content: shift it to the right by 250 pixels when the sidebar is visible -->\
<div class="w3-main" style="margin-left:250px">\
  <div class="w3-row w3-padding-64">\
    <div class="w3-container">\
      <h1 id="sa" class="w3-text-teal">Analysis of Sample annotations data</h1>\
      <p>some description if you want</p>\n'

    for s in preset_fields:
        plot_html_all = plot_html_all + plot_html[s] + '</div>'
    plot_html_all = plot_html_all + '</div>\n</div>\n</div>\n'
    return(plot_html_all)


## Conduct Analysis

In [None]:
if 'isCase' in fields_to_plot:
    fields_to_plot.remove('isCase')
    fields_to_plot.append('CaseControl')
if 'isFemale' in fields_to_plot:
    fields_to_plot.remove('isFemale')
    fields_to_plot.append('Gender')

In [None]:
plot_html = get_sample_annot_info(mt,downsample_percent,graph_type, fields_to_plot)
    
# write into html file
filepath = "htmls/sample_annot.html"
f = open(filepath, "w+")
f.write(plot_html)
f.close()

# QC analysis

## Load functions

In [None]:
## given a matrixTable
## plot all plots for each field in sample_qc and variant_qc
## INPUT: matrixTable, isVaraiant(boolean), downsample_percent, analyse_name
## OUTPUT: list of plots
def get_qc_info(mt, isVariant, downsample_percent, analyse_name):

    plot_html = '<!-- Main content: shift it to the right by 250 pixels when the sidebar is visible -->\
<div class="w3-main" style="margin-left:250px">\
  <div class="w3-row w3-padding-64">\
    <div class="w3-container">\
      <h1 id="sa" class="w3-text-teal">Per '+ analyse_name +' analysis</h1>\
      <p>some description if you want</p>\n'
    if isVariant:
        col_dict = get_table_fields(mt.rows().select('variant_qc').key_by('variant_qc').select().flatten())
    else:
        col_dict = get_table_fields(mt.cols().select('sample_qc').key_by('sample_qc').select().flatten())
    
    sidebar_field_list = []
    for k in list(col_dict.keys()):
        field_names = k.split('.')
        if len(field_names) == 2:
            field = mt[field_names[0]][field_names[1]]
            field_name = field_names[1]
        elif len(field_names) == 3:
            continue
            field = mt[field_names[0]][field_names[1]][field_names[2]]
            field_name = field_names[2]
        else:
            print('error too many field depth')
            break

        if re.match(r'.*int.*',col_dict[k], re.M) or re.match(r'.*float.*',col_dict[k],re.M):
            # numbers -> histgram
            if re.match(r'^array',col_dict[k],re.M):
                for i in range(0,2):
                    if i == 0:
                        field_name_x = field_name + '_ref_allele' 
                    else:
                        field_name_x = field_name + '_alt_allele'
                    sidebar_field_list.append(field_name_x)
                    plot_html = plot_html + get_hist_graph(mt, field[i],field_name_x,downsample_percent,isVariant, analyse_name)
            else:
                print(field_name)
                sidebar_field_list.append(field_name)
                plot_html = plot_html + get_hist_graph(mt, field,field_name,downsample_percent,isVariant, analyse_name)
        else:
            continue
    
    sidebar_html = create_side_bar(sidebar_field_list, analyse_name)
    plot_html = sidebar_html + plot_html + '</div>\n</div>\n'
    
    return(plot_html)

## Conduct variant QC analysis

In [None]:
plot_html = get_qc_info(mt, True, downsample_percent, 'Variant QC')
    
# write into html file
filepath = "htmls/variant_qc.html"
f = open(filepath, "w+")
f.write(plot_html)
f.close()

## Conduct sample QC analysis

In [None]:
#plot_html = get_qc_info(mt, False, downsample_percent, 'Sample QC')
plot_html = get_qc_info(mt, False, 1, 'Sample QC')

# write into html file
filepath = "htmls/sample_qc.html"
f = open(filepath, "w+")
f.write(plot_html)
f.close()

# Principle Component Analysis (PCA)

## Load functions

In [None]:
def get_PCA_scatter(x,y,label,title=None, xlabel=None, ylabel=None, collect_all=False, n_divisions=500, size=4):

    palette = d3['Category10'][10]
    
    # enlist data from expression
    if collect_all:
        res = hl.tuple([x, y, label]).collect()
        label = [point[2] for point in res]
    else:
        agg_f = x._aggregation_method()
        res = agg_f(hl.agg.downsample(x, y, label=label, n_divisions=n_divisions))
        label = [point[2][0] for point in res]

    x = [point[0] for point in res]
    y = [point[1] for point in res]
    
    p = figure(title=title, x_axis_label=xlabel, y_axis_label=ylabel, tooltips = '@label', background_fill_color='#EEEEEE')
    factors = list(set(label))
    fields = dict(x=x, y=y, label=label)
    source = ColumnDataSource(fields)
    if len(factors) > len(palette):
        color_gen = cycle(palette)
        colors = []
        for i in range(0, len(factors)):
            colors.append(next(color_gen))
    else:
        colors = palette[0:len(factors)]

    color_mapper = CategoricalColorMapper(factors=factors, palette=colors)
    p.circle('x', 'y', alpha=0.5, source=source,size=size, color={'field': 'label', 'transform': color_mapper}, legend='label')
    #show(p)
    return(p)

In [None]:
## Principle Component Analysis
## INPUT: mt(matrixTable), number of factors(int)
## OUTPUT: annotated mt, plot_html
def get_PCA_graph(mt, n_factor, field_list):
    analyse_name = 'PCA'
    plot_html = create_side_bar(['CaseControl'] + field_list, analyse_name)
    plot_html = plot_html + '<div class="w3-main" style="margin-left:250px">\
  <div class="w3-row w3-padding-64">\
    <div class="w3-container">\
      <h1 id="vqc" class="w3-text-teal">PCA Analysis</h1>\
      <p>PCA analysis of factor '+ str(n_factor) + ' </p>'
    eigenvalues, pcs, loadings = hl.hwe_normalized_pca(mt.GT, k=n_factor)
    print('Done PCA analysis')
    
    ## plot pcs per sample
    mt = mt.annotate_cols(pcs = pcs[mt.s].scores)
    ## plot all pcs for case/control
    for i in range(0,n_factor):
        j = i+1
        while j < n_factor:
            print('Starting plotting PC'+ str(i) + ' vs PC' + str(j))
            x = 'PC' + str(i+1)
            y = 'PC' + str(j+1)
            title =  y + ' vs ' +  x 
            title = 'CaseControl PCA - ' + y + ' vs ' +  x 
            if i == 0 and j == 1:
                id_link = analyse_name + '-' + 'CaseControl'
            else:
                id_link = 'pca-cc-'+str(i)+str(j)
            
            p = get_PCA_scatter(mt.pcs[i],
                            mt.pcs[j],
                            label=mt.pheno.CaseControl,
                            title=title, xlabel=x, ylabel=y)
            plot_html = plot_html + save_plot(p,title,analyse_name,id_link)
            j = j + 1
            
    ## plot 1&2 for others 
    col_dict = get_table_fields(mt.cols().select('pheno').key_by('pheno').select().flatten())
    for f in list(col_dict.keys()):
        field_names = f.split('.')
        if len(field_names) == 2:
            field_name = field_names[1]
        if field_name in field_list:
            title = field_name + ' PCA - PC2 vs PC1'
            id_link = analyse_name + '-' + field_name
            p = get_PCA_scatter(mt.pcs[0],
                            mt.pcs[1],
                            label=mt[field_names[0]][field_names[1]],
                            title=title, xlabel='PC1', ylabel='PC2')
            plot_html = plot_html + save_plot(p,title,analyse_name,id_link)
    plot_html = plot_html + '</div>\n</div>'        
    
    return(mt, plot_html)

## Conduct analysis

In [None]:
if 'CaseControl' in fields_to_plot:
    fields_to_plot.remove('CaseControl')

In [None]:
mt, plot_html = get_PCA_graph(mt, n_factor, fields_to_plot)
# write into html file
filepath = "htmls/pca.html"
f = open(filepath, "w+")
f.write(plot_html)
f.close()
    

# Logistic Regression and Manhattan Plot (must after PCA)

## Load functions

In [None]:
def get_manhattan_graph(mt,method, n_divisions=500):
    title = 'Manhattan plot (' + method +')'
    id_link = 'Logistic Regression-'+method
    field_name = 'logreg_' + method
    # plotting
    hover_fields = {'rsid': mt.rsid,
                'locus': mt.locus,
                'p_value': mt[field_name]['p_value']}
    p = hl.plot.manhattan(pvals=mt[field_name]['p_value'], hover_fields=hover_fields, title=title, n_divisions=n_divisions)

    return(save_plot(p, title,'Logistic Regression', id_link))

In [None]:
def get_qq_graph(mt,method, n_divisions=500 ):
    title = 'Q-Q plot (' + method +')'
    id_link = 'lgqq-'+method
    field_name = 'logreg_' + method
    # plotting
    p = hl.plot.qq(pvals=mt[field_name]['p_value'], n_divisions = n_divisions)
    return(save_plot(p, title,'Logistic Regression', id_link))

In [None]:
def logistic_regression (mt, method, n_factor):
    # logistic regression
    covariates = [mt.pcs[i] for i in range(0,n_factor)]
    if 'pheno.isFemale' in list(get_table_fields(mt.cols().select('pheno').flatten()).keys()):
        covariates.append(mt.pheno.isFemale)

    result = hl.logistic_regression_rows(test =method, 
                                          y=mt.pheno.isCase,
                                          x=mt.GT.n_alt_alleles(),
                                          covariates=covariates)
    # annotate matrixtable
    field_name = 'logreg_' + method
    mt = mt.annotate_rows( logreg = result[mt.locus, mt.alleles])
    if field_name in mt._fields:
        mt = mt.drop(mt[field_name])
    mt = mt.rename({'logreg': field_name})
    
    return(mt)

## Conduct analysis

In [None]:
# wald
mt = logistic_regression(mt, 'wald', n_factor)
mt.count()
plot_html_wald = get_manhattan_graph(mt, 'wald', 150)
plot_html_wald = plot_html_wald + get_qq_graph(mt, 'wald', 150)

In [None]:
#score
mt = logistic_regression(mt, 'score', n_factor)
mt.count()
plot_html_score = get_manhattan_graph(mt, 'score', 450)
plot_html_score = plot_html_score + get_qq_graph(mt, 'score', 450)

In [None]:
#lrt
mt = logistic_regression(mt, 'lrt', n_factor)
mt.count()
plot_html_lrt = get_manhattan_graph(mt, 'lrt', 150)
plot_html_lrt = plot_html_lrt + get_qq_graph(mt, 'lrt', 150)

In [None]:
plot_html = create_side_bar(['wald','score','lrt'],'Logistic Regression')
plot_html = plot_html + '<div class="w3-main" style="margin-left:250px">\
    <div class="w3-row w3-padding-64">\
    <div class="w3-container">\
      <h1 id="vqc" class="w3-text-teal">Logistic Regression</h1>\
      <p>Manhattan Plots of p-values from Logistic Regressions(3 methods)</p>\n'
plot_html = plot_html + plot_html_wald + plot_html_score + plot_html_lrt + '</div>\n</div>\n'
# write into html file
filepath = "htmls/manhattan.html"
f = open(filepath, "w+")
f.write(plot_html)
f.close()

# Variant-Spark Analysis

### export current LR-ed matrix and conduct variant-spark analysis

## Load Functions

In [None]:
def vs_analysis(mt):

    rf_model = vshl.random_forest_model(y=mt.pheno.isCase, x=mt.GT.n_alt_alleles(),
                                        seed = 13, mtry_fraction = mtry_fraction,
                                        min_node_size = min_node_size, max_depth = max_depth)
    rf_model.fit_trees(num_of_tree, batch_size)
    print("OOB: ", rf_model.oob_error())
    impTable = rf_model.variable_importance()
    
    mt = mt.annotate_rows(vs_score = impTable[mt.locus, mt.alleles].importance)
    
    return(mt)

In [None]:
def get_vs_manhattan(mt):
    mt = mt.annotate_rows(vs_stats = mt.aggregate_rows(hl.agg.stats(mt['vs_score'])))
    mt = mt.annotate_rows(z_score = (mt['vs_score'] - mt.vs_stats.mean)/mt.vs_stats.stdev)
    mt = mt.annotate_rows(vs_score_converted = 10** -mt.z_score)
    title = 'Variant-Spark Manhattan plot'
    id_link = 'man-vs'
    folder = 'Variant Spark'
    hover_fields = {'rsid': mt.rsid, 'vs_score': mt.vs_score}
    p = hl.plot.manhattan(pvals=mt.vs_score_converted, hover_fields=hover_fields, title=title)
    p.yaxis.axis_label = 'Z score of importantce score by VS'
    return(save_plot(p, title, folder, id_link))


## Conduct Analysis

In [None]:
filename = re.sub(r'\s', '_', analysis_name)
filepath = result_dir + '/output/variant-spark/' + filename
directory = os.path.dirname(filepath)
if not os.path.exists(directory):
    os.makedirs(directory)

mt = vs_analysis(mt)
plot_html = get_vs_manhattan(mt)

# write into html file
filepath = "htmls/variant-spark.html"
f = open(filepath, "w+")
f.write(plot_html)
f.close()
    

# Save Results

## Copy Plots and HTML files to S3

In [None]:
cmd = 'aws s3 cp --recursive ' + result_dir + '/ ' + S3_Path + analysis_name + '/'
print(cmd)
subprocess.call(cmd, shell=True)

## export as a matrix table - ready to use for later analysis

In [None]:
file_name = 'annotated_' + re.sub(r'\s', '_', analysis_name)
output_prefix = S3_Path + analysis_name + '/output/' + file_name
print(output_prefix)

In [None]:
mt_name = output_prefix + '.mt'
mt.write(mt_name,overwrite=True)

## export as vcf

In [None]:
output_path = output_prefix + '.vcf.bgz' # extension needed
hl.export_vcf(mt, output_path) # 1 output file

## export as plink

In [None]:
output_path = output_prefix # no extension needed
hl.export_plink(mt, output_path) # 3 output files 