In [1]:
#!/usr/bin/env python
#
# Moritz Blumer | 2021-12-07 (2023-03-15)
#
# Conduct sliding window PCA with scikit-allel 



# produces identical results to previous version (except window boundaries)

#### ADD PC SELECTION, AND REMOVE PC2 FUNCTIONALITY (--> to be run with whatever PC)

#### IMPROVE READING EXISTING OUTPUT (SAVE UNANNOTATED?)

#### NUMBA?

#### FIX PCT INCLUDED SITES SCALE; ABRUNDEN PC % EXP VALUES

# To add PC2 output, uncomment lines tagged with '# uncomment for PC 2'

# add check whether input samples are unique (if not len(set(input_salmpes)) == ...)

# change documentation: 'primary_id' --> 'id' (maybe change to 'sample_id')

# PLOTTING SHOULDN'T START AT 0 BUT WINDOW MID!




# The minimum number of variants per window is set to 50 and can be changed below (NA will be returned for windows with fewer variants)
min_var_per_window = 50

# import packages

import allel
import gzip
import sys
import numpy as np
import pandas as pd
#import numba as nb




In [2]:
# set arguments
variant_file_path = '../test_dataset/input/genotype_file.tsv.gz'
#variant_file_path = '../test_dataset/input/sample.vcf'
metadata_path = '../test_dataset/input/metadata.tsv'
var_threshold = 9 # make set inbsteat of option
mean_threshold = 3
chrom = 'chr1'
pc=1
skip_monomorphic = True
output_prefix = 'test_out'
output_prefix = output_prefix.lower()
color_taxon = 'inversion_state'
w_size = 1000000
w_step = 10000
region = '10000000-20000000'
start = int(region.split('-')[0])
stop = int(region.split('-')[1])

In [4]:
## GET RID OF ORDER BY GT FILE, OR NOT? --> TEST IF CURRENT VERSION CAUSES PROBLEMS

# read metadata

def read_metadata(variant_file_path, metadata_path, taxon=None, group=None):
    '''
    Read in metadata, optionally filter by taxon and sort by gt_file sample order

    '''

    # fetch sample names from genotype file header
    read_func = gzip.open if variant_file_path.endswith('.gz') else open
    with read_func(variant_file_path, 'rt') as gt_file:
        samples_lst = gt_file.readline().strip().split('\t')[2:]

    # read in metadata
    metadata_df = pd.read_csv(metadata_path, sep='\t')

    # re-name first column to 'id' (this is the only required column and must have unique ids)
    metadata_df.columns.values[0] = 'id'

    # subset input samples to match taxon group specification if specified
    if taxon and group:
        metadata_df = metadata_df.loc[metadata_df[taxon].isin(group.split(','))]
    
    # remove individuals that are not in the genotype file
    exclude_lst = [x for x in list(metadata_df['id']) if x not in samples_lst]
    for i in exclude_lst:
        metadata_df.drop(metadata_df[metadata_df['id'] == i].index, inplace=True)
    
    # # get index of samples kept after filtering
    # sample_idx_lst = sorted([samples_lst.index(x) for x in list(metadata_df['id'])])
    # keep_id_lst = [samples_lst[x] for x in sample_idx_lst]

    # # sort metadata by VCF sample order 
    # metadata_df['id'] = pd.Categorical(metadata_df['id'], categories = keep_id_lst, ordered = True)
    # metadata_df.sort_values('id', inplace=True)

    return metadata_df

metadata_df = read_metadata(variant_file_path, metadata_path, taxon='species', group='species_1,species_2')


def pca(win, w_start, w_size):
    '''
    Conduct PCA, but if ((n_variants < min_var_per_window)) generate empty/dummy output instead
    '''

    # trim off pos info
    win = [x[1:] for x in win]

    # get window mid for X value
    w_mid = int(w_start + w_size/2-1)

    # count variants
    n_variants = len(win)

    # if # variants passes specified threshold  
    if n_variants > min_var_per_window: # CHANGE TO: "if n_variants >= min_var_per_window:" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        w_gt_arr = np.array(win, dtype=np.int8)
        pca = allel.pca(w_gt_arr, n_components=2, copy=True, scaler='patterson', ploidy=2)
        out = [pca[0][: , 0], pca[0][: , 1], pca[1].explained_variance_ratio_[0]*100, pca[1].explained_variance_ratio_[1]*100, n_variants]

    # else create empty output
    else:
        print('[INFO] Skipped window ' + str(w_start) + '-' + str(w_start + w_size-1) + ' with ' + str(n_variants) + ' variants (threshold is ' + str(min_var_per_window) + ' variants per window)', file=sys.stderr, flush=True)
        empty_array = [None] * len(win[0])
        out = [empty_array, empty_array, None, None, n_variants]

    # append output
    w_mid_lst.append(w_mid)
    w_pca_lst.append(out[0]) if pc==1 else out[1] # depending on whether PC1 or PC2 was specified, append out[0] or out[1], respectively
    w_stats_lst.append([out[2], out[3], out[4]])


# frame for window-by-window analysis
def windowed_pca(variant_file_path, chrom, start, stop, metadata_df, pca, w_size, w_step, skip_monomorphic=True):

    global w_mid_lst, w_pca_lst, w_stats_lst

    # # initialize results containers
    w_mid_lst = []
    w_pca_lst = []
    w_stats_lst = []

    # # conduct windowed PCA using window_parser() function
    if variant_file_path.endswith('.vcf') or variant_file_path.endswith('.vcf.gz'):
        from window_parser import win_vcf
        win_vcf(variant_file_path, chrom, start, stop, metadata_df['id'], w_size, w_step, pca, skip_monomorphic=skip_monomorphic)

    elif variant_file_path.endswith('.tsv') or variant_file_path.endswith('.tsv.gz'):
        from window_parser import win_gt_file
        win_gt_file(variant_file_path, chrom, start, stop, metadata_df['id'], w_size, w_step, pca, skip_monomorphic=skip_monomorphic)

    # compile output dataframe for windowed PCA
    w_pca_df = pd.DataFrame(np.transpose(w_pca_lst), index=list(metadata_df['id']), columns=w_mid_lst)
    w_pca_df.index.names = ['id']

    # compile output dataframe for supplementary info (% variance explained, # sites per window)
    w_stats_df = pd.DataFrame(w_stats_lst, index=w_mid_lst, columns=['pct_explained_pc_1', 'pct_explained_pc_2', 'num_variants'])
    w_stats_df.index.names = ['window_mid']
    w_stats_df['pct_included_sites'] = w_stats_df['num_variants']/w_size/100

    return w_pca_df, w_stats_df

w_pca_df, w_stats_df = windowed_pca(variant_file_path, chrom, start, stop, metadata_df, pca, w_size, w_step, skip_monomorphic)


# polarize windowed PCA output
def polarize(w_pca_df, var_threshold=9, mean_threshold=3): ## IMPROVE GUIDESAMPLE SETTINGS

    '''
    - take a w_pca_df and adjust window orientation using a selection of a few samples with high absolute values and small variability
    - then annotate the df with metadata.
    - hack: if setting var_threshold=False and mean_threshold to a list of ids (e.g. "cichlid7050764,cichlid7050776,cichlid7050768"), those will be used as guide samples for polarizaion
    '''

    # select the 9 samples with the least variance, and from those the 3 with the highest absolute value accross 
    # all windows as guide samples to calibrate the orientation of all windows

    if not var_threshold == 'False':
        guide_samples = list(w_pca_df.dropna(axis=1).abs().var(axis=1).sort_values(ascending=True).index[0:var_threshold])
        guide_samples_df = w_pca_df.loc[guide_samples]
        guide_samples = list(guide_samples_df.dropna(axis=1).abs().sum(axis=1).sort_values(ascending=False).index[0:mean_threshold])

    else:
        guide_samples = mean_threshold.split(',')
        guide_samples_df = w_pca_df.loc[guide_samples]

    guide_samples_df = guide_samples_df.loc[guide_samples]

    # for each guide sample, determine whether the positive or negative absolute value of each window is closer 
    # to the value in the previous window. If the negative value is closer, switch that windows orientation
    # (1 --> switch, 0 --> keep)
    
    rows_lst = []    
    for row in guide_samples_df.iterrows():
        row = list(row[1])
        last_window = row[0] if not row[0] == None else 0 # only if the first window is None, last_window can be None, in that case set it to 0 to enable below numerical comparisons
        out = [0]
    
        for window in row[1:]:
            if window == None:
                out.append(0)
                continue
    
            elif abs(window - last_window) > abs(window - (last_window*-1)):
                out.append(1)
                last_window = (window*-1)
    
            else:
                out.append(-1)
                last_window = window
    
        rows_lst.append(out)

    # sum up values from each row and save to switch_lst
    rows_arr = np.array(rows_lst, dtype=int).transpose()
    switch_lst = list(rows_arr.sum(axis=1))

    # switch individual windows according to switch_lst (switch if value is negative)
    for idx, val in zip(list(w_pca_df.columns), switch_lst):
        if val < 0:
            w_pca_df[idx] = w_pca_df[idx]*-1

    # switch Y axis if largest absolute value is negative
    if abs(w_pca_df.to_numpy(na_value=0).min()) > abs(w_pca_df.to_numpy(na_value=0).max()):
        w_pca_df = w_pca_df * -1

    return w_pca_df

w_pca_df = polarize(w_pca_df, var_threshold=9, mean_threshold=3)


# save output data before annotation
w_pca_df.to_csv(output_prefix + '.w_pc_' + str(pc) + '.tsv.gz', sep='\t', compression='gzip')
w_stats_df.to_csv(output_prefix + '.w_stats.tsv.gz', sep='\t', compression='gzip')

# pivot & annotate windowed pca output
def annotate(w_pca_df, metadata_df, pc):

    # annotate with metadata
    for column_name in metadata_df.columns:
        w_pca_df[column_name] = list(metadata_df[column_name])

    # replace numpy NaN with 'NA' for plotting (hover_data display)
    w_pca_df = w_pca_df.replace(np.nan, 'NA')

    # convert to long format for plotting
    w_pca_df_annotated = pd.melt(w_pca_df, id_vars=metadata_df.columns, var_name='window_mid', value_name=pc)

    return w_pca_df_annotated

pc_anno_df = annotate(w_pca_df, metadata_df, 'pc_' + str(pc))

# plot windowed PCA output & save
from utils import plot_w_pca
w_pca_fig = plot_w_pca(pc_anno_df, 1, color_taxon, chrom, start, stop, w_size, w_step)
w_pca_fig.write_html(output_prefix + '.w_pc_' + str(pc) + '.html')
w_pca_fig.write_image(output_prefix + '.w_pc_' + str(pc) + '.pdf', engine='kaleido', scale=2.4)

# plot window stats & save
from utils import plot_w_stats
w_stats_fig = plot_w_stats(w_stats_df, chrom, start, stop, w_size, w_step)
w_stats_fig.write_html(output_prefix + '.w_stats' + '.html')
w_stats_fig.write_image(output_prefix + '.w_stats' + '.pdf', engine='kaleido', scale=2.4)


[INFO] Processed 1 of 901 windows
[INFO] Processed 2 of 901 windows
[INFO] Processed 3 of 901 windows
[INFO] Processed 4 of 901 windows
[INFO] Processed 5 of 901 windows
[INFO] Processed 6 of 901 windows
[INFO] Processed 7 of 901 windows
[INFO] Processed 8 of 901 windows
[INFO] Processed 9 of 901 windows
[INFO] Processed 10 of 901 windows
[INFO] Processed 11 of 901 windows
[INFO] Processed 12 of 901 windows
[INFO] Processed 13 of 901 windows
[INFO] Processed 14 of 901 windows
[INFO] Processed 15 of 901 windows
[INFO] Processed 16 of 901 windows
[INFO] Processed 17 of 901 windows
[INFO] Processed 18 of 901 windows
[INFO] Processed 19 of 901 windows
[INFO] Processed 20 of 901 windows
[INFO] Processed 21 of 901 windows
[INFO] Processed 22 of 901 windows
[INFO] Processed 23 of 901 windows
[INFO] Processed 24 of 901 windows
[INFO] Processed 25 of 901 windows
[INFO] Processed 26 of 901 windows
[INFO] Processed 27 of 901 windows
[INFO] Processed 28 of 901 windows
[INFO] Processed 29 of 901 wi

In [5]:
def parse_arguments():

    '''
    parse command line arguments and print help message if the number of arguments is different from what is expected
    '''

    # declare all variables global

    global gt_file_path, metadata_path, output_prefix, chrom, chrom_len, w_size, w_step, taxon, group, color_taxon, var_threshold, mean_threshold


    # fetch arguments
    
    _, gt_file_path, metadata_path, output_prefix, chrom, chrom_len, w_size, w_step, taxon, group, color_taxon, var_threshold, mean_threshold = sys.argv


    # print help message if incorrect number of arguments was specified

    if len(sys.argv) != 13:
        print('\nUsage:', file=sys.stderr)
        print('\tpython windowed_pca.py <genotype matrix> <metadata> <chromosome name> <chromosome length> <window size> <window step size> <filter column name> <filter column value> \ \n                         <color column name> <variance threshold> <mean threshold> <output prefix>\n', file=sys.stderr)
        print('\t\t<genotype matrix>\tstr\tpath to the genotype matrix file produced as described in the README', file=sys.stderr)
        print('\t\t<metadata>\t\tstr\tpath to the metadata file produced as described in the README', file=sys.stderr)
        print('\t\t<output prefix>\t\tstr\tprefix that will be used for all output files, can also be a directory to be created', file=sys.stderr)
        print('\t\t<chromosome name>\tstr\tname of the chromosome, e.g. "chr1"', file=sys.stderr)
        print('\t\t<chromosome length>\tint\tlength of the chromosome in bp, e.g. "32123123"', file=sys.stderr)
        print('\t\t<window size>\t\tint\tsize of the sliding window in bp, e.g. "1000000"', file=sys.stderr)
        print('\t\t<window step>\t\tint\tstep size of the sliding window in bp, e.g. "10000"', file=sys.stderr)
        print('\t\t<filter column name>\tstr\tset a metadata column name to be used to select individuals to be included in the analysis e.g. "genus" (see filter column value)', file=sys.stderr)
        print('\t\t<filter column value>\t\tstr\tselect a value to be filtered for in the defined filter column. Setting <filter column name> to "genus" and <filter column value> to "Homo" would include all individuals of the genus Homo in the output, and ignore all others. A comma-separated list of include values can be provided, to include for example a specific subset of genera ("Homo,Pan") ', file=sys.stderr)
        print('\t\t<color column name>\tstr\tselect a metadata column that will serve to partition included individuals into color groups in the output plots. If selecting e.g. "genus", all individuals from the same genus will have the same color in the output plots. If specifying a comma-separated list of column names (e.g. "genus,species"), two versions of each output plot will be produced, that differ only in the color scheme', file=sys.stderr)
        print('\t\t<variance threshold>\tint\trelevant to correct random switching along PC axes, see code for details, if unsure, use "9"', file=sys.stderr)
        print('\t\t<mean threshold>\tint\trelevant to correct random switching along PC axes, see code for details, if unsure, use "3"\n', file=sys.stderr)
        sys.exit()

    # allow to slecify no input sample filter
    if taxon == 'None': taxon = None
    if group == 'None': group = None

    # convert str to int where necessary (if var_threshold == False, assumes list of specific guide samples as mean_threshold)

    if not var_threshold == 'False':
        chrom_len, w_size, w_step, var_threshold, mean_threshold = int(chrom_len), int(w_size), int(w_step), int(var_threshold), int(mean_threshold)

    else:
        chrom_len, w_size, w_step = int(chrom_len), int(w_size), int(w_step)
    
    output_prefix = output_prefix.lower()

############################

