In [15]:
#!/usr/bin/env python3
#
# Moritz Blumer, 2024

"""
WinPCA. A package for windowed PCA analysis.
"""

## IMPORT PACKAGES
import sys
#import pandas as pd


## IMPORT MODULES
from modules import config
from modules.cli import CLI
# from modules.error_handling import ErrorHandling ## add

In [20]:
## DELETE CELL

# simulate command line arguments
# command_line = 'winpca.ipynb pca -n test --vcf test_dataset/input/sample.vcf.gz -s ind_1,ind_2,ind_3,ind_4,ind_5,ind_6 -r chr1:1-33000000 -w 1000000 -i 100000 -m 0.01 -p auto'
command_line = 'winpca.ipynb pca -n test --vcf test_dataset/input/sample.vcf.gz -s samples.tsv -p auto -r chr1:1-30000000 -w 1000000 -i 100000 -m 0.01'
# command_line = 'winpca.ipynb polarize -n test -p auto -c 2'
command_line = 'winpca.ipynb chromplot -n test -r chr1:1-30000000 -m test_dataset/input/metadata.tsv -i 10 -f PDF,html -g inversion_state -c ancestral:eb4034,inverted:2f35a8,heterozygous:197d34'
# command_line = 'winpca.ipynb pca -h'
sys.argv = command_line.split(' ')
sys.argv


['winpca.ipynb',
 'chromplot',
 '-n',
 'test',
 '-r',
 'chr1:1-30000000',
 '-m',
 'test_dataset/input/metadata.tsv',
 '-i',
 '10',
 '-f',
 'PDF,html',
 '-g',
 'inversion_state',
 '-c',
 'ancestral:eb4034,inverted:2f35a8,heterozygous:197d34']

In [21]:
## PARSE COMMAND LINE ARGUMENTS

# instantiate
cli = CLI()

# call subparsers
cli.pca()
cli.pcangsd()
cli.polarize()
cli.flip()
cli.chromplot()
cli.genomeplot()

# parse argments
cli.parse_args()
args_dct = cli.args_dct

# enter mode
mode = args_dct['winpca']
args_dct

Namespace(color_by='inversion_state', hex_codes='ancestral:eb4034,inverted:2f35a8,heterozygous:197d34', interval=10, metadata_path='test_dataset/input/metadata.tsv', plot_fmt=['pdf', 'html'], prefix='test', region='chr1:1-30000000', winpca='chromplot')


{'winpca': 'chromplot',
 'prefix': 'test',
 'region': 'chr1:1-30000000',
 'metadata_path': 'test_dataset/input/metadata.tsv',
 'color_by': 'inversion_state',
 'hex_codes': 'ancestral:eb4034,inverted:2f35a8,heterozygous:197d34',
 'interval': 10,
 'plot_fmt': ['pdf', 'html'],
 'chrom': 'chr1',
 'start': 1,
 'end': 30000000,
 'hex_code_dct': {'ancestral': '#eb4034',
  'inverted': '#2f35a8',
  'heterozygous': '#197d34'},
 'skip_monomorphic': False,
 'min_var_per_w': 25,
 'n_prev_windows': 5,
 'pol_pc': 'both',
 'flip_pc': '1',
 'chrom_plot_w': 1200,
 'chrom_plot_h': 400}

In [9]:
# WINDOWED PCA FROM CALLED GENOTYPES

if mode == 'pca':

    # import relevant modules
    from modules.windowed_pca import gt_wpca
    from modules.data import wpca_data

    # instantiate windowed PCA
    w_pca = gt_wpca(
        variant_file_path = args_dct['variant_file_path'],
        sample_lst = args_dct['sample_lst'],
        chrom = args_dct['chrom'],
        start = args_dct['start'],
        stop = args_dct['end'],
        w_size = args_dct['w_size'],
        w_step = args_dct['w_step'],
        skip_monomorphic=config.skip_monomorphic,
        )
    
    # run
    w_pca.win_vcf_gt()

    # parse run data
    data = wpca_data(args_dct['prefix'], w_pca)


# WINDOWED PCA FROM CALLED GENOTYPES

# elif mode == 'pcangsd':

#     [...]

# EXISTING DATA:

else:
    from modules.data import wpca_data
    data = wpca_data(args_dct['prefix'])

[INFO] Processed 1 of 291 windows
[INFO] Processed 2 of 291 windows
[INFO] Processed 3 of 291 windows
[INFO] Processed 4 of 291 windows
[INFO] Processed 5 of 291 windows
[INFO] Processed 6 of 291 windows
[INFO] Processed 7 of 291 windows
[INFO] Processed 8 of 291 windows
[INFO] Processed 9 of 291 windows
[INFO] Processed 10 of 291 windows
[INFO] Processed 11 of 291 windows
[INFO] Processed 12 of 291 windows
[INFO] Processed 13 of 291 windows
[INFO] Processed 14 of 291 windows
[INFO] Processed 15 of 291 windows
[INFO] Processed 16 of 291 windows
[INFO] Processed 17 of 291 windows
[INFO] Processed 18 of 291 windows
[INFO] Processed 19 of 291 windows
[INFO] Processed 20 of 291 windows
[INFO] Processed 21 of 291 windows
[INFO] Processed 22 of 291 windows
[INFO] Processed 23 of 291 windows
[INFO] Processed 24 of 291 windows
[INFO] Processed 25 of 291 windows
[INFO] Processed 26 of 291 windows
[INFO] Processed 27 of 291 windows
[INFO] Processed 28 of 291 windows
[INFO] Processed 29 of 291 wi

In [5]:
# POLARIZE

# polarize
if mode in ['pca', 'pcangsd', 'polarize'] \
    and not args_dct['polarize'] == 'skip':
    
    from modules.transform_data import Polarize
    polarize = Polarize()

    # adaptive
    if args_dct['polarize'] == 'auto':
        if args_dct['pol_pc'] == 'both':
            data.modify_data(
                'pc_1', polarize.adaptive, args_dct['n_prev_windows']
            )
            data.modify_data(
                'pc_2', polarize.adaptive, args_dct['n_prev_windows']
            )
        else:
            data.modify_data(
                'pc_' + str(args_dct['pol_pc']), polarize.adaptive, \
                    args_dct['n_prev_windows']
            )

    # using guide samples
    if args_dct['polarize'] == 'guide_samples':
        if args_dct['pol_pc'] == 'both':
            data.modify_data(
                'pc_1', polarize.guide_samples, args_dct['guide_sample_lst']
            )
            data.modify_data(
                'pc_2', polarize.guide_samples, args_dct['guide_sample_lst']
            )
        else:
            data.modify_data(
                'pc_' + str(args_dct['pol_pc']), polarize.guide_samples, \
                    args_dct['guide_sample_lst']
            )
else:
    print('skip') ### DELETE



In [8]:
# FLIP

if mode == 'flip':
    from modules.transform_data import Flip
    flip = Flip()
    if args_dct['flip_pc'] == 'both':
        data.modify_data('pc_1', flip.flip_chrom)
        data.modify_data('pc_2', flip.flip_chrom)
    else:
        data.modify_data('pc_' + str(args_dct['flip_pc']), flip.flip_chrom)
else:
    print('skip') ### DELETE

skip


In [82]:
# # WRITE RESULTS

# # create output directory if prefix contains '/'
# if '/' in args_dct['prefix']:
#     if not os.path.exists('/'.join(args_dct['prefix'].split('/')[0:-1]) + '/'):
#         os.makedirs('/'.join(args_dct['prefix'].split('/')[0:-1]) + '/')

# # write results to files
# data.to_files()

In [83]:
args_dct

{'winpca': 'chromplot',
 'prefix': 'test',
 'region': 'chr1:1-30000000',
 'metadata_path': 'test_dataset/input/metadata.tsv',
 'color_by': 'inversion_state',
 'hex_codes': 'ancestral:eb4034,inverted:2f35a8,heterozygous:197d34',
 'interval': None,
 'file_format': 'both',
 'chrom': 'chr1',
 'start': 1,
 'end': 30000000,
 'hex_code_dct': {'ancestral': '#eb4034',
  'inverted': '#2f35a8',
  'heterozygous': '#197d34'},
 'skip_monomorphic': False,
 'min_var_per_w': 25,
 'n_prev_windows': 5,
 'pol_pc': 'both',
 'flip_pc': '1',
 'chrom_plot_w': 1200,
 'chrom_plot_h': 400}

In [25]:
'''
Plot PC and associated data chromosome- and genome-wide.
'''

# IMPORT PACKAGES
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# IMPORT MODULES

class Plot:

    def __init__(self, name, data, plot_var, stat_var, metadata_path=None, 
                 color_by=None, hex_code_dct=None, 
                 interval=config.plot_interval, 
                 chrom_plot_w=config.chrom_plot_w, 
                 chrom_plot_h=config.chrom_plot_h, 
                 plot_fmt=[config.plot_fmt], ):

        # input variables
        self.name = name
        self.data = data
        self.plot_var = plot_var
        self.stat_var = stat_var
        self.metadata_path = metadata_path
        self.color_by = color_by
        self.hex_code_dct = hex_code_dct
        self.interval = interval
        self.chrom_plot_w = chrom_plot_w
        self.chrom_plot_h = chrom_plot_h
        self.plot_fmt = plot_fmt



        # instance variables
        self.data_df = None
        self.metadata_df = None
        self.group_id = None
        self.group_lst = None
        self.color_dct = None
        self.fig = None

        # initiate data
        self.data_df = getattr(self.data, self.plot_var)
        self.stat_df = getattr(self.data, 'stat')
    

    @staticmethod
    def subset(df, interval):
        '''
        Subset a dataframe to the specified interval.
        '''

        return df.iloc[::interval, :]


    def annotate(self):
        '''
        Annotate per-sample data with a metadata file if supplied, otherwise
        just reformat for plotting function.
        '''
        
        # fetch sample names and order (=data_df column names)
        sample_lst = list(self.data_df.columns)
        # transpose data_df and copy index to column 'id'
        self.data_df = self.data_df.T
        #self.data_df.index.name = 'id'

        # initiate id_vars to be filled if metadata are supplied
        id_var_lst = ['id']

        # read metadata if provided and do sanity checks
        if self.metadata_path:

            # read metadata and print error message if there are non-unique IDs
            self.metadata_df = pd.read_csv(
                self.metadata_path, sep='\t', index_col=0, dtype=str
            )
            if len(self.metadata_df.index) != len(set(self.metadata_df.index)):
                print('\n[ERROR] The provided metadata file contains non-unique'
                    ' sample IDs.', 
                    file=sys.stderr)
                sys.exit()

            # subset and reorder metadata_df to match data_df individuals
            self.metadata_df = self.metadata_df.reindex(sample_lst).dropna()

            # if individuals are missing in the metadata print error message
            if len(self.metadata_df) != len(sample_lst):
                print('\n[ERROR] One or more sample IDs are missing in the'
                      ' provided metadata file.', 
                    file=sys.stderr)
                sys.exit()

            # add metadata columns to data_df
            for column_name in self.metadata_df.columns:
                self.data_df[column_name] = list(self.metadata_df[column_name])

            # add metadata column names to id_var_lst
            id_var_lst += list(self.metadata_df.columns)

        # copy id from index to column
        self.data_df['id'] = list(self.data_df.index)

        # replace numpy NaN with 'NA' for plotting (hover_data display)
        self.data_df = self.data_df.replace(np.nan, 'NA')

        # convert to long format for plotting
        self.data_df = pd.melt(
            self.data_df,
            id_vars=id_var_lst,
            var_name='pos',
            value_name=self.plot_var,
        )

        
    def set_colors(self):
        '''
        Parse per-sample plot color specifications and compile color_dct.
        '''

        # fetch group_id (=color_by) if specified, else default to 'id'
        self.group_id = self.color_by if not self.color_by == None else 'id'

        # get list of groups/ids
        self.group_lst = list(set(self.data_df[self.group_id]))

        # define colors based on plotly default colors or specified HEX codes; 
        # print error messages if HEX codes are missing for specified groups 
        if self.hex_code_dct:
            self.color_dct = self.hex_code_dct
            if not all(x in self.color_dct.keys() for x in self.group_lst):
                print('\n[ERROR] HEX codes missing for one or more groups.',
                        file=sys.stderr)
                sys.exit()
        else:
            import plotly.colors as pc
            def_col_lst = pc.DEFAULT_PLOTLY_COLORS
            self.color_dct = {
                self.group_lst[i]: def_col_lst[i % len(def_col_lst)] \
                    for i in range(len(self.group_lst))
            }

    def savefig(self):
        '''
        Save figure in HTML and/or PDF format.
        '''
        for fmt in self.plot_fmt:
            if fmt == 'html':
                self.fig.write_html(self.name + '.' + fmt)
            else:
                self.fig.write_image(self.name + '.' + fmt)
            

    def chromplot(self):
        '''
        [...]
        '''

        # SETUP

        # subset if interval (-i) is specified
        if self.interval:
            self.data_df = self.subset(self.data_df, self.interval)
            self.stat_df = self.subset(self.stat_df, self.interval)

        # annotate per-sample data if metadata were supplied
        self.annotate()

        # set per-sample plot colors
        self.set_colors()

        # figure setup
        self.fig = make_subplots(
            rows=2, cols=1, 
            row_heights=[1, 6],
            vertical_spacing=0.0,
            shared_xaxes=True, 
            )
        

        # TOP PANEL

        # parse display name for top panel: variance explained or n of sites
        display_name = \
            '% heterozygous sites' if self.stat_var == 'hetp' else \
            '% variance explained'
        
        # compile per-window hover data strings
        hover_data = [
            ''.join(
                [f'<b>pos</b>: {str(idx)}<br><b>{display_name}</b>: \
                <b>{row[self.stat_var]}%<br>' ]
            ) for idx, row in self.stat_df.iterrows()
        ]

        # plot
        self.fig.add_trace(
            go.Scatter(
                x=self.stat_df.index,
                y=self.stat_df[self.stat_var],
                name=display_name,
                legendgroup=display_name,
                mode='lines',
                text=hover_data,
                hoverinfo='text',
                line=dict(color='#4d61b0', width=1),
                fill='tozeroy',
                connectgaps=True,
            ),
        row=1, col=1)


        # BOTTOM PANEL

        # plot each specified group (or else by ID) individually
        for group in self.group_lst:

            # subset data to group
            group_df = self.data_df[self.data_df[self.group_id] == group]

            # initiate lists to hold per-sample-per-window x values, y values 
            # and hover data strings
            x_val_lst = []
            y_val_lst = []
            hover_str_lst = []

            # iterate through each individual per group        
            for id in set(group_df['id']):

                # subset data to individual
                id_df = group_df[group_df['id'] == id]

                # compile list of hover text strings per window for individual
                hover_data = [
                    ''.join(
                        [f'<b>{c}</b>: {row[c]}<br>' for c in id_df.columns]
                    ) for i, row in id_df.iterrows()
                ]

                # append x, y, hover values, separated by None to separate lines 
                # plotted as part of the same trace
                x_val_lst += id_df['pos'].tolist() + [None]
                y_val_lst += id_df[self.plot_var].tolist() + [None]
                hover_str_lst += hover_data + [None]

            # determine plot color
            plot_color = self.color_dct[group]

            # plot
            self.fig.add_trace(
                go.Scatter(
                    x=x_val_lst,
                    y=y_val_lst,
                    text=hover_str_lst,
                    hoverinfo='text',
                    name=list(id_df[self.group_id])[0],   
                    legendgroup=list(id_df[self.group_id])[0],
                    mode='lines',
                    line=dict(color=plot_color),
                ),
                row=2, col=1
            )

        # general layout
        self.fig.update_layout(
            template='simple_white',
            paper_bgcolor='rgba(0,0,0,0)',
            plot_bgcolor='rgba(0,0,0,0)',
            font_family='Arial',
            font_color='black',
            width=self.chrom_plot_w,
            height=self.chrom_plot_h,
            legend=dict(font=dict(size=10)),
            )

        # format lines
        self.fig.update_traces(
            row=1, col=1,
            line=dict(width=.7, color='lightgrey'),
        )
        self.fig.update_traces(
            row=2, col=1,
            line=dict(width=.7), 
        )

        # format x axis
        self.fig.update_xaxes(
            row=1, col=1, 
            range=[args_dct['start'], args_dct['end']], 
            linewidth=1, 
            side='top',  mirror=False, 
            ticks='', showticklabels=False, 
        )
        self.fig.update_xaxes(
            row=2, col=1,
            range=[args_dct['start'], args_dct['end']],
            linewidth=1,
            side='bottom', mirror=True, 
            ticks='outside', tickfont=dict(size=10), tickformat=',.0f', 
            title_font=dict(size=12), 
            title=dict(text='<b>Genomic position (bp)', standoff=10))

        # format y axis
        self.fig.update_yaxes(
            row=1, col=1, 
            linewidth=1, 
            side='left', mirror=True, 
            ticks='outside', tickfont=dict(size=10), 
        )
        self.fig.update_yaxes(
            row=2, col=1, 
            linewidth=1, 
            side='left', mirror=True, 
            ticks='outside', tickfont=dict(size=10), 
            title_font=dict(size=12),  
            title=dict(text='<b>PC 1', standoff=0),
        )

        # save image
        self.savefig()


args_dct['interval'] = 3
plot = Plot('test', data, 'pc_1', 'pc_1_ve', color_by=args_dct['color_by'], 
            metadata_path=args_dct['metadata_path'], interval=args_dct['interval'], plot_fmt=args_dct['plot_fmt'])
plot.chromplot()
fig = plot.fig
fig.show()

In [14]:
if args_dct['hex_code_dct']:
    print('hello')

In [217]:
import plotly.graph_objects as go

# Function to create a plot for a specific chromosome
def create_plot(chromosome_name, x_data, y_data, chromosome_size, pixel_per_mbp):
    # Create a figure
    fig = go.Figure()

    # Add a scatter trace for the current chromosome
    fig.add_trace(go.Scatter(x=x_data, y=y_data, mode='lines+markers', name=chromosome_name))

    # Update layout
    fig.update_layout(
        xaxis=dict(
            title='Position on Chromosome (Mbp)',
            range=[0, chromosome_size / 1000000],  # Set x-axis limits in Mbp
            tickvals=[0, chromosome_size / 2000000, chromosome_size / 1000000],  # Custom tick positions
            ticktext=["0 Mbp", f"{chromosome_size / 2000000:.1f} Mbp", f"{chromosome_size / 1000000:.1f} Mbp"],  # Custom tick labels
        ),
        yaxis=dict(title='Value'),
        width=int(chromosome_size / 1000000 * pixel_per_mbp),  # Adjust width based on chromosome size
        height=400,  # Fixed height
        margin=dict(l=40, r=40, t=40, b=40),  # Adjust margins
        legend=dict(
            x=0.0,  # Align the legend to the left
            y=1.1,  # Position above the plot
            yanchor='bottom',  # Anchor to the bottom of the legend
            xanchor='left',    # Anchor to the left of the legend
            orientation='h'    # Horizontal legend
        ),
    )

    return fig

# Example chromosome sizes in base pairs
chromosome_sizes = {
    "Chromosome 1": 20000000,  # 20 Mbp
    "Chromosome 2": 15000000,  # 15 Mbp
    "Chromosome 3": 30000000,  # 30 Mbp
}

# Example data for plots
data = {
    "Chromosome 1": ([1000000, 1500000, 1800000], [1, 3, 2]),
    "Chromosome 2": ([500000, 1200000, 1400000], [2, 4, 3]),
    "Chromosome 3": ([2000000, 2500000, 2800000], [3, 5, 4]),
}

# Set a fixed number of pixels per Mbp for scaling
pixel_per_mbp = 50  # Example: 100 pixels per Mbp

# Create and save plots for each chromosome
for chrom, size in chromosome_sizes.items():
    x_data, y_data = data[chrom]
    fig = create_plot(chrom, x_data, y_data, size, pixel_per_mbp)

    # Show the figure (if desired)
    fig.write_image(chrom + '.png')

In [146]:
color_dct

{'ancestral': '#eb4034', 'inverted': '#2f35a8', 'heterozygous': '#197d34'}

In [150]:
def_col_lst = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)']
group_lst = ['inverted', 'ancestral', 'heterozygous']
color_dct = {}
color_dct = {
    group_lst[i]: \
        def_col_lst[i % len(def_col_lst)] for i in range(len(group_lst))
}
color_dct

{}