Notebook 5: F2 and FST

1. FST in Hail is currently in the works so is a bit messy right now (second half of the nb)
2. In the meantime, FST is calculated using PLINK
3. Heat maps for this analysis were plotted using R - will be kept that way

## Index
- [General Overview](#General-Overview)
- [F2 analysis](#F2-analysis)
- [FST](#FST)
    - [FST with PLINK](#1.-FST-with-PLINK)
    - [FST with HAIL](#2.-FST-with-HAIL)
- [Additional Notes](#Notes)

# General Overview 
The purpose of this notebook is to show two population genetics analyses (F2 and FST) to understand recent and deep history. It contains steps on how to: 

- Read in a matrix table and filter using sample IDs that were obtained from another matrix table 
- Separate a field into multiple fields
- Filter using the call rate field 
- Extract doubletons and check if they are the reference or alternate allele
- Count how many times a sample or a sample pair appears in a field 
- Combine two dictionaries and add up the values for identical keys
- Format list as pandas table 
- Export a matrix table as PLINK2 BED, BIM and FAM files 
- Set up a pair-wise comparison
- Drop certain fields
- Calculate FST (once there is progress on this, I will elaborate more)

Author: Mary T. Yohannes

In [None]:
# import Hail
import hail as hl

# F2 analysis

In [None]:
# read-in the right intermediate file 
mt_filt = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/intermediate_files/pre_running_varqc.mt')

# filter to just the unrelated samples
# use the file exported for Lindo - unrelated samples mt without outliers - to obtain desired samples  

mt_unrel = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/unrelated.mt') # unrelated mt - 3380 samples
unrel_samples = mt_unrel.s.collect() # collect sample IDs as a list 
unrel_samples = hl.literal(unrel_samples) # capture and broadcast the list as an expression 
mt_filt_unrel = mt_filt.filter_cols(unrel_samples.contains(mt_filt['s'])) # filter mt to only unrelated samples

mt_filt_unrel.count() # (155648020, 3380)

In [None]:
# run common variant statistics (quality control metrics) - more info https://hail.is/docs/0.2/methods/genetics.html#hail.methods.variant_qc  
mt_unrel_varqc = hl.variant_qc(mt_filt_unrel)

# separate the AC array into individual fields   
mt_unrel_interm = mt_unrel_varqc.annotate_rows(AC1 = mt_unrel_varqc.variant_qc.AC[0], AC2 = mt_unrel_varqc.variant_qc.AC[1])

# extract the doubletons
mt_unrel_only2 = mt_unrel_interm.filter_rows((mt_unrel_interm.AC1 == 2) | (mt_unrel_interm.AC2 == 2))
#mt_unrel_only2.count() # (18018978, 3380)

In [None]:
#### checking # of doubletons with d/t filter parameters 

# how many of the first allele are doubletons? 
mt_unrel_only2_ac1 = mt_unrel_interm.filter_rows((mt_unrel_interm.AC1 == 2))
print(mt_unrel_only2_ac1.count()) # (3354, 3380)

# how many of the second allele are doubletons? 
mt_unrel_only2_ac2 = mt_unrel_interm.filter_rows(mt_unrel_interm.AC2 == 2)
print(mt_unrel_only2_ac2.count()) # (18015720, 3380)

# where both alleles are equal 2
mt_unrel_only2_both = mt_unrel_interm.filter_rows((mt_unrel_interm.AC1 == 2) & (mt_unrel_interm.AC2 == 2))
print(mt_unrel_only2_both.count()) # (96, 3380)


# sanity check 
# mt_unrel_only2_ac1 + mt_unrel_only2_ac2 - mt_unrel_only2_both = mt_unrel_only2
3354+18015720-96  == 18018978 # True 

In [None]:
# write it out saving took ~23 min  
#mt_unrel_only2.write('gs://hgdp-1kg/hgdp_tgp/FST_F2/F2/doubleton.mt', overwrite=False)

In [None]:
# read it back in 
mt_unrel_only2 = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/FST_F2/F2/doubleton.mt')

In [None]:
# remove variants with call rate < 0.05 (no more than 5% missingness/low missingness)  
mt_unrel_only2_filtered = mt_unrel_only2.filter_rows(mt_unrel_only2.variant_qc.call_rate > 0.05)
#mt_unrel_only2_filtered.count() # (17997741, 3380)

In [None]:
# Check AF to see if the doubleton is the ref or alt allele 
# AF[0] < AF[1] - doubleton is 1st allele (ref)
mt_doubl_ref = mt_unrel_only2_filtered.filter_rows((mt_unrel_only2_filtered.variant_qc.AF[0] < mt_unrel_only2_filtered.variant_qc.AF[1]))
#print(mt_doubl_ref.count()) # (3159, 3380)

# AF[0] > AF[1] - doubleton is 2nd allele (alt)
mt_doubl_alt = mt_unrel_only2_filtered.filter_rows((mt_unrel_only2_filtered.variant_qc.AF[0] > mt_unrel_only2_filtered.variant_qc.AF[1]))
#print(mt_doubl_alt.count()) # (17994582, 3380)

# sanity check 
# mt_doubl_ref.count()[0] + mt_doubl_alt.count()[0] = mt_unrel_only2_filtered.count()[0]
#print(3159 + 17994582 == 17997741) # True 

In [None]:
# find the samples that are doubletons in each mt and compile them in a set 
# chose set instead of a list cause a list isn't hashable and the next step woulde not run 

# doubleton is 1st allele (ref) - 0|1 & 0|0
# if one sample in the new column field then it's 0|0. If two, then it's 0|1
mt_ref_collected = mt_doubl_ref.annotate_rows(
    samples_with_doubletons = hl.agg.filter(
        (mt_doubl_ref.GT == hl.call(0, 1))| (mt_doubl_ref.GT == hl.call(0, 0)), hl.agg.collect_as_set(mt_doubl_ref.s)))


# doubleton is 2nd allele (alt) - 0|1 & 1|1
# if one sample in the new column field then it's 1|1. If two, then it's 0|1
mt_alt_collected = mt_doubl_alt.annotate_rows(
    samples_with_doubletons = hl.agg.filter(
        (mt_doubl_alt.GT == hl.call(0, 1))| (mt_doubl_alt.GT == hl.call(1, 1)), hl.agg.collect_as_set(mt_doubl_alt.s)))


In [None]:
mt_ref_collected.samples_with_doubletons.show(5)

In [None]:
# count how many times a sample or a sample pair appears in the "samples_with_doubletons" field 
# returns a dictionary
ref_doubl_count = mt_ref_collected.aggregate_rows(hl.agg.counter(mt_ref_collected.samples_with_doubletons))

alt_doubl_count = mt_alt_collected.aggregate_rows(hl.agg.counter(mt_alt_collected.samples_with_doubletons))

In [None]:
# combine the two dictionaries and add up the values for identical keys  
all_doubl_count = {k: ref_doubl_count.get(k, 0) + alt_doubl_count.get(k, 0) for k in set(ref_doubl_count) | set(alt_doubl_count)}

len(all_doubl_count) # 3183039

# sanity check - add up the count of the two dict and then subtract the # of keys that intersect b/n the two 
# the value that you get should be equal to the combined dict length 

In [None]:
# get list of samples from mt
mt_sample_list = mt_unrel_only2_filtered.s.collect()

# make pairs from sample list - 5710510 pairs - n(n-1)/2)
mt_sample_pairs = [{x,y} for i, x in enumerate(sample_list) for j,y in enumerate(sample_list) if i<j]

In [None]:
# subset dict to only keys with length of 1 - one sample 
dict_single_samples = {x:all_doubl_count[x] for x in all_doubl_count if len(x) == 1}

# subset dict to keys with sample pairs (not just 1)
dict_pair_samples = {x:all_doubl_count[x] for x in all_doubl_count if len(x) != 1}

# sanity check 
print(len(dict_single_samples) + len(dict_pair_samples) == len(all_doubl_count)) # True 

# further investigation
print(len(mt_sample_list) == len(dict_single_samples)) # True - are the samples in the list the same as the dict keys?  
print(len(mt_sample_pairs) == len(dict_pair_samples)) # False - there are more sample pairs obtained from mt than what's in the dict 

In [None]:
# single sample 

# go through the list of samples obtained from the mt and see if any of them are keys in the single sample dict 
# if they are, record the sample ID twice and it's corresponding value from the dict 
# if they are not, then record the sample ID twice and set the value to 0 

# list Comprehension version 
single_sample_final_list = [[s, s, 0] if dict_single_samples.get(frozenset([s])) is None else [s, s, dict_single_samples[frozenset([s])]] for s in mt_sample_list]

# # for loop version 
# sample_ids = ['NA12546B', 'NA12830A', 'HG02688', 'HG02334', 'NA21130'] # trail list 
# sample_dict = {frozenset({'HG02334'}): 639, frozenset({'NA21130'}): 497, frozenset({'HG02688'}): 83} # trial dict 
# final_data = []
# for s in sample_ids:
#     if sample_dict.get(frozenset([s])) is None:
#         final_data.append([s, s, 0])
#     else:
#         final_data.append([s, s, sample_dict[frozenset([s])]])

# sanity check 
print(len(single_sample_final_list) == len(mt_sample_list) == len(dict_single_samples)) # True - for the single samples, the len should be consistent across dict, mt sample list and final list
# compare the counts in the final list against the ones in the dict 
# if all the comparisons are True, the output of this command will also be True 
all([single_sample_final_list[x][2] == dict_single_samples[frozenset([single_sample_final_list[x][0]])] for x in range(len(single_sample_final_list))]) # True 

In [None]:
# sample pairs  

# go through the list of sample pairs created from the samples obtained from mt and see if any of them are keys in the sample pair dict 
# if they are, record the two sample IDs and the corresponding value from the dict
# if they are not, then record the two sample IDs and set the value to 0 

# list Comprehension version 
sample_pair_final_list = [[list(s)[0], list(s)[1], 0] if dict_pair_samples.get(frozenset(list(s))) is None else [list(s)[0], list(s)[1], dict_pair_samples[frozenset(list(s))]] for s in mt_sample_pairs]

# # for loop version 
# sample_ids = [{'NA12546B', 'NA12830A'}, {'HG02757', 'NA12546B'}, {'HG02184', 'HGDP00863'}, {'LP6005443-DNA_D02', 'NA19068'}, {'HG02611', 'NA12546B'}] # trail list 
# sample_dict = {frozenset({'HG02184', 'HGDP00863'}): 639, frozenset({'LP6005443-DNA_D02', 'NA19068'}): 497, frozenset({'NA19982', 'NA20356'}): 83} # trial dict 

# sanity check 
len(sample_pair_final_list) == len(mt_sample_pairs) # True
# STILL IN THE WORKS :-
# compare the counts in the final list against the ones in the dict 
# if all the comparisons are True, the output of this command will also be True 
# all([sample_pair_final_list[x][2] == dict_pair_samples[frozenset([sample_pair_final_list[x][0], sample_pair_final_list[x][1]])] for x in range(len(sample_pair_final_list))]) # True 

In [None]:
final_list = single_sample_final_list + sample_pair_final_list

# sanity check 
len(final_list) == len(single_sample_final_list) + len(sample_pair_final_list) # True

In [None]:
# format list as pandas table 
import pandas as pd
df = pd.DataFrame(final_list)
df 

In [None]:
# rename column names 
# used 'inplace' instead of assigning it back to df
df.rename({0:'sample1', 1:'sample2', 2:'count'}, axis=1, inplace=True) 
df

In [None]:
# save to the cloud so it can be plotted with R 
df.to_csv('gs://hgdp-1kg/hgdp_tgp/FST_F2/F2/doubleton_sample_pair_count_tbl.csv', index=False, sep='\t')

In [None]:
# save sample ID and population for heatmap annotation 
sampleID_pop_reg = (mt_unrel_only2_filtered.select_cols(mt_unrel_only2_filtered['hgdp_tgp_meta']['Population'], mt_unrel_only2_filtered['hgdp_tgp_meta']['Genetic']['region'])).cols()
sampleID_pop_reg.export('gs://hgdp-1kg/hgdp_tgp/FST_F2/F2/sampleID_pop_reg.txt', header=False)

--------------------------------------------------------

# Notes

In [None]:
# when dict key is frozenset 

# get value 
#list(ref_doubl_count.values())[0]

# get key
#list(list(ref_doubl_count.keys())[0])

# get one sample from the pair 
#list(set(list(ref_doubl_count.keys())[2]))[0]

In [None]:
# # sanity check to make sure that mt.GT.is_het() and mt.GT == hl.call(0, 1) produce the same results 

# trial_1 = mt_unrel_only2_filtered.annotate_rows(
#     samples_with_doubletons = hl.agg.filter(
#         mt_unrel_only2_filtered.GT.is_het(), hl.agg.collect(mt_unrel_only2_filtered.s)))

# trial_2 = mt_unrel_only2_filtered.annotate_rows(
#     samples_with_doubletons = hl.agg.filter(
#         mt_unrel_only2_filtered.GT == hl.call(0, 1), hl.agg.collect(mt_unrel_only2_filtered.s)))

# trial_1_list = trial_1.samples_with_doubletons.collect()
# trial_2_list = trial_2.samples_with_doubletons.collect()

# trial_1_list == trial_2_list # True 

In [None]:
# is_hom_var() - only 1|1 - Evaluate whether the call includes two identical alternate alleles.
# is_hom_ref() - only 0|0 - Evaluate whether the call includes two reference alleles.
# is_het() - only 0/1 - Evaluate whether the call includes two different alleles.

# # samples with doubletons and how many per variant  
# p = mt_unrel_only2_filtered.annotate_rows(
#     samples_with_doubletons = hl.agg.filter(
#         mt_unrel_only2_filtered.GT == hl.call(0, 1), 
#         hl.agg.collect(mt_unrel_only2_filtered.s)),
#     n_doubleton = hl.agg.filter(
#         mt_unrel_only2_filtered.GT == hl.call(0, 1), 
#         hl.agg.count()))

# # n_doubleton by sample - maybe can be used for sanity check 
# k = mt_unrel_only2_filtered.annotate_cols(
#     n_doubleton = hl.agg.filter(
#         mt_unrel_only2_filtered.GT.is_non_ref(), 
#         hl.agg.count()))

In [None]:
# list(s) - if s is a frozenset with a string as the element, this only changes the type into a list (with the string characters intact)
# [s] - if s is a frozenset with a string as the element, this creates a list of the characters that make up the string 

--------------------------------------------------------

# FST

For FST, we are using the data we had prior to running pc_relate (*filtered_n_pruned_output_updated.mt*)

In [None]:
# use the file exported for Lindo - data before running pc_relate and without outliers 
mt_to_plink = hl.read_matrix_table('gs://hgdp-1kg/hgdp_tgp/datasets_for_others/lindo/ds_without_outliers/whole_dataset.mt')

### 1. FST with PLINK

#### FOR ZAN -- convert mt to PLINK files

In [None]:
# export mt as PLINK2 BED, BIM and FAM files with 'hgdp_tgp' prefix
hl.export_plink(mt_to_plink, 'gs://hgdp-1kg/hgdp_tgp/FST_F2/FST/PLINK/hgdp_tgp', fam_id=mt_to_plink.hgdp_tgp_meta.Population)

## FST was then calculated using PLINK (with a VM - how to in your offline notebook: FST -> PLINK for Zan) 

### 2. FST with HAIL 

#### 2a. *Hudson Estimator* 

In [None]:
# rename dataset  
mt_for_fst = mt_to_plink

# previous code ##########################
# same dataset as gs://hgdp-1kg/filtered_n_pruned_output_updated.mt after removing the outliers (gs://hgdp-1kg/hgdp_tgp/pca_outliers_v2.txt)
##with hl.utils.hadoop_open('gs://hgdp-1kg/hgdp_tgp/pca_outliers_v2.txt') as file: # read the outliers file into a list
    ##outliers = [line.rstrip('\n') for line in file]
    
##outliers_list = hl.literal(outliers) # capture and broadcast the list as an expression

##mt_var_pru_filt = mt_var_pru_filt.filter_cols(~outliers_list.contains(mt_var_pru_filt['s'])) # remove 22 outliers 
##########################

mt_for_fst.count() # double checking - (248634, 4097)

##### *pair-wise comparison*

Formula to calculate number of pair-wise comparisons = (k * (k-1))/2

So in our case, since we have 78 populations, we would expect = (78 * (78-1))/2 = 6006/2 = 3003 pair-wise comparisons

In [None]:
pop = mt_for_fst['hgdp_tgp_meta']['Population'].collect()
pop = list(dict.fromkeys(pop)) 
len(pop) # 78 populations in total 

In [None]:
# example 
ex = ['a','b','c']
# pair-wise comparison 
ex_pair_com = [[x,y] for i, x in enumerate(ex) for j,y in enumerate(ex) if i<j]
ex_pair_com

In [None]:
# pair-wise comparison - creating list of lists 
# enumerate gives index values for each population in the 'pop' list (ex. 0 CEU, 1 YRI, 2 LWK ...) and then by 
# comparing those index values, we create a pair-wise comparison between the populations 
# i < j so that it only does a single comparison among two different populations 
# ex. for a comparison between populations CEU and YRI, it only keeps CEU-YRI and discards YRI-CEU, CEU-CEU and YRI-YRI
pair_com = [[x,y] for i, x in enumerate(pop) for j,y in enumerate(pop) if i<j]

In [None]:
# first 5 elements in the list  
pair_com[0:5]

In [None]:
# sanity check 
len(pair_com) # 3003

##### *subset mt into popns according to the pair-wise comparisons and run common variant statistics*

In [None]:
pair_com[0]

In [None]:
## example - pair_com[0] = ['CEU', 'YRI'] and pair_com[0][0] = 'CEU'
CEU_mt = mt_for_fst.filter_cols(mt_for_fst['hgdp_tgp_meta']['Population'] == pair_com[0][0])
YRI_mt = mt_for_fst.filter_cols(mt_for_fst['hgdp_tgp_meta']['Population'] == pair_com[0][1])

# previous code ##########################
##CEU_YRI_mt = mt_for_fst.filter_cols((mt_for_fst['hgdp_tgp_meta']['Population'] == pair_com[0][0]) | (mt_for_fst['hgdp_tgp_meta']['Population'] == pair_com[0][1]))

# sanity check 
##CEU_mt.count()[1] + YRI_mt.count()[1] == CEU_YRI_mt.count()[1] # 176 + 175 = 351
##########################

In [None]:
# run common variant statistics for each population and their combined mt 
CEU_var = hl.variant_qc(CEU_mt) # individual - 176 
YRI_var = hl.variant_qc(YRI_mt) # individual - 175

##CEU_YRI_var = hl.variant_qc(CEU_YRI_mt) # total - 351

##### *Set up mt table for FST calculation - the next code is run for each population and their combos*

##### a. *population 1*

In [None]:
# drop certain fields first to make mt smaller 

# drop all entry fields
# everything except for 's' (key) from the column fields
# everything from the row fields except for the keys -'locus' and 'alleles' and row field 'variant_qc'  
CEU_interm = CEU_var.drop(*list(CEU_var.entry), *list(CEU_var.col)[1:], *list(CEU_var.row)[2:-1])

# only select the row field keys (locus and allele) and row field 'AF' which is under 'variant_qc'
CEU_interm2 = CEU_interm.select_rows(CEU_interm['variant_qc']['AF'])  

# previous code ##########################
# only select the row field keys (locus and allele) and row fields 'AF' & 'AN' which are under 'variant_qc'
##CEU_interm2 = CEU_interm.select_rows(CEU_interm['variant_qc']['AF'], CEU_interm['variant_qc']['AN'])  
##########################

# quick look at the condensed mt 
CEU_interm2.describe()

In [None]:
# only include the second entry of the array from the row field 'AF' 
CEU_interm3 = CEU_interm2.transmute_rows(AF = CEU_interm2.AF[1])

# previous code ##########################
# key the rows only by 'locus' so that the 'allele' row field can be split into two row fields (one for each allele)
# also, only include the second entry of the array from 'AF' row field  
##CEU_interm3 = CEU_interm2.key_rows_by('locus')
##CEU_interm3 = CEU_interm3.transmute_rows(AF = CEU_interm3.AF[1], A1 = CEU_interm3.alleles[0], A2 = CEU_interm3.alleles[1])
##########################

# add a row field with population name to keep track of which mt it came from 
CEU_final = CEU_interm3.annotate_rows(pop = pair_com[0][0])
CEU_final.rows().show(5)

##### b. *population 2*

In [None]:
# drop fields  

# drop all entry fields
# everything except for 's' (key) from the column fields
# everything from the row fields except for the keys -'locus' and 'alleles' and row field 'variant_qc'  
##CEU_YRI_interm = CEU_YRI_var.drop(*list(CEU_YRI_var.entry), *list(CEU_YRI_var.col)[1:], *list(CEU_YRI_var.row)[2:-1])

# only select the row field keys (locus and allele) and row fields 'AF' & 'AN' which are under 'variant_qc'
##CEU_YRI_interm2 = CEU_YRI_interm.select_rows(CEU_YRI_interm['variant_qc']['AF'], CEU_YRI_interm['variant_qc']['AN'])  

# quick look at the condensed mt 
##CEU_YRI_interm2.describe()

In [None]:
# only include the second entry of the array from the row field 'AF' 
##CEU_YRI_interm3 = CEU_YRI_interm2.transmute_rows(AF = CEU_YRI_interm2.AF[1])

# previous code ##########################
# key the rows only by 'locus' so that the 'allele' row field can be split into two row fields (one for each allele)
# also, only include the second entry of the array from 'AF' row field  
##CEU_YRI_interm3 = CEU_YRI_interm2.key_rows_by('locus')
##CEU_YRI_interm3 = CEU_YRI_interm3.transmute_rows(AF = CEU_YRI_interm3.AF[1], A1 = CEU_YRI_interm3.alleles[0], A2 = CEU_YRI_interm3.alleles[1])
##########################

# add a row field with population name to keep track of which mt it came from 
##CEU_YRI_final = CEU_YRI_interm3.annotate_rows(pop = f'{pair_com[0][0]}-{pair_com[0][1]}')
##CEU_YRI_final.rows().show(5)

### FST formula pre-setup - trial run

##### *Variables needed for FST calculation* 

In [None]:
# converting lists into numpy arrarys cause it is easier to work with and more readable

# assign populations to formula variables 
pop1 = CEU_final
pop2 = CEU_YRI_final

# number of alleles 
n1 = np.array(pop1.AN.collect())
n2 = np.array(pop2.AN.collect())

# allele frequencies 
FREQpop1 = np.array(pop1.AF.collect()) 
FREQpop2 = np.array(pop2.AF.collect())  

##### *Weighted average allele frequency*

In [None]:
FREQ = ((n1*FREQpop1) + (n2*FREQpop2)) / (n1+n2)

# sanity checks
print(((n1[0]*FREQpop1[0]) + (n2[0]*FREQpop2[0])) / (n1[0]+n2[0]) == FREQ[0])
print(len(FREQ) == len(FREQpop1)) # length of output should be equal to the length of arrays we started with

##### *Filter to only freqs between 0 and 1*

In [None]:
INCLUDE=(FREQ>0) & (FREQ<1) # only include ave freq between 0 and 1 - started with FREQ = 248634
print(np.count_nonzero(INCLUDE)) # 246984 ave freq values were between 0 and 1 - returned True to the conditions above; 248634 - 246984 = 1650 were False 

# subset allele frequencies 
FREQpop1=FREQpop1[INCLUDE]
FREQpop2=FREQpop2[INCLUDE]
FREQ=FREQ[INCLUDE]

# sanity check 
print(len(FREQpop1) == np.count_nonzero(INCLUDE)) # TRUE

# subset the number of alleles 
n1 = n1[INCLUDE]
n2 = n2[INCLUDE]

# sanity check 
print(len(n1) == np.count_nonzero(INCLUDE)) # TRUE

#### 2b. *W&C ESTIMATOR*

In [None]:
## average sample size that incorporates variance
nc =((1/(s-1)) * (n1+n2)) - ((np.square(n1) + np.square(n2))/(n1+n2))

msa= (1/(s-1))*((n1*(np.square(FREQpop1-FREQ)))+(n2*(np.square(FREQpop2-FREQ))))

msw =  (1/((n1-1)+(n2-1))) * ((n1*(FREQpop1*(1-FREQpop1))) + (n2*(FREQpop2*(1-FREQpop2))))

numer = msa-msw

denom = msa + ((nc-1)*msw)

FST_val = numer/denom

# sanity check using the first element 
nc_0 =((1/(s-1)) * (n1[0]+n2[0])) - ((np.square(n1[0]) + np.square(n2[0]))/(n1[0]+n2[0]))

msa_0= (1/(s-1))*((n1[0]*(np.square(FREQpop1[0]-FREQ[0])))+(n2[0]*(np.square(FREQpop2[0]-FREQ[0]))))

msw_0 =  (1/((n1[0]-1)+(n2[0]-1))) * ((n1[0]*(FREQpop1[0]*(1-FREQpop1[0]))) + (n2[0]*(FREQpop2[0]*(1-FREQpop2[0]))))

numer_0 = msa_0-msw_0

denom_0 = msa_0 + ((nc_0-1)*msw_0)

FST_0 = numer_0/denom_0

print(FST_0 == FST_val[0]) # TRUE

In [None]:
FST_val

### *Which FST value is for which locus-allele?* - actual run

In [None]:
# resetting variables for the actual FST run 

# assign populations to formula variables 
pop1 = CEU_final
pop2 = CEU_YRI_final

# number of alleles 
n1 = np.array(pop1.AN.collect())
n2 = np.array(pop2.AN.collect())

# allele frequencies 
FREQpop1 = np.array(pop1.AF.collect()) 
FREQpop2 = np.array(pop2.AF.collect())  

# locus + alleles = keys - needed for reference purposes - these values are uniform across all populations 
locus = np.array(hl.str(pop1.locus).collect())
alleles = np.array(hl.str(pop1.alleles).collect())
key = np.array([i + ' ' + j for i, j in zip(locus, alleles)])

In [None]:
s=2   # s is the number of populations - since we are calculating pair-wise FSTs, this is always 2 
key_FST = {}
for i in range(len(key)):
    FREQ = ((n1[i]*FREQpop1[i]) + (n2[i]*FREQpop2[i])) / (n1[i]+n2[i])
    
    if (FREQ>0) & (FREQ<1): # only include ave freq between 0 and 1
        
    ## average sample size that incorporates variance
        nc = ((1/(s-1)) * (n1[i]+n2[i])) - ((np.square(n1[i]) + np.square(n2[i]))/(n1[i]+n2[i]))

        msa= (1/(s-1))*((n1[i]*(np.square(FREQpop1[i]-FREQ)))+(n2[i]*(np.square(FREQpop2[i]-FREQ))))

        msw = (1/((n1[i]-1)+(n2[i]-1))) * ((n1[i]*(FREQpop1[i]*(1-FREQpop1[i]))) + (n2[i]*(FREQpop2[i]*(1-FREQpop2[i]))))

        numer = msa-msw

        denom = msa + ((nc-1)*msw)

        FST = numer/denom
        
        key_FST[key[i]] = FST

In [None]:
key_FST

In [None]:
# sanity checks 
print(all(np.array(list(key_FST.values())) == FST_val)) # True 
print(len(key_FST) == len(FST_val)) # True

### *other pair*

##### c. population 3

In [None]:
# population - YRI
# same steps we did to CEU

YRI_interm = YRI_var.drop(*list(YRI_var.entry), *list(YRI_var.col)[1:], *list(YRI_var.row)[2:-1])

# only select the row field keys (locus and allele) and row fields 'AF' & 'AN' which are under 'variant_qc'
YRI_interm2 = YRI_interm.select_rows(YRI_interm['variant_qc']['AF'], YRI_interm['variant_qc']['AN'])  

# only include the second entry of the array from the row field 'AF' 
YRI_interm3 = YRI_interm2.transmute_rows(AF = YRI_interm2.AF[1])

# add a row field with population name to keep track of which mt it came from 
YRI_final = YRI_interm3.annotate_rows(pop = pair_com[0][1])
YRI_final.rows().show(5)

### FST

In [None]:
# resetting variables for the actual FST run 

# assign populations to formula variables 
pop1 = YRI_final
pop2 = CEU_YRI_final

# number of alleles 
n1 = np.array(pop1.AN.collect())
n2 = np.array(pop2.AN.collect())

# allele frequencies 
FREQpop1 = np.array(pop1.AF.collect()) 
FREQpop2 = np.array(pop2.AF.collect())  

# locus + alleles = keys - needed for reference purposes - these values are uniform across all populations 
locus = np.array(hl.str(pop1.locus).collect())
alleles = np.array(hl.str(pop1.alleles).collect())
key = np.array([i + ' ' + j for i, j in zip(locus, alleles)])

In [None]:
s=2   # s is the number of populations - since we are calculating pair-wise FSTs, this is always 2 
key_FST_YRI = {}
for i in range(len(key)):
    FREQ = ((n1[i]*FREQpop1[i]) + (n2[i]*FREQpop2[i])) / (n1[i]+n2[i])
    
    if (FREQ>0) & (FREQ<1): # only include ave freq between 0 and 1
        
    ## average sample size that incorporates variance
        nc = ((1/(s-1)) * (n1[i]+n2[i])) - ((np.square(n1[i]) + np.square(n2[i]))/(n1[i]+n2[i]))

        msa= (1/(s-1))*((n1[i]*(np.square(FREQpop1[i]-FREQ)))+(n2[i]*(np.square(FREQpop2[i]-FREQ))))

        msw = (1/((n1[i]-1)+(n2[i]-1))) * ((n1[i]*(FREQpop1[i]*(1-FREQpop1[i]))) + (n2[i]*(FREQpop2[i]*(1-FREQpop2[i]))))

        numer = msa-msw

        denom = msa + ((nc-1)*msw)

        FST = numer/denom
        
        key_FST_YRI[key[i]] = FST

### *three popn pairs*

In [None]:
## example using three sample pairs ['CEU', 'YRI'], ['CEU', 'LWK'], ['CEU', 'ESN'] and setting up the function 
example_pairs = pair_com[0:3]

ex_dict = {} # empty dictionary to hold final outputs 
for pairs in example_pairs:
    l = [] # empty list to hold the subsetted datasets 
    l.append(mt_var_pru_filt.filter_cols(mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[0])) # first population 
    l.append(mt_var_pru_filt.filter_cols(mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[1])) # second population 
    l.append(mt_var_pru_filt.filter_cols((mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[0]) | (mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[1]))) # first + second = total population
    
    # sanity check - the sample count of the first and second subset mts should be equal to the total subset mt 
    if l[0].count()[1] + l[1].count()[1] == l[2].count()[1]: 
        v = [] # empty list to hold output mts from running common variant statistics 
        # run common variant statistics for each population and their combined mt
        v.append(hl.variant_qc(l[0])) # first population  
        v.append(hl.variant_qc(l[1])) # second population 
        v.append(hl.variant_qc(l[2])) # both/total population
        
        # add to dictionary 
        ex_dict["-".join(pairs)] = v

In [None]:
# three mt subsets per comparison pair - set up as a dictionary 
ex_dict

In [None]:
# population - YRI
# same steps we did to CEU


YRI_var == ex_dict['CEU-YRI'][0]

YRI_interm = ex_dict['CEU-YRI'][0].drop(*list(ex_dict['CEU-YRI'][0].entry)


YRI_interm = ex_dict['CEU-YRI'][0].drop(*list(ex_dict['CEU-YRI'][0].entry), *list(ex_dict['CEU-YRI'][0].col)[1:], *list(ex_dict['CEU-YRI'][0].row)[2:-1])

# only select the row field keys (locus and allele) and row fields 'AF' & 'AN' which are under 'variant_qc'
YRI_interm2 = YRI_interm.select_rows(YRI_interm['variant_qc']['AF'], YRI_interm['variant_qc']['AN'])  

# only include the second entry of the array from the row field 'AF' 
YRI_interm3 = YRI_interm2.transmute_rows(AF = YRI_interm2.AF[1])

# add a row field with population name to keep track of which mt it came from 
YRI_final = YRI_interm3.annotate_rows(pop = pairs[0])
YRI_final.rows().show(5)

In [None]:
# same as CEU_var['variant_qc'].show(5)
ex_dict['CEU-YRI'][0]['variant_qc'].show(5)

In [None]:
# trying things out for the main function 

a = ['CEU-YRI','CEU-LWK', 'CEU-ESN']
b = [0,1,2]
dc = {}
for i in a:
    li = []
    for j in b:
        li.append(str(j) + i)
    dc[i] = li 
    
#########################################

for i in range(len(v)-1):
    print(i)

#########################################

from collections import defaultdict

dd = defaultdict(list)

for d in (key_FST, key_FST_YRI):
    print(d)
    #for key, value in d.items():
        #dd[key].append(value)

In [None]:
final_dic = {}
for pair in ex_dict.keys(): # for each population pair 
    u = [] # list to hold updated mts  
    for i in range(len(ex_dict[pair])): # for each population (each mt)
        # pop1
        # drop certain fields and only keep the ones we need 
        interm = ex_dict[pair][i].drop(*list(ex_dict[pair][i].entry), *list(ex_dict[pair][i].col)[1:], *list(ex_dict[pair][i].row)[2:-1])
        interm2 = interm.select_rows(interm['variant_qc']['AF'], interm['variant_qc']['AN'])  
        interm3 = interm2.transmute_rows(AF = interm2.AF[1])
        #final = interm3.annotate_rows(pop = pair) # keep track of which mt it came from
        u.append(interm3) # add updated mt to list 
    
    # variables for FST run 

    # assign populations to formula variables 
    pop1 = u[0]
    pop2 = u[1]
    total = u[2]
        
    # number of alleles 
    n1 = np.array(pop1.AN.collect())
    n2 = np.array(pop2.AN.collect())
    total_n = np.array(total.AN.collect())

    # allele frequencies 
    FREQpop1 = np.array(pop1.AF.collect()) 
    FREQpop2 = np.array(pop2.AF.collect())
    total_FREQ = np.array(total.AF.collect()) 
    
    # locus + alleles = keys - needed for reference purposes during FST calculations - these values are uniform across all populations 
    locus = np.array(hl.str(pop1.locus).collect())
    alleles = np.array(hl.str(pop1.alleles).collect())
    key = np.array([i + ' ' + j for i, j in zip(locus, alleles)])
    
    s=2   # s is the number of populations - since we are calculating pair-wise FSTs, this is always 2 
    
    # FST pop1 and total popn
    key_pop1_total = {}
    for i in range(len(key)):
        FREQ = ((n1[i]*FREQpop1[i]) + (total_n[i]*total_FREQ[i])) / (n1[i]+total_n[i])

        if (FREQ>0) & (FREQ<1): # only include ave freq between 0 and 1

        ## average sample size that incorporates variance
            nc = ((1/(s-1)) * (n1[i]+total_n[i])) - ((np.square(n1[i]) + np.square(total_n[i]))/(n1[i]+total_n[i]))

            msa= (1/(s-1))*((n1[i]*(np.square(FREQpop1[i]-FREQ)))+(total_n[i]*(np.square(total_FREQ[i]-FREQ))))

            msw = (1/((n1[i]-1)+(total_n[i]-1))) * ((n1[i]*(FREQpop1[i]*(1-FREQpop1[i]))) + (total_n[i]*(total_FREQ[i]*(1-total_FREQ[i]))))

            numer = msa-msw

            denom = msa + ((nc-1)*msw)

            FST = numer/denom

            key_pop1_total[key[i]] = FST
            
    # FST pop2 and total popn
    key_pop2_total = {}
    for i in range(len(key)):
        FREQ = ((n2[i]*FREQpop2[i]) + (total_n[i]*total_FREQ[i])) / (n2[i]+total_n[i])

        if (FREQ>0) & (FREQ<1): # only include ave freq between 0 and 1

        ## average sample size that incorporates variance
            nc = ((1/(s-1)) * (n2[i]+total_n[i])) - ((np.square(n2[i]) + np.square(total_n[i]))/(n2[i]+total_n[i]))

            msa= (1/(s-1))*((n2[i]*(np.square(FREQpop2[i]-FREQ)))+(total_n[i]*(np.square(total_FREQ[i]-FREQ))))

            msw = (1/((n2[i]-1)+(total_n[i]-1))) * ((n2[i]*(FREQpop2[i]*(1-FREQpop2[i]))) + (total_n[i]*(total_FREQ[i]*(1-total_FREQ[i]))))

            numer = msa-msw

            denom = msa + ((nc-1)*msw)

            FST = numer/denom

            key_pop2_total[key[i]] = FST
    
    # merge the two FST results together
    from collections import defaultdict

    dd = defaultdict(list)

    for d in (key_pop1_total, key_pop2_total):
        for key, value in d.items():
            dd[key].append(value)
    
    final_dic[pair] = dd

In [None]:
# convert to a table 
import pandas as pd

df = pd.DataFrame(final_dic) 

len(final_dic['CEU-YRI']) # 246984

In [None]:
# not sure what this code is 
# actual function/run using all population pairs
dict = {} # empty dictionary to hold final outputs 
for pairs in pair_com:
    l = [] # empty list to hold the subsetted datasets 
    l.append(mt_var_pru_filt.filter_cols(mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[0])) # first population 
    l.append(mt_var_pru_filt.filter_cols(mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[1])) # second population 
    l.append(mt_var_pru_filt.filter_cols((mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[0]) | (mt_var_pru_filt['hgdp_tgp_meta']['Population'] == pairs[1]))) # first + second = total population
    
    # sanity check - the sample count of the first and second subset mts should be equal to the total subset mt 
    if l[0].count()[1] + l[1].count()[1] == l[2].count()[1]: 
        v = [] # empty list to hold output mts from running common variant statistics 
        # run common variant statistics for each population and their combined mt
        v.append(hl.variant_qc(l[0])) # first population  
        v.append(hl.variant_qc(l[1])) # second population 
        v.append(hl.variant_qc(l[2])) # both/total population
        
        # add to dictionary 
        dict["-".join(pairs)] = v

In [None]:
# accessing dictionary element with index 
ex_dict[list(ex_dict)[0]][0]['variant_qc'].show(5)

In [None]:
for l in list(ex_dict):
    print(ex_dict[l][1]['variant_qc']['AF'][1].show(5))

In [None]:
CEU_af_freq = ex_dict[list(ex_dict)[0]][0]['variant_qc']['AF'][1]

In [None]:
play_mt = hl.utils.range_matrix_table(0, 6)

## junk code below

In [None]:
# this code is if the alleles were split into their separate columns and if we expect a mismatch across popns 

# remove indels - only include single letter varients for each allele in both populations 
# this is b/c the FST formula is set up for single letter alleles 
#pop1 = CEU_final.filter_rows((CEU_final.A1.length() == 1) & (CEU_final.A2.length() == 1))
#pop2 = CEU_YRI_final.filter_rows((CEU_YRI_final.A1.length() == 1) & (CEU_YRI_final.A2.length() == 1))


# sanity check 
#A1 = pop1.A1.collect()
#A1 =  list(set(A1)) # OR can also do: 
### from collections import OrderedDict 
### A1 = list(OrderedDict.fromkeys(A1))

#print(A1) 
#len(A1) == 4

# total # of snps at the beginning - 255666 
# unique snps before removing indels - 2712 
# total # of snps after removing indels - 221017 (34649 snps were indels for A1, A2 or both)
# unique snps after removing indels - 4 ['C', 'A', 'T', 'G'] - which is what we expect 



## *use the same reference allele - A2 is minor allele here*  

# get the minor alleles from both populations  
#pop1_A2 = pop1.A2.collect()
#pop2_A2 = pop2.A2.collect()


# find values that are unequal 
#import numpy as np
#switch1 = (np.array(pop1_A2) != np.array(pop2_A2))
#print(switch1.all()) # all comparisons returned 'FALSE' which means that all variants that were compared are the same 

# sanity check 
#print(len(pop1_A2) == len(pop2_A2) == len(switch1)) # True 


### *if there is a variant mismatch among the minor alleles of the two populations*
# in case there was a comparison that didn't match correctly among the minor alleles of the two populations, we would adjust the allele frequency(AF) accordingly   
#new_frq = pop2.AF.collect() 
#new_frq = np.array(new_frq) # convert to numpy array for the next step

# explanation (with an example) for what this does is right below it 
#new_frq[switch1] = 1-(new_frq[switch1]) 
# Example: for pop_1, A1 and A2 are 'T' and 'C' with AF of 0.25 
# and for pop_2, A1 and A2 are 'C and 'T' with AF of 0.25
# then since the same reference allele is not used (alleles don't correctly align) in this case, 
# we would subtract the AF of pop_2 from 1, to get the correct allele frequency 
# the AF of pop_2 with A1 and A2 oriented the same way as pop_1: 'T' and 'C', would be 1-0.25 = 0.75 (w/c is the correct AF)

# if we wanted to convert array back to list 
#pop2_frq = new_frq.tolist() 


# junk code 
#pop2.rows().show(5)

#p = pop2.filter_rows(str(pop2.locus) =='chr10:38960343')
p.row.show()


# for i in locus:
#     if i =='chr1:94607079':
#         print ("True")
        
sum(num == dup for num,dup in zip(locus, d))

In [None]:
# code to check if there are duplicates in a list and print them out 
#import collections
#dup = [item for item, count in collections.Counter(key).items() if count > 1]
#print('Num of duplicate loci: ' + str(len(dup))) 
#print(dup)

In [None]:
# which FST value is for which locus? 
key_freq1 = {key[i]: FREQpop1[i] for i in range(len(key))}
key_freq2 = {key[i]: FREQpop2[i] for i in range(len(key))}


key_n1 = {key[i]: n1[i] for i in range(len(key))}
key_n2 = {key[i]: n2[i] for i in range(len(key))}

# for key,value in zip (locus, FREQpop1):
#     print(dict(key, value))
#for v1,v2 in zip(list(locus_freq1.values())[0:5], list(locus_freq2.values())[0:5]):
    #lq = ((n1*locus_freq1.values()) + (n2*locus_freq2.values())) / (n1+n2)
    #print(key,value)

In [None]:
#locus #220945
#len(set(FREQpop1))


# check if there are duplicates in locus list and print them out - 72 duplicates  
# import collections
# d = [item for item, count in collections.Counter(locus).items() if count > 1]

# list.sort(locus)
#locus

# from collections import Counter
# [k for k,v in Counter(locus).items() if v>1]

# where are each of the duplicated loci located?
from collections import defaultdict

D = defaultdict(list)
for i,item in enumerate(locus):
    D[item].append(i)
D = {k:v for k,v in D.items() if len(v)>1}
locus[6202]

In [None]:
bad_locus = locus[INCLUDE=='FALSE']

# ave freq values that were not between 0 and 1 - returned FALSE to the conditions in the above chuck of code 
print(np.count_nonzero(INCLUDE==0))
DONT_INCLUDE= (FREQ=='') & (FREQ>=1)
np.count_nonzero(DONT_INCLUDE)

In [None]:
# convert the output from the preimp_qc module (qced.mt) into a vcf file in Hail 
import hail as hl 
mt = hl.read_matrix_table('gs://nepal-geno/GWASpy/Preimp_QC/Nepal_PTSD_GSA_Updated_May2021_qced.mt')
hl.export_vcf(mt, 'gs://nepal-geno/Nepal_PTSD_GSA_Updated_May2021_qced.vcf.bgz')

In [None]:
# to figure out which samples were removed
s_pre = mt.s.collect()
s_post = mt_filt.s.collect()
removed_samples = set(s_pre) ^ set(s_post)