In [123]:
#!/usr/bin/env python3
#
# Moritz Blumer, 2024

"""
WinPCA. A package for windowed PCA analysis.
"""

## IMPORT PACKAGES
import sys, os
#import pandas as pd


## IMPORT MODULES
from modules import config
from modules.cli import CLI
# from modules.error_handling import ErrorHandling ## add

In [130]:
## DELETE CELL

# simulate command line arguments
# command_line = 'winpca.ipynb pca -n test --vcf test_dataset/input/sample.vcf.gz -s ind_1,ind_2,ind_3,ind_4,ind_5,ind_6 -r chr1:1-33000000 -w 1000000 -i 100000 -m 0.01 -p auto'
command_line = 'winpca.ipynb pca -n test --vcf test_dataset/input/sample.vcf.gz -s samples.tsv -p auto -r chr1:1-30000000 -w 1000000 -i 100000 -m 0.01'
# command_line = 'winpca.ipynb polarize -n test -p auto -c 2'
# command_line = 'winpca.ipynb chromplot -n test -r chr1:1-30000000 -m test_dataset/input/metadata.tsv -g inversion_state -c ancestral:eb4034,inverted:2f35a8,heterozygous:197d34'
#command_line = 'winpca.ipynb chromplot -h'
sys.argv = command_line.split(' ')
sys.argv


['winpca.ipynb',
 'pca',
 '-n',
 'test',
 '--vcf',
 'test_dataset/input/sample.vcf.gz',
 '-s',
 'samples.tsv',
 '-p',
 'auto',
 '-r',
 'chr1:1-30000000',
 '-w',
 '1000000',
 '-i',
 '100000',
 '-m',
 '0.01']

In [131]:
## PARSE COMMAND LINE ARGUMENTS

# instantiate
cli = CLI()

# call subparsers
cli.pca()
cli.pcangsd()
cli.polarize()
cli.flip()
cli.chromplot()
cli.genomeplot()

# parse argments
cli.parse_args()
args_dct = cli.args_dct

# enter mode
mode = args_dct['winpca']
args_dct

Namespace(guide_samples=None, min_maf=0.01, polarize='auto', prefix='test', region='chr1:1-30000000', samples='samples.tsv', variant_file_path='test_dataset/input/sample.vcf.gz', w_size=1000000, w_step=100000, winpca='pca')


{'winpca': 'pca',
 'prefix': 'test',
 'variant_file_path': 'test_dataset/input/sample.vcf.gz',
 'region': 'chr1:1-30000000',
 'samples': 'samples.tsv',
 'w_size': 1000000,
 'w_step': 100000,
 'min_maf': 0.01,
 'polarize': 'auto',
 'guide_samples': None,
 'chrom': 'chr1',
 'start': 1,
 'end': 30000000,
 'sample_lst': ['ind_1', 'ind_2', 'ind_3', 'ind_5', 'ind_4', 'ind_6'],
 'guide_sample_lst': None,
 'skip_monomorphic': False,
 'min_var_per_w': 25,
 'n_prev_windows': 5,
 'pol_pc': 'both',
 'flip_pc': '1'}

In [132]:
# WINDOWED PCA FROM CALLED GENOTYPES

if mode == 'pca':

    # import relevant modules
    from modules.windowed_pca import GTWindowedPCA
    from modules.data import WPCAData

    # instantiate windowed PCA
    w_pca = GTWindowedPCA(
        variant_file_path = args_dct['variant_file_path'],
        sample_lst = args_dct['sample_lst'],
        chrom = args_dct['chrom'],
        start = args_dct['start'],
        stop = args_dct['end'],
        w_size = args_dct['w_size'],
        w_step = args_dct['w_step'],
        skip_monomorphic=config.skip_monomorphic,
        )
    
    # run
    w_pca.win_vcf_gt()

    # parse run data
    data = WPCAData(args_dct['prefix'], w_pca)


# WINDOWED PCA FROM CALLED GENOTYPES

# elif mode == 'pcangsd':

#     [...]

# EXISTING DATA:

else:
    from modules.data import WPCAData
    data = WPCAData(args_dct['prefix'])

[INFO] Processed 1 of 291 windows
[INFO] Processed 2 of 291 windows
[INFO] Processed 3 of 291 windows


[INFO] Processed 4 of 291 windows
[INFO] Processed 5 of 291 windows
[INFO] Processed 6 of 291 windows
[INFO] Processed 7 of 291 windows
[INFO] Processed 8 of 291 windows
[INFO] Processed 9 of 291 windows
[INFO] Processed 10 of 291 windows
[INFO] Processed 11 of 291 windows
[INFO] Processed 12 of 291 windows
[INFO] Processed 13 of 291 windows
[INFO] Processed 14 of 291 windows
[INFO] Processed 15 of 291 windows
[INFO] Processed 16 of 291 windows
[INFO] Processed 17 of 291 windows
[INFO] Processed 18 of 291 windows
[INFO] Processed 19 of 291 windows
[INFO] Processed 20 of 291 windows
[INFO] Processed 21 of 291 windows
[INFO] Processed 22 of 291 windows
[INFO] Processed 23 of 291 windows
[INFO] Processed 24 of 291 windows
[INFO] Processed 25 of 291 windows
[INFO] Processed 26 of 291 windows
[INFO] Processed 27 of 291 windows
[INFO] Processed 28 of 291 windows
[INFO] Processed 29 of 291 windows
[INFO] Processed 30 of 291 windows
[INFO] Processed 31 of 291 windows
[INFO] Processed 32 of 291

In [133]:
# POLARIZE

# polarize
if mode in ['pca', 'pcangsd', 'polarize'] \
    and not args_dct['polarize'] == 'skip':
    
    from modules.transform_data import Polarize
    polarize = Polarize()

    # adaptive
    if args_dct['polarize'] == 'auto':
        if args_dct['pol_pc'] == 'both':
            data.modify_data(
                'pc_1', polarize.adaptive, args_dct['n_prev_windows']
            )
            data.modify_data(
                'pc_2', polarize.adaptive, args_dct['n_prev_windows']
            )
        else:
            data.modify_data(
                'pc_' + str(args_dct['pol_pc']), polarize.adaptive, \
                    args_dct['n_prev_windows']
            )

    # using guide samples
    if args_dct['polarize'] == 'guide_samples':
        if args_dct['pol_pc'] == 'both':
            data.modify_data(
                'pc_1', polarize.guide_samples, args_dct['guide_sample_lst']
            )
            data.modify_data(
                'pc_2', polarize.guide_samples, args_dct['guide_sample_lst']
            )
        else:
            data.modify_data(
                'pc_' + str(args_dct['pol_pc']), polarize.guide_samples, \
                    args_dct['guide_sample_lst']
            )



In [51]:
# FLIP

if mode == 'flip':
    from modules.transform_data import Flip
    flip = Flip()
    if args_dct['flip_pc'] == 'both':
        data.modify_data('pc_1', flip.flip_chrom)
        data.modify_data('pc_2', flip.flip_chrom)
    else:
        data.modify_data('pc_' + str(args_dct['flip_pc']), flip.flip_chrom)

In [52]:
# WRITE RESULTS

# create output directory if prefix contains '/'
if '/' in args_dct['prefix']:
    if not os.path.exists('/'.join(args_dct['prefix'].split('/')[0:-1]) + '/'):
        os.makedirs('/'.join(args_dct['prefix'].split('/')[0:-1]) + '/')

# write results to files
data.to_files()

In [136]:
# IMPORT PACKAGES
import pandas as pd
import numpy as np

# IMPORT MODULES
from plotly.subplots import make_subplots
import plotly.graph_objects as go

data.modify_data(
    'pc_1', polarize.adaptive, 3
)

plot_var = 'pc_1'
data_df = getattr(data, plot_var)
metadata_path = 'test_dataset/input/metadata.tsv'

# read metadata and print error message if there are non.unique IDs
metadata_df = pd.read_csv(metadata_path, sep='\t', index_col=0, dtype=str)
if len(metadata_df.index) != len(set(metadata_df.index)):
    print('\n[ERROR] The provided metadata file contains non-unique sample'
          ' IDs.', 
          file=sys.stderr)
    sys.exit()

# then subset and reorder metadata_df to match data_df individuals
metadata_df = metadata_df.reindex(data_df.columns).dropna()

# copy index to column 'id'
metadata_df['id'] = metadata_df.index

# if individuals are missing in the metadata file print error message & exit
if len(metadata_df) != len(data_df.columns):
    print('\n[ERROR] One or more sample IDs are missing in the provided.'
          ' metadata file.', 
          file=sys.stderr)
    sys.exit()

# transpose data_df and name index
data_df = data_df.T
data_df.index.name = 'id'

# add metadata columns to data_df
for column_name in metadata_df.columns:
    data_df[column_name] = list(metadata_df[column_name])

# replace numpy NaN with 'NA' for plotting (hover_data display)
data_df = data_df.replace(np.nan, 'NA')

# convert to long format for plotting
data_df = pd.melt(
    data_df,
    id_vars=metadata_df.columns,
    var_name='pos',
    value_name=plot_var,
)


#
metadata_df
data_df


Unnamed: 0,coverage,species,inversion_state,id,pos,pc_1
0,20X,species_1,ancestral,ind_1,500000,-4.762554
1,21X,species_1,inverted,ind_2,500000,-3.849530
2,20X,species_1,heterozygous,ind_3,500000,-0.406797
3,19X,species_1,inverted,ind_5,500000,3.919203
4,21X,species_1,ancestral,ind_4,500000,0.292388
...,...,...,...,...,...,...
1747,21X,species_1,inverted,ind_2,29600000,-9.815533
1748,20X,species_1,heterozygous,ind_3,29600000,-2.631565
1749,19X,species_1,inverted,ind_5,29600000,-10.570201
1750,21X,species_1,ancestral,ind_4,29600000,13.634994


In [137]:
stat_df = data.stat

# Create a subplot figure with 2 rows and 1 column
args_dct['hex_code_dct'] = {
    'ancestral': '#eb4034',
    'inverted': '#2f35a8',
    'heterozygous': '#197d34',
}
args_dct['color_by'] = 'inversion_state'


group = args_dct['color_by'] if args_dct['color_by'] else 'id'
group_lst = list(set(metadata_df[group]))
# set plotting colors
if args_dct['hex_code_dct']:
    color_dct = args_dct['hex_code_dct']
    if not all(x in color_dct.keys() for x in metadata_df[group]):
        print('\n[ERROR] HEX codes missing for one or more groups.',
                file=sys.stderr)
        sys.exit()
else:
    import plotly.colors as pc
    def_col_lst = pc.DEFAULT_PLOTLY_COLORS
    
    color_dict = {
        group_lst[i]: \
            def_col_lst[i % len(def_col_lst)] for i in range(len(group_lst))
    }

if len(plot_vars) == 1:
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[1, 6], vertical_spacing=0.0)

    # variance explained
    hover_data = [
        ''.join(
            [f'<b>pos</b>: {str(idx)}<br><b>var_explained</b>: \
             <b>{row[str(plot_var) + "_ve"]}%<br>' ]
        ) for idx, row in stat_df.iterrows()
    ]

    fig.add_trace(
        go.Scatter(
            x=stat_df.index,
            y=stat_df[plot_var + '_ve'],
            name='variance explained',
            legendgroup='variance explained',
            mode='lines',
            text=hover_data,
            hoverinfo='text',
            line=dict(color='#4d61b0', width=1),
            fill='tozeroy',
            connectgaps=True,
        ),
    row=1, col=1)

    # Loop through each group and create a trace for each group
    for id in set(data_df['id']):
        id_df = data_df[data_df['id'] == id]
        
        hover_data = [
            ''.join(
                [f'<b>{col}</b>: {row[col]}<br>' for col in id_df.columns]
            ) for i, row in id_df.iterrows()
        ]

        plot_color = color_dct[list(id_df[group])[0]]

        # Add the trace for the group to the first subplot (row=1, col=1)
        fig.add_trace(
            go.Scatter(
                x=id_df['pos'],
                y=id_df[plot_var],
                text=hover_data,
                hoverinfo='text',
                name=list(id_df[group])[0],   
                legendgroup=list(id_df[group])[0],
                #name='a',   
                #legendgroup='a',
                mode='lines',
                line=dict(color=plot_color),

            ),
            row=2, col=1  # Specify which subplot to add the trace to
        )

    # adjust layout
    fig.update_layout(
        template='simple_white',
        font_family='Arial',
        font_color='black',
        xaxis=dict(ticks='outside', mirror=True, showline=True,),
        yaxis=dict(ticks='outside', mirror=True, showline=True,),
        #legend={'traceorder':'normal'}, 
        title={'xanchor': 'center', 'y': 0.9, 'x': 0.45},
        )

    # set line width
    fig.update_traces(line=dict(width=0.5, color='darkgrey'), row=1, col=1)
    fig.update_traces(line=dict(width=0.5), row=2, col=1)

    # set x axis range
    fig.update_xaxes(range=[args_dct['start'], args_dct['end']], showticklabels=False, ticks='', row=1, col=1)
    fig.update_xaxes(range=[args_dct['start'], args_dct['end']], showline=True, linecolor='black', linewidth=1, row=2, col=1)

    # set line width
    fig.update_yaxes(showline=True, mirror=True, linecolor='black', linewidth=1, row=1, col=1, side='left')
    fig.update_yaxes(showline=True, mirror=True, linecolor='black', linewidth=1, row=2, col=1, side='left')

    # Show the figure
    fig.show()


In [140]:
stat_df = data.stat

# Create a subplot figure with 2 rows and 1 column
args_dct['hex_code_dct'] = {
    'ancestral': '#eb4034',
    'inverted': '#2f35a8',
    'heterozygous': '#197d34',
}
args_dct['color_by'] = 'inversion_state'


group = args_dct['color_by'] if args_dct['color_by'] else 'id'
group_lst = list(set(metadata_df[group]))
# set plotting colors
if args_dct['hex_code_dct']:
    color_dct = args_dct['hex_code_dct']
    if not all(x in color_dct.keys() for x in metadata_df[group]):
        print('\n[ERROR] HEX codes missing for one or more groups.',
                file=sys.stderr)
        sys.exit()
else:
    import plotly.colors as pc
    def_col_lst = pc.DEFAULT_PLOTLY_COLORS
    
    color_dict = {
        group_lst[i]: \
            def_col_lst[i % len(def_col_lst)] for i in range(len(group_lst))
    }

if len(plot_vars) == 1:
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[1, 6], vertical_spacing=0.0)

    # variance explained
    hover_data = [
        ''.join(
            [f'<b>pos</b>: {str(idx)}<br><b>var_explained</b>: \
             <b>{row[str(plot_var) + "_ve"]}%<br>' ]
        ) for idx, row in stat_df.iterrows()
    ]

    fig.add_trace(
        go.Scatter(
            x=stat_df.index,
            y=stat_df[plot_var + '_ve'],
            name='variance explained',
            legendgroup='variance explained',
            mode='lines',
            text=hover_data,
            hoverinfo='text',
            line=dict(color='#4d61b0', width=1),
            fill='tozeroy',
            connectgaps=True,
        ),
    row=1, col=1)



    # Loop through each group and create a trace for each group
    for group_name in group_lst:
        group_df = data_df[data_df[group] == group_name]

        x_vals = []
        y_vals = []
        hover_texts = []
        
        for id in set(group_df['id']):
            id_df = group_df[group_df['id'] == id]
            x_vals.extend(id_df['pos'].tolist() + [None])  # Add None for a gap
            y_vals.extend(id_df[plot_var].tolist() + [None])  # Add None for a gap

            hover_data = [
                ''.join(
                    [f'<b>{col}</b>: {row[col]}<br>' for col in id_df.columns]
                ) for i, row in id_df.iterrows()
            ]
            hover_texts.extend(hover_data + [None])

        plot_color = color_dct[group_name]

        # Add the trace for the group to the first subplot (row=1, col=1)
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=y_vals,
                text=hover_texts,
                hoverinfo='text',
                name=list(id_df[group])[0],   
                legendgroup=list(id_df[group])[0],
                #name='a',   
                #legendgroup='a',
                mode='lines',
                line=dict(color=plot_color),

            ),
            row=2, col=1  # Specify which subplot to add the trace to
        )

    # adjust layout
    fig.update_layout(
        template='simple_white',
        font_family='Arial',
        font_color='black',
        xaxis=dict(ticks='outside', mirror=True, showline=True,),
        yaxis=dict(ticks='outside', mirror=True, showline=True,),
        #legend={'traceorder':'normal'}, 
        title={'xanchor': 'center', 'y': 0.9, 'x': 0.45},
        )

    # set line width
    fig.update_traces(line=dict(width=.7, color='darkgrey'), row=1, col=1)
    fig.update_traces(line=dict(width=.7), row=2, col=1)

    # set x axis range
    fig.update_xaxes(range=[args_dct['start'], args_dct['end']], showticklabels=False, ticks='', row=1, col=1)
    fig.update_xaxes(range=[args_dct['start'], args_dct['end']], showline=True, linecolor='black', linewidth=1, row=2, col=1)

    # set line width
    fig.update_yaxes(showline=True, mirror=True, linecolor='black', linewidth=1, row=1, col=1, side='left')
    fig.update_yaxes(showline=True, mirror=True, linecolor='black', linewidth=1, row=2, col=1, side='left')

    # Show the figure
    fig.show()
