In [1]:
#!/usr/bin/env python3
#
# Moritz Blumer, 2024

"""
WinPCA. A package for windowed PCA analysis.
"""

## IMPORT PACKAGES
import sys, os
#import pandas as pd


## IMPORT MODULES
from modules import config
from modules.cli import CLI
# from modules.error_handling import ErrorHandling ## add

0.01


In [17]:
## DELETE CELL

# simulate command line arguments
# command_line = 'winpca.ipynb pca -n test --vcf test_dataset/input/sample.vcf.gz -s ind_1,ind_2,ind_3,ind_4,ind_5,ind_6 -r chr1:1-33000000 -w 1000000 -i 100000 -m 0.01 -p auto'
command_line = 'winpca.ipynb pca -n test --vcf test_dataset/input/sample.vcf.gz -s samples.tsv -p auto -r chr1:1-30000000 -w 1000000 -i 100000 -m 0.01'
# command_line = 'winpca.ipynb polarize -n test -p auto -c 2'
command_line = 'winpca.ipynb chromplot -n test -r chr1:1-30000000 -m test_dataset/input/metadata.tsv'# -g inversion_state' #-c ancestral:eb4034,inverted:2f35a8,heterozygous:197d34'
# command_line = 'winpca.ipynb pca -h'
sys.argv = command_line.split(' ')
sys.argv


['winpca.ipynb',
 'chromplot',
 '-n',
 'test',
 '-r',
 'chr1:1-30000000',
 '-m',
 'test_dataset/input/metadata.tsv']

In [18]:
## PARSE COMMAND LINE ARGUMENTS

# instantiate
cli = CLI()

# call subparsers
cli.pca()
cli.pcangsd()
cli.polarize()
cli.flip()
cli.chromplot()
cli.genomeplot()

# parse argments
cli.parse_args()
args_dct = cli.args_dct

# enter mode
mode = args_dct['winpca']
args_dct

Namespace(color_by=None, file_format='both', hex_codes=None, interval=None, metadata_path='test_dataset/input/metadata.tsv', prefix='test', region='chr1:1-30000000', winpca='chromplot')


{'winpca': 'chromplot',
 'prefix': 'test',
 'region': 'chr1:1-30000000',
 'metadata_path': 'test_dataset/input/metadata.tsv',
 'color_by': None,
 'hex_codes': None,
 'interval': None,
 'file_format': 'both',
 'chrom': 'chr1',
 'start': 1,
 'end': 30000000,
 'hex_code_dct': None,
 'skip_monomorphic': False,
 'min_var_per_w': 25,
 'n_prev_windows': 5,
 'pol_pc': 'both',
 'flip_pc': '1',
 'chrom_plot_w': 1200,
 'chrom_plot_h': 400}

In [19]:
# WINDOWED PCA FROM CALLED GENOTYPES

if mode == 'pca':

    # import relevant modules
    from modules.windowed_pca import gt_wpca
    from modules.data import wpca_data

    # instantiate windowed PCA
    w_pca = gt_wpca(
        variant_file_path = args_dct['variant_file_path'],
        sample_lst = args_dct['sample_lst'],
        chrom = args_dct['chrom'],
        start = args_dct['start'],
        stop = args_dct['end'],
        w_size = args_dct['w_size'],
        w_step = args_dct['w_step'],
        skip_monomorphic=config.skip_monomorphic,
        )
    
    # run
    w_pca.win_vcf_gt()

    # parse run data
    data = wpca_data(args_dct['prefix'], w_pca)


# WINDOWED PCA FROM CALLED GENOTYPES

# elif mode == 'pcangsd':

#     [...]

# EXISTING DATA:

else:
    from modules.data import wpca_data
    data = wpca_data(args_dct['prefix'])


[INFO] Reading data from prefix "test*".


In [20]:
# POLARIZE

# polarize
if mode in ['pca', 'pcangsd', 'polarize'] \
    and not args_dct['polarize'] == 'skip':
    
    from modules.transform_data import Polarize
    polarize = Polarize()

    # adaptive
    if args_dct['polarize'] == 'auto':
        if args_dct['pol_pc'] == 'both':
            data.modify_data(
                'pc_1', polarize.adaptive, args_dct['n_prev_windows']
            )
            data.modify_data(
                'pc_2', polarize.adaptive, args_dct['n_prev_windows']
            )
        else:
            data.modify_data(
                'pc_' + str(args_dct['pol_pc']), polarize.adaptive, \
                    args_dct['n_prev_windows']
            )

    # using guide samples
    if args_dct['polarize'] == 'guide_samples':
        if args_dct['pol_pc'] == 'both':
            data.modify_data(
                'pc_1', polarize.guide_samples, args_dct['guide_sample_lst']
            )
            data.modify_data(
                'pc_2', polarize.guide_samples, args_dct['guide_sample_lst']
            )
        else:
            data.modify_data(
                'pc_' + str(args_dct['pol_pc']), polarize.guide_samples, \
                    args_dct['guide_sample_lst']
            )
else:
    print('skip') ### DELETE



skip


In [21]:
# FLIP

if mode == 'flip':
    from modules.transform_data import Flip
    flip = Flip()
    if args_dct['flip_pc'] == 'both':
        data.modify_data('pc_1', flip.flip_chrom)
        data.modify_data('pc_2', flip.flip_chrom)
    else:
        data.modify_data('pc_' + str(args_dct['flip_pc']), flip.flip_chrom)
else:
    print('skip') ### DELETE

skip


In [22]:
# # WRITE RESULTS

# # create output directory if prefix contains '/'
# if '/' in args_dct['prefix']:
#     if not os.path.exists('/'.join(args_dct['prefix'].split('/')[0:-1]) + '/'):
#         os.makedirs('/'.join(args_dct['prefix'].split('/')[0:-1]) + '/')

# # write results to files
# data.to_files()

In [23]:
# IMPORT PACKAGES
import pandas as pd
import numpy as np

# IMPORT MODULES
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# data.modify_data(
#     'pc_1', polarize.adaptive, 3
# )

plot_var = 'hetp'
data_df = getattr(data, plot_var)
metadata_path = 'test_dataset/input/metadata.tsv'


def annotate(metadata_path, data_df):

    # read metadata and print error message if there are non-unique IDs
    metadata_df = pd.read_csv(metadata_path, sep='\t', index_col=0, dtype=str)
    if len(metadata_df.index) != len(set(metadata_df.index)):
        print('\n[ERROR] The provided metadata file contains non-unique sample'
            ' IDs.', 
            file=sys.stderr)
        sys.exit()

    # then subset and reorder metadata_df to match data_df individuals
    metadata_df = metadata_df.reindex(data_df.columns).dropna()

    # copy index to column 'id'
    metadata_df['id'] = metadata_df.index

    # if individuals are missing in the metadata file print error message
    if len(metadata_df) != len(data_df.columns):
        print('\n[ERROR] One or more sample IDs are missing in the provided.'
            ' metadata file.', 
            file=sys.stderr)
        sys.exit()

    # transpose data_df and name index
    data_df = data_df.T
    data_df.index.name = 'id'

    # add metadata columns to data_df
    for column_name in metadata_df.columns:
        data_df[column_name] = list(metadata_df[column_name])

    # replace numpy NaN with 'NA' for plotting (hover_data display)
    data_df = data_df.replace(np.nan, 'NA')

    # convert to long format for plotting
    data_df = pd.melt(
        data_df,
        id_vars=metadata_df.columns,
        var_name='pos',
        value_name=plot_var,
    )

    return data_df


#
#metadata_df
data_df = annotate(metadata_path, data_df)
data_df


Unnamed: 0,coverage,species,inversion_state,id,pos,hetp
0,20X,species_1,ancestral,ind_1,500000,0.200000
1,21X,species_1,inverted,ind_2,500000,0.200000
2,20X,species_1,heterozygous,ind_3,500000,0.200000
3,19X,species_1,inverted,ind_5,500000,0.400000
4,21X,species_1,ancestral,ind_4,500000,0.280000
...,...,...,...,...,...,...
1747,21X,species_1,inverted,ind_2,29600000,0.255814
1748,20X,species_1,heterozygous,ind_3,29600000,0.310078
1749,19X,species_1,inverted,ind_5,29600000,0.232558
1750,21X,species_1,ancestral,ind_4,29600000,0.279070


In [24]:
stat_df = data.stat

# Create a subplot figure with 2 rows and 1 column
# args_dct['hex_code_dct'] = {
#     'ancestral': '#eb4034',
#     'inverted': '#2f35a8',
#     'heterozygous': '#197d34',
# }
# args_dct['color_by'] = 'inversion_state'

# define chrom_plot_w and chrom_plot_h in class/__init__()



# fetch group_id (=color_by) if specified, else default to 'id'
group_id = args_dct['color_by'] if args_dct['color_by'] else 'id'

# get list of groups/ids
group_lst = list(set(data_df[group_id]))

# define colors based on plotly default colors or specified HEX codes; print
# error messages if HEX codes are missing for specified groups 
if args_dct['hex_code_dct']:
    color_dct = args_dct['hex_code_dct']
    if not all(x in color_dct.keys() for x in group_lst):
        print('\n[ERROR] HEX codes missing for one or more groups.',
                file=sys.stderr)
        sys.exit()
else:
    import plotly.colors as pc
    def_col_lst = pc.DEFAULT_PLOTLY_COLORS
    color_dct = {
        group_lst[i]: \
            def_col_lst[i % len(def_col_lst)] for i in range(len(group_lst))
    }


def plot_pc(ind_var, stat_var, data, group_id):
    '''
    
    '''

    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[1, 6], vertical_spacing=0.0)

    # Panel A: Variance explained
    display_name = \
        '% heterozygous sites' if stat_var == 'hetp' else '% variance explained'
    hover_data = [
        ''.join(
            [f'<b>pos</b>: {str(idx)}<br><b>{display_name}</b>: \
             <b>{row[stat_var]}%<br>' ]
        ) for idx, row in stat_df.iterrows()
    ]

    fig.add_trace(
        go.Scatter(
            x=stat_df.index,
            y=stat_df[stat_var],
            name=display_name,
            legendgroup=display_name,
            mode='lines',
            text=hover_data,
            hoverinfo='text',
            line=dict(color='#4d61b0', width=1),
            fill='tozeroy',
            connectgaps=True,
        ),
    row=1, col=1)


    # Panel B: PC 1 or 2

    # plot each specified group (or else by ID) individually
    for group_name in group_lst:

        # subset data to group
        group_df = pc_df[pc_df[group_id] == group_name]

        # initiate lists to hold per-sample-per-window x values, y values and 
        # hover text strings
        x_vals = []
        y_vals = []
        hover_texts = []

        # iterate through each individual per group        
        for id in set(group_df['id']):

            # subset data to individual
            id_df = group_df[group_df['id'] == id]

            # compile list of hover text strings per window for individual
            hover_data = [
                ''.join(
                    [f'<b>{col}</b>: {row[col]}<br>' for col in id_df.columns]
                ) for i, row in id_df.iterrows()
            ]

            # append x, y, hover values, separated by None to separate lines 
            # plotted as part of the same trace
            x_vals += id_df['pos'].tolist() + [None]  # Add None for a gap
            y_vals.extend(id_df[plot_var].tolist() + [None])  # Add None for a gap
            hover_texts.extend(hover_data + [None])

        # determine plot color
        plot_color = color_dct[group_name]

        # plot
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=y_vals,
                text=hover_texts,
                hoverinfo='text',
                name=list(id_df[group_id])[0],   
                legendgroup=list(id_df[group_id])[0],
                mode='lines',
                line=dict(color=plot_color),
            ),
            row=2, col=1
        )

    # general layout
    fig.update_layout(
        template='simple_white',
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        font_family='Arial',
        font_color='black',
        width=1200,
        height=400,
        legend=dict(font=dict(size=10)),
        )

    # format lines
    fig.update_traces(
        row=1, col=1,
        line=dict(width=.7, color='lightgrey'),
    )
    fig.update_traces(
        row=2, col=1,
        line=dict(width=.7), 
    )

    # format x axis
    fig.update_xaxes(
        row=1, col=1, 
        range=[args_dct['start'], args_dct['end']], 
        linewidth=1, 
        side='top',  mirror=False, 
        ticks='', showticklabels=False, 
    )
    fig.update_xaxes(
        row=2, col=1,
        range=[args_dct['start'], args_dct['end']],
        linewidth=1,
        side='bottom', mirror=True, 
        ticks='outside', tickfont=dict(size=10), tickformat=',.0f', 
        title_font=dict(size=12), 
        title=dict(text='<b>Genomic position (bp)', standoff=10))

    # format y axis
    fig.update_yaxes(
        row=1, col=1, 
        linewidth=1, 
        side='left', mirror=True, 
        ticks='outside', tickfont=dict(size=10), 
    )
    fig.update_yaxes(
        row=2, col=1, 
        linewidth=1, 
        side='left', mirror=True, 
        ticks='outside', tickfont=dict(size=10), 
        title_font=dict(size=12),  
        title=dict(text='<b>PC 1', standoff=0),
    )

    return fig

fig = plot_pc(data_df, stat_df, plot_var, group_id)
fig.show()



In [25]:
'''
Plot PC and associated data chromosome- and genome-wide.
'''

# IMPORT PACKAGES
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# IMPORT MODULES

class Plot:

    def __init__(self, name, data, plot_var, stat_var, metadata_path=None, 
                 color_by=None, hex_code_dct=None):

        # input variables
        self.name = name
        self.data = data
        self.plot_var = plot_var
        self.stat_var = stat_var
        self.metadata_path = metadata_path
        self.color_by = color_by
        self.hex_code_dct = hex_code_dct

        # instance variables
        self.anno_df = None
        self.group_id
        self.group_lst = None
        self.color_dct = None
        self.fig = None
    

    def annotate(self):
        '''
        Annotate per-sample data with a metadata file.
        '''
        
        # fetch per-sample data
        data_df = getattr(self.data, self.plot_var)

        # read metadata and print error message if there are non-unique IDs
        metadata_df = pd.read_csv(
            self.metadata_path, sep='\t', index_col=0, dtype=str
        )
        if len(metadata_df.index) != len(set(metadata_df.index)):
            print('\n[ERROR] The provided metadata file contains non-unique'
                  ' sample IDs.', 
                  file=sys.stderr)
            sys.exit()

        # subset and reorder metadata_df to match data_df individuals
        metadata_df = metadata_df.reindex(data_df.columns).dropna()

        # copy index to column 'id'
        metadata_df['id'] = metadata_df.index

        # if individuals are missing in the metadata file print error message
        if len(metadata_df) != len(data_df.columns):
            print('\n[ERROR] One or more sample IDs are missing in the provided.'
                ' metadata file.', 
                file=sys.stderr)
            sys.exit()

        # transpose data_df and name index
        self.anno_df = data_df.T
        self.anno_df.index.name = 'id'

        # add metadata columns to data_df
        for column_name in metadata_df.columns:
            self.anno_df[column_name] = list(metadata_df[column_name])

        # replace numpy NaN with 'NA' for plotting (hover_data display)
        self.anno_df = self.anno_df.replace(np.nan, 'NA')

        # convert to long format for plotting
        self.anno_df = pd.melt(
            self.anno_df,
            id_vars=metadata_df.columns,
            var_name='pos',
            value_name=plot_var,
        )
        
    def set_colors(self):
        '''
        Parse per-sample plot color specifications and compile color_dct.
        '''

        # fetch group_id (=color_by) if specified, else default to 'id'
        self.group_id = self.color_by if self.color_by else 'id'

        # get list of groups/ids
        self.group_lst = list(set(self.anno_df[self.group_id]))

        # define colors based on plotly default colors or specified HEX codes; 
        # print error messages if HEX codes are missing for specified groups 
        if self.hex_code_dct:
            self.color_dct = self.hex_code_dct
            if not all(x in self.color_dct.keys() for x in self.group_lst):
                print('\n[ERROR] HEX codes missing for one or more groups.',
                        file=sys.stderr)
                sys.exit()
        else:
            import plotly.colors as pc
            def_col_lst = pc.DEFAULT_PLOTLY_COLORS
            self.color_dct = {
                self.group_lst[i]: def_col_lst[i % len(def_col_lst)] \
                    for i in range(len(self.group_lst))
            }


    def chrom_plot(self):
        '''
        [...]
        '''

        # SETUP

        # annotate per-sample data
        self.annotate()

        # set per-sample plot colors
        self.set_colors()


        # TOP PANEL
        
        # initiate subplot figure
        self.fig = make_subplots(
            rows=2, cols=1, 
            shared_xaxes=True, 
            row_heights=[1, 6], 
            vertical_spacing=0.0
        )

        # fetch stats data
        stat_df = getattr(self.data, self.stat_var)

        # parse display name for top panel: variance explained or n of sites
        display_name = \
            '% heterozygous sites' if self.stat_var == 'hetp' else \
            '% variance explained'
        
        # compile per-window hover data strings
        hover_data = [
            ''.join(
                [f'<b>pos</b>: {str(idx)}<br><b>{display_name}</b>: \
                <b>{row[self.stat_var]}%<br>' ]
            ) for idx, row in stat_df.iterrows()
        ]

        # plot
        fig.add_trace(
            go.Scatter(
                x=stat_df.index,
                y=stat_df[self.stat_var],
                name=display_name,
                legendgroup=display_name,
                mode='lines',
                text=hover_data,
                hoverinfo='text',
                line=dict(color='#4d61b0', width=1),
                fill='tozeroy',
                connectgaps=True,
            ),
        row=1, col=1)


        # BOTTOM PANEL

        # plot each specified group (or else by ID) individually
        for group in group_lst:

            # subset data to group
            group_df = self.anno_df[self.anno_df[self.group_id] == group]

            # initiate lists to hold per-sample-per-window x values, y values 
            # and hover data strings
            x_val_lst = []
            y_val_lst = []
            hover_str_lst = []

            # iterate through each individual per group        
            for id in set(group_df['id']):

                # subset data to individual
                id_df = group_df[group_df['id'] == id]

                # compile list of hover text strings per window for individual
                hover_data = [
                    ''.join(
                        [f'<b>{c}</b>: {row[c]}<br>' for c in id_df.columns]
                    ) for i, row in id_df.iterrows()
                ]

                # append x, y, hover values, separated by None to separate lines 
                # plotted as part of the same trace
                x_val_lst += id_df['pos'].tolist() + [None]
                y_val_lst += id_df[plot_var].tolist() + [None]
                hover_str_lst += hover_data + [None]

            # determine plot color
            plot_color = self.color_dct[group]

            # plot
            fig.add_trace(
                go.Scatter(
                    x=x_val_lst,
                    y=y_val_lst,
                    text=hover_str_lst,
                    hoverinfo='text',
                    name=list(id_df[self.group_id])[0],   
                    legendgroup=list(id_df[self.group_id])[0],
                    mode='lines',
                    line=dict(color=plot_color),
                ),
                row=2, col=1
            )

        # general layout
        fig.update_layout(
            template='simple_white',
            paper_bgcolor='rgba(0,0,0,0)',
            plot_bgcolor='rgba(0,0,0,0)',
            font_family='Arial',
            font_color='black',
            width=1200,
            height=400,
            legend=dict(font=dict(size=10)),
            )

        # format lines
        fig.update_traces(
            row=1, col=1,
            line=dict(width=.7, color='lightgrey'),
        )
        fig.update_traces(
            row=2, col=1,
            line=dict(width=.7), 
        )

        # format x axis
        fig.update_xaxes(
            row=1, col=1, 
            range=[args_dct['start'], args_dct['end']], 
            linewidth=1, 
            side='top',  mirror=False, 
            ticks='', showticklabels=False, 
        )
        fig.update_xaxes(
            row=2, col=1,
            range=[args_dct['start'], args_dct['end']],
            linewidth=1,
            side='bottom', mirror=True, 
            ticks='outside', tickfont=dict(size=10), tickformat=',.0f', 
            title_font=dict(size=12), 
            title=dict(text='<b>Genomic position (bp)', standoff=10))

        # format y axis
        fig.update_yaxes(
            row=1, col=1, 
            linewidth=1, 
            side='left', mirror=True, 
            ticks='outside', tickfont=dict(size=10), 
        )
        fig.update_yaxes(
            row=2, col=1, 
            linewidth=1, 
            side='left', mirror=True, 
            ticks='outside', tickfont=dict(size=10), 
            title_font=dict(size=12),  
            title=dict(text='<b>PC 1', standoff=0),
        )
    


In [14]:
if args_dct['hex_code_dct']:
    print('hello')

In [217]:
import plotly.graph_objects as go

# Function to create a plot for a specific chromosome
def create_plot(chromosome_name, x_data, y_data, chromosome_size, pixel_per_mbp):
    # Create a figure
    fig = go.Figure()

    # Add a scatter trace for the current chromosome
    fig.add_trace(go.Scatter(x=x_data, y=y_data, mode='lines+markers', name=chromosome_name))

    # Update layout
    fig.update_layout(
        xaxis=dict(
            title='Position on Chromosome (Mbp)',
            range=[0, chromosome_size / 1000000],  # Set x-axis limits in Mbp
            tickvals=[0, chromosome_size / 2000000, chromosome_size / 1000000],  # Custom tick positions
            ticktext=["0 Mbp", f"{chromosome_size / 2000000:.1f} Mbp", f"{chromosome_size / 1000000:.1f} Mbp"],  # Custom tick labels
        ),
        yaxis=dict(title='Value'),
        width=int(chromosome_size / 1000000 * pixel_per_mbp),  # Adjust width based on chromosome size
        height=400,  # Fixed height
        margin=dict(l=40, r=40, t=40, b=40),  # Adjust margins
        legend=dict(
            x=0.0,  # Align the legend to the left
            y=1.1,  # Position above the plot
            yanchor='bottom',  # Anchor to the bottom of the legend
            xanchor='left',    # Anchor to the left of the legend
            orientation='h'    # Horizontal legend
        ),
    )

    return fig

# Example chromosome sizes in base pairs
chromosome_sizes = {
    "Chromosome 1": 20000000,  # 20 Mbp
    "Chromosome 2": 15000000,  # 15 Mbp
    "Chromosome 3": 30000000,  # 30 Mbp
}

# Example data for plots
data = {
    "Chromosome 1": ([1000000, 1500000, 1800000], [1, 3, 2]),
    "Chromosome 2": ([500000, 1200000, 1400000], [2, 4, 3]),
    "Chromosome 3": ([2000000, 2500000, 2800000], [3, 5, 4]),
}

# Set a fixed number of pixels per Mbp for scaling
pixel_per_mbp = 50  # Example: 100 pixels per Mbp

# Create and save plots for each chromosome
for chrom, size in chromosome_sizes.items():
    x_data, y_data = data[chrom]
    fig = create_plot(chrom, x_data, y_data, size, pixel_per_mbp)

    # Show the figure (if desired)
    fig.write_image(chrom + '.png')

In [146]:
color_dct

{'ancestral': '#eb4034', 'inverted': '#2f35a8', 'heterozygous': '#197d34'}

In [150]:
def_col_lst = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)']
group_lst = ['inverted', 'ancestral', 'heterozygous']
color_dct = {}
color_dct = {
    group_lst[i]: \
        def_col_lst[i % len(def_col_lst)] for i in range(len(group_lst))
}
color_dct

{}

In [148]:
def_col_lst

['rgb(31, 119, 180)',
 'rgb(255, 127, 14)',
 'rgb(44, 160, 44)',
 'rgb(214, 39, 40)',
 'rgb(148, 103, 189)',
 'rgb(140, 86, 75)',
 'rgb(227, 119, 194)',
 'rgb(127, 127, 127)',
 'rgb(188, 189, 34)',
 'rgb(23, 190, 207)']

In [149]:
group_lst

['inverted', 'ancestral', 'heterozygous']