# Example Notebook
### This notebook is an example that shows how to use [VariantSpark](https://github.com/aehrc/VariantSpark) with the [Hail v0.2 library](https://hail.is/) and compares the results with PCA and logistic regression.
### For demonstration purposes this notebook uses the sample dataset available in [ViGWAS](https://github.com/aehrc/VIGWAS).

## [We're always looking for suggestions and feedback. Please click here for a 1 minute survey](https://docs.google.com/forms/d/e/1FAIpQLScWoazw3-rgNFrZ5vcHL9JUmO0AX6Ji2P54Z2jNJZ-RAObuPg/viewform?usp=sf_link)

# User Block

In [0]:
val containerName = "<Container Name>"
val storageAccountName = "<StorageAccount Nmae>"
val sas = "<Generated SAS Key>"
val config = "fs.azure.sas." + containerName+ "." + storageAccountName + ".blob.core.windows.net"
 
dbutils.fs.mount(
  source = "wasbs://"+containerName+"@"+storageAccountName+".blob.core.windows.net/sample_input", 
  extraConfigs = Map(config -> sas)) #replace sample_input with your input folder.

## Some configs
numCPU = 32
memory = '60G'
numPartitions = numCPU*4

# Environment initialization

In [0]:
## Environment init

import os
from pyspark import SparkContext
sc = SparkContext.getOrCreate()

import hail as hl
import varspark.hail as vshl
vshl.init(sc=sc)

In [0]:
from bokeh.io import show, output_notebook
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, FactorRange, LabelSet, Label
from bokeh.transform import factor_cmap
from bokeh.palettes import d3
from bokeh.core.properties import value
from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.layouts import gridplot
from bokeh.models.mappers import CategoricalColorMapper

from pprint import pprint
output_notebook()

import re
import numpy as np
import math as math
import sys
import operator
from collections import OrderedDict
import subprocess
from itertools import cycle
import shutil

# Load VCF files

In [0]:
mt = hl.import_vcf(path='/mnt/V1.vcf.bgz',
                   skip_invalid_loci=True,
                   min_partitions=int(numPartitions))

# Sample Annotation Data Analysis

In [0]:
Annot = hl.import_table('/mnt/hipster.csv',
                        impute=True, delimiter=',').key_by('Sample')

# Annotate dataset with sample annotation

In [0]:
mt = mt.annotate_cols(pheno = Annot[mt.s])

# PCA analysis

In [0]:
eigenvalues, pcs, loadings = hl.hwe_normalized_pca(mt.GT, k=2)
mt = mt.annotate_cols(pcs = pcs[mt.s].scores)

In [0]:
p = hl.plot.scatter(pcs.scores[0], pcs.scores[1],
                    label=mt.cols()[pcs.s].pheno.Hipster,
                    title='PCA Case/Control', xlabel='PC1', ylabel='PC2', collect_all=True)
show(p)

# Logistic Regression

In [0]:
covariates = [mt.pheno.isFemale, mt.pcs[0], mt.pcs[1]]

result = hl.logistic_regression_rows(test ='wald', 
                                          y=mt.pheno.isCase,
                                          x=mt.GT.n_alt_alleles(),
                                          covariates=covariates)

mt = mt.annotate_rows( logreg = result[mt.locus, mt.alleles])

In [0]:
p = hl.plot.manhattan(result.p_value)
show(p)

# Variant-Spark RandomForest

In [0]:
rf_model = vshl.random_forest_model(y=mt.pheno.isCase, x=mt.GT.n_alt_alleles(),
                                    seed = 13, mtry_fraction = 0.1,
                                    min_node_size = 10, max_depth = 15)

rf_model.fit_trees(n_trees=100, batch_size=25)

impTable = rf_model.variable_importance()

mt = mt.annotate_rows(vs_score = impTable[mt.locus, mt.alleles].importance)

In [0]:
mt = mt.annotate_rows(vs_stats = mt.aggregate_rows(hl.agg.stats(mt['vs_score'])))
mt = mt.annotate_rows(z_score = (mt['vs_score'] - mt.vs_stats.mean)/mt.vs_stats.stdev)
mt = mt.annotate_rows(vs_score_converted = 10** -mt.z_score)
title = 'Variant-Spark Manhattan plot'
hover_fields = {'rsid': mt.rsid, 'vs_score': mt.vs_score}
p = hl.plot.manhattan(pvals=mt.vs_score_converted, hover_fields=hover_fields, title=title)
p.yaxis.axis_label = 'Z score of importantce score by VS'
show(p)

# Describe matrix Table

In [0]:
mt.describe()