# Periodicity Metrics
This notebook aims to explore and compare a number of metrics that may be used to the strength of the 3 nucleotide periodicity expected from Ribo-Seq data. 

There are two classes of metric explored here. Those that look at the degree to which decoding centers of Ribo-Seq reads map to the coding reading frame and those that look at the frequency of the Ribo-Seq signal.

## Setup

Install required packages

In [1]:
!pip uninstall RiboMetric -y
!pip install git+https://github.com/JackCurragh/RiboMetric.git -q 
!pip install plotly pandas numpy scikit-learn -q

Found existing installation: RiboMetric 0.1.9
Uninstalling RiboMetric-0.1.9:
  Successfully uninstalled RiboMetric-0.1.9


Import required packages

In [4]:
from RiboMetric.metrics import (
    fourier_transform,
    multitaper,
    read_frame_information_content,
    periodicity_dominance,
    periodicity_autocorrelation,
    
    )

from RiboMetric.modules import (
    read_frame_score_trips_viz
)

import plotly.graph_objects as go
import pandas as pd

from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, MaxAbsScaler


In [5]:
!pip install multitaper



In [6]:

import numpy as np

from scipy import signal, stats

import numpy as np
import multitaper.mtspec as spec

def compute_multitaper_spectrum(metagene_profile, read_lengths=[28, 29, 30, 31, 32], nw=4.5, kspec=6, twin=40.0, olap=0.8, fmax=25.0):
    """
    Compute the multitaper spectrum and periodicity score for the metagene profile.

    Args:
        metagene_profile (dict): The metagene profile to compute the multitaper transform of.
        read_lengths (list, optional): The list of read lengths to calculate the multitaper transform for.
        nw (float, optional): The time-bandwidth product.
        kspec (int, optional): The number of tapers to use.
        twin (float, optional): The window length.
        olap (float, optional): The overlap between segments.
        fmax (float, optional): The maximum frequency to consider.

    Returns:
        dict: The multitaper spectrum and periodicity score for each read length and the global aggregated counts.
    """
    multitaper_scores = {}
    max_length = max(len(counts) for counts in metagene_profile['start'].values())
    global_counts = [0] * max_length
    dt = 1.0  # Assuming a time step of 1.0 for simplicity

    for read_len in read_lengths:
        print(f"Running: {read_len}/{max(read_lengths)}")
        counts = list(metagene_profile['start'][read_len].values())
        noisy_signal = np.array(counts)

        # Compute the multitaper spectrum using Prieto's method
        t_vals, freq_vals, quad_vals, thomp_vals = spec.spectrogram(noisy_signal, dt, twin=twin, olap=olap, nw=nw, fmax=fmax, kspec=kspec)

        # Compute the maximum quadratic spectrum values
        max_qi_vals = np.max(quad_vals, axis=1)
        log_psd_db = 10 * np.log10(max_qi_vals)

        # Compute the periodicity score
        max_psd = np.max(log_psd_db)
        periodicity_score = max_psd / np.sum(log_psd_db)
        multitaper_scores[read_len] = periodicity_score

        # # Sum the counts at each position for the current read length
        # for i, count in enumerate(counts):
        #     global_counts[i] += count

    # Compute the multitaper spectrum and periodicity score for the global aggregated counts
    noisy_signal = np.array(global_counts)
    t_vals, freq_vals, quad_vals, thomp_vals = spec.spectrogram(noisy_signal, dt, twin=twin, olap=olap, nw=nw, fmax=fmax, kspec=kspec)
    max_qi_vals = np.max(quad_vals, axis=1)
    log_psd_db = 10 * np.log10(max_qi_vals)
    max_psd = np.max(log_psd_db)
    global_periodicity_score = max_psd / np.sum(log_psd_db)
    multitaper_scores["global"] = global_periodicity_score

    return multitaper_scores

## Simulate Inputs 
Generate inputs for each metric for an array of different periodicity strengths. Signal based metric work off a metagene and the frame metrics work off a read frame dictionary that counts the number of P sites in each frame of each CDS relative to the start codon (initiation codon is always frame 0). Read frame dictionary can be calculated from the metagene.

### Simulated Read Frame Dict

In [7]:
import itertools

def generate_read_frame_distribution_permutations():
    """
    Generate all possible permutations of read frame distributions that sum to 100.
    
    Args:
        None

    Returns:
        simulated_read_frame_proportions (dict): Dictionary containing all possible read frame distributions that sum to 100.
    """
    # Generate permutations
    numbers = range(1, 101)
    permutations = []
    for perm in itertools.permutations(numbers, 3):
        if sum(perm) == 100:
            permutations.append(perm)

    # Simulate read frame proportions
    simulated_read_frame_proportions = {}
    for i, perm in enumerate(permutations):
        simulated_read_frame_proportions[i] = {0: perm[0], 1: perm[1], 2: perm[2]}

    return simulated_read_frame_proportions


### Simulated Metagene

In [8]:
import random

def generate_metagene(frame_ratios, start, stop, noise_factor=0.8, max_count=100):
    """
    Generate a metagene dictionary with varying degrees of periodicity based on frame ratios.
    
    Args:
        frame_ratios (dict): A dictionary representing the global ratio of reads per frame.
                             Keys are frame positions (0, 1, 2), and values are the corresponding ratios.
        start (int): The start position of the metagene.
        stop (int): The stop position of the metagene.
    
    Returns:
        dict: A dictionary representing the metagene, where keys are positions and values are counts.
    """
    metagene = {}
    
    # Initialize all positions with a small constant value
    for pos in range(start, stop):
        metagene[pos] = 1
    
    # Adjust counts based on frame ratios
    for pos in range(start, stop):
        frame = (pos - start) % 3
        metagene[pos] += int(frame_ratios[frame] * max_count)
    
    # Introduce some random noise
    for pos in range(start, stop):
        metagene[pos] += int(random.uniform(0, noise_factor * max_count))
    
    return metagene


In [9]:
start = 10
stop = 110
frame_ratios = {0: 0.98, 1: 0.01, 2: 0.01}  # High periodicity
metagene = generate_metagene(frame_ratios, start, stop)

# Get the positions and counts from the metagene dictionary
positions = list(metagene.keys())
counts = list(metagene.values())

# Create the bar plot
fig = go.Figure(data=go.Bar(x=positions, y=counts))

# Set the axis labels
fig.update_layout(xaxis_title='Position', yaxis_title='Count', title='Periodic Metagene')

# Show the plot
fig.show()


In [10]:

frame_ratios = {0: 0.33, 1: 0.33, 2: 0.33}  # No periodicity
metagene = generate_metagene(frame_ratios, start, stop, noise_factor=0)

# Get the positions and counts from the metagene dictionary
positions = list(metagene.keys())
counts = list(metagene.values())

# Create the bar plot
fig = go.Figure(data=go.Bar(x=positions, y=counts))

# Set the axis labels
fig.update_layout(xaxis_title='Position', yaxis_title='Count', title='Metagene with no periodicity')

# Show the plot
fig.show()

Simulate range of metagenes using permutations

In [11]:
start = 10
stop = 110
simulated_metagenes = {'start':{}, 'stop':{}}
simulated_metagenes['start'] = {
    i: generate_metagene(frame_ratios, start, stop) for i, frame_ratios in generate_read_frame_distribution_permutations().items()
}

Get read frame distribution for each metagene to ensure comparisons are of same data. Otherwise it may differ due to the introduced noise factor

In [12]:
simulated_read_frame_dict = {}
for i, metagene in simulated_metagenes['start'].items():
    metagene_total = sum(metagene.values())
    simulated_read_frame_dict[i] = {
        0: round(sum([metagene[pos] for pos in metagene if pos % 3 == 0])/metagene_total, 4),
        1: round(sum([metagene[pos] for pos in metagene if pos % 3 == 1])/metagene_total, 4),
        2: round(sum([metagene[pos] for pos in metagene if pos % 3 == 2])/metagene_total, 4),
    }
    

## Metrics on Simulated data


In [13]:
!pip install pywavelets



In [14]:
def counts_to_codon_proportions(counts: list) -> list:
    """
    Convert a list of counts to proportions of codon.
    Codons are windows of 3 nucleotides and there is no overlap between windows

    Inputs:
        counts: list
            A list of counts for each position

    Returns:
        dict: A dictionary where keys represent positions,
        and values represent codon proportions.
    """
    codon_proportions = []
    for i in range(0, len(counts), 3):
        codon_counts = counts[i:i+3]
        total_count = sum(codon_counts)
        for count in codon_counts:
            if total_count != 0:
                codon_proportions.append(count / total_count)
            else:
                codon_proportions.append(0)
    return codon_proportions

import pywt

def wavelet_transform(metagene_profile, read_lengths=[25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], wavelet='db4'):
    """
    Calculate the discrete wavelet transform of the metagene profile.

    Inputs:
        metagene_profile: dict
            The metagene profile to compute the wavelet transform of.
        read_lengths: list, optional
            The read lengths to consider.
        wavelet: str, optional
            The wavelet family to use for the DWT.

    Returns:
        wavelet_scores: dict
            The wavelet transform scores for each read length.
    """
    wavelet_scores = {}
    global_counts = []

    for read_len in read_lengths:
        if not global_counts:
            global_counts = list(metagene_profile['start'][read_len].values())
        else:
            global_counts = [i + j for i, j in zip(global_counts, list(metagene_profile['start'][read_len].values()))]

        counts = list(metagene_profile['start'][read_len].values())
        wavelet_coeffs = pywt.wavedec(counts, wavelet, mode='smooth')
        wavelet_scores[read_len] = np.sum(np.abs(wavelet_coeffs[-1])) / np.sum(np.abs(wavelet_coeffs[0]))

    global_wavelet_coeffs = pywt.wavedec(global_counts, wavelet, mode='smooth')
    wavelet_scores["global"] = np.sum(np.abs(global_wavelet_coeffs[-1])) / np.sum(np.abs(global_wavelet_coeffs[0]))

    return wavelet_scores

In [15]:
def denoise_and_convert_to_pf_p_sites(signal, wavelet='sym4', frequency_band=(0.2, 0.5), target_frequency_range=(0.328125, 0.34375)):
    # Perform DWPT
    wp = pywt.WaveletPacket(data=signal, wavelet=wavelet, mode='symmetric')
    
    # Get nodes corresponding to the frequency band of interest
    nodes = [node for node in wp.get_level(wp.maxlevel, 'freq') if node.path.count('/') == 2]

    # Get coefficients within the target frequency range
    target_coefficients = []
    for node in nodes:
        # Calculate frequency based on the node's path
        node_frequency = 2 ** (wp.maxlevel - node.level) / len(signal)
        if node_frequency >= target_frequency_range[0] and node_frequency <= target_frequency_range[1]:
            target_coefficients.extend(node.data)

    # Denoise by eliminating signals with 3-nt frequency's coefficient lower than others
    denoised_coefficients = []
    for i, coeff in enumerate(target_coefficients):
        if all(coeff > other_coeff for j, other_coeff in enumerate(target_coefficients) if j != i):
            denoised_coefficients.append(coeff)

    # Convert to PF P-sites
    pf_p_sites = [index for index, coeff in enumerate(target_coefficients) if coeff in denoised_coefficients]

    return pf_p_sites

# Example usage:
signal = np.random.rand(1000)  # Example signal (replace with your actual signal)
pf_p_sites = denoise_and_convert_to_pf_p_sites(signal)
print("PF P-sites:", pf_p_sites)


PF P-sites: []


In [16]:
def wavelet_decomposition(metagene_profiles, read_lengths=[25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], wavelet='db4', levels=3, target_frequency=0.33):
    """
    Plots the wavelet decomposition of a metagene profile for a given read length.

    Args:
        metagene_profile (dict): The metagene profile dictionary.
        read_length (int): The read length to decompose.
        wavelet (str, optional): The wavelet family to use for the DWT. Default is 'db4'.
        levels (int, optional): The number of levels to decompose the signal. Default is 5.

    Returns:
        None
    """
    wavelet_scores = {}
    global_counts = []

    for read_length in read_lengths:
        if not global_counts:
            global_counts = list(metagene_profiles['start'][read_length].values())
        else:
            global_counts = [i + j for i, j in zip(global_counts, list(metagene_profiles['start'][read_length].values()))]

        signal_data = np.array(list((metagene_profiles['start'][read_length].values())))

        # Compute wavelet decomposition
        coeffs = pywt.wavedec(signal_data, wavelet, mode='smooth', level=levels)

        power_spectrum = signal.welch(coeffs[-1], fs=1.0, nperseg=256)  # Power spectrum of Level 4 coefficients
        total_power = np.sum(power_spectrum)

        # Find indices of frequency closest to the target frequency
        target_idx = np.argmin(np.abs(power_spectrum[0] - target_frequency))

        # Calculate score as the amplitude of the peak at the target frequency
        score = power_spectrum[1][target_idx]
        wavelet_scores[read_length] = 1 - (score / total_power)

    signal_data = np.array(global_counts)

    # Compute wavelet decomposition
    coeffs = pywt.wavedec(signal_data, wavelet, mode='smooth', level=levels)

    power_spectrum = signal_data.welch(coeffs[-1], fs=1.0, nperseg=256)  # Power spectrum of Level 4 coefficients
    total_power = np.sum(power_spectrum)

    # Find indices of frequency closest to the target frequency
    target_idx = np.argmin(np.abs(power_spectrum[0] - target_frequency))

    # Calculate score as the amplitude of the peak at the target frequency
    score = power_spectrum[1][target_idx]
    # wavelet_scores['global'] = 1 - (score / total_power)
    return wavelet_scores



In [17]:
def dominance_score(simulated_metagenes):
    metagene_dominance_scores = {}
    for read_len in simulated_metagenes['start']:
        frame_1_counts = sum([count for i, count in enumerate(simulated_metagenes['start'][read_len].values()) if i % 3 == 0])
        frame_2_counts = sum([count for i, count in enumerate(simulated_metagenes['start'][read_len].values()) if i % 3 == 1])
        frame_3_counts = sum([count for i, count in enumerate(simulated_metagenes['start'][read_len].values()) if i % 3 == 2])
        total_counts = sum(simulated_metagenes['start'][read_len].values())

        dominance_score = max(frame_1_counts, frame_2_counts, frame_3_counts) / total_counts
        metagene_dominance_scores[read_len] = dominance_score
    
    return metagene_dominance_scores

In [18]:

metagene_dominance_scores = dominance_score(simulated_metagenes)
# Sort metagenes by dominance score
sorted_metagenes = dict(sorted(metagene_dominance_scores.items(), key=lambda item: item[1]))

# Plot the metagenes with the highest and lowest dominance scores
lowest_dominance_metagene = list(sorted_metagenes.keys())[0]
low_mid_dominance_metagene = list(sorted_metagenes.keys())[int(len(sorted_metagenes)/4)]
high_mid_dominance_metagene = list(sorted_metagenes.keys())[int(len(sorted_metagenes)*3/4)]
mid_mid_dominance_metagene = list(sorted_metagenes.keys())[int(len(sorted_metagenes)/2)]
highest_dominance_metagene = list(sorted_metagenes.keys())[-1]

ordered_subset_metagenes = { 'start': {
    lowest_dominance_metagene: simulated_metagenes['start'][lowest_dominance_metagene],
    low_mid_dominance_metagene: simulated_metagenes['start'][low_mid_dominance_metagene],
    high_mid_dominance_metagene: simulated_metagenes['start'][high_mid_dominance_metagene],
    mid_mid_dominance_metagene: simulated_metagenes['start'][mid_mid_dominance_metagene],
    highest_dominance_metagene: simulated_metagenes['start'][highest_dominance_metagene]
}}


In [19]:
for read_len in ordered_subset_metagenes['start']:
    # Get the positions and counts from the metagene dictionary
    positions = list(ordered_subset_metagenes['start'][read_len].keys())
    counts = list(ordered_subset_metagenes['start'][read_len].values())

    # Create the bar plot
    fig = go.Figure(data=go.Bar(x=positions, y=counts))

    # Set the axis labels
    fig.update_layout(xaxis_title='Position', yaxis_title='Count', title=f'Metagene with dominance score: {metagene_dominance_scores[read_len]}')

    # Show the plot
    fig.show()

In [46]:
def autocorrelate_periodicity(signal, expected_period):
    """
    Calculate a periodicity score based on the autocorrelation function.

    Args:
        signal (numpy.ndarray): The ribo-seq signal or count data.
        expected_period (int): The expected period of the signal (e.g., 3 for codon reading frame).

    Returns:
        float: The periodicity score based on the autocorrelation function.
    """
    # Compute the autocorrelation function
    autocorr = np.correlate(signal, signal, mode='full')
    autocorr = autocorr[autocorr.size // 2:]  # Keep only the positive lags

    autocorr = [float(val) / float(autocorr[0]) for val in autocorr]

    # Find the peak corresponding to the expected period
    peak_idx = np.argmax(autocorr[expected_period:]) + expected_period

    # Calculate the periodicity score
    peak_value = autocorr[peak_idx]
    mean_autocorr = np.mean(autocorr[1:])  # Exclude the value at lag 0
    periodicity_score = (peak_value - mean_autocorr) / mean_autocorr
    return periodicity_score


def autocorrelate_periodicity_metric(metagene_profile: dict, lag: int) -> dict:
    """
    Computes the autocorrelation of the ribosome counts at a given lag.

    Parameters:
    -----------
    metagene_profile: dict
        The metagene profile to compute the autocorrelation of.

    lag: int
        The lag to compute the autocorrelation at.

    Returns:
    --------
    read_length_scores: dict
        The autocorrelation scores at the given lag.
    """
    read_length_scores = {}
    global_counts = []

    for read_length in metagene_profile:
        if not global_counts:
            global_counts = list(metagene_profile[read_length].values())
        else:
            global_counts = [
                i + j for i, j in zip(
                    global_counts,
                    list(metagene_profile[read_length].values())
                    )
                    ]
        count_list = np.array(list(metagene_profile[read_length].values()))
        if count_list[0] is not None:
            read_length_scores[read_length] = autocorrelate_periodicity(count_list, lag)
        else:
            read_length_scores[read_length] = 0
    read_length_scores['global'] = autocorrelate_periodicity(np.array(global_counts), lag)
    return read_length_scores


def autocorrelate_uniformity_metric(metagene_profile: dict) -> dict:
    """
    Computes the autocorrelation of the ribosome counts looking
    for uniformity in the signal. Done by computing the autocorrelation
    of the signal and taking the inverse of the mean autocorrelation.

    Parameters:
    -----------
    metagene_profile: dict
        The metagene profile to compute the autocorrelation of.

    lag: int
        The lag to compute the autocorrelation at.

    Returns:
    --------
    read_length_scores: dict
        The autocorrelation scores at the given lag.
    """
    read_length_scores = {}
    global_counts = []

    for read_length in metagene_profile:
        if not global_counts:
            global_counts = list(metagene_profile[read_length].values())
        else:
            global_counts = [
                i + j for i, j in zip(
                    global_counts,
                    list(metagene_profile[read_length].values())
                    )
                    ]
        count_list = np.array(list(metagene_profile[read_length].values()))
        if count_list[0] is not None:
            read_length_scores[read_length] = autocorrelate_uniformity(count_list)
        else:
            read_length_scores[read_length] = 0
    read_length_scores['global'] = autocorrelate_uniformity(np.array(global_counts))
    return read_length_scores


In [44]:
read_frame_dominance_simulated = read_frame_dominance(simulated_read_frame_dict)
read_frame_score_trips_viz_simulated = read_frame_score_trips_viz(simulated_read_frame_dict)
read_frame_information_content_simulated = read_frame_information_content(simulated_read_frame_dict)

fourier_transform_simulated = fourier_transform(simulated_metagenes, read_lengths=simulated_metagenes['start'].keys())
multitaper_simulated = multitaper(simulated_metagenes, read_lengths=simulated_metagenes['start'].keys())
# print("Running")
# multitaper_thompson_simulated = compute_multitaper_spectrum(simulated_metagenes, read_lengths=simulated_metagenes['start'].keys())
# print("Done Running")
# wavelet_transform_simulated = wavelet_decomposition(simulated_metagenes, read_lengths=simulated_metagenes['start'].keys())
auto_periodicity_simulated = autocorrelate_periodicity_metric(simulated_metagenes['start'], 3)
auto_uniformity_simulated = autocorrelate_uniformity_metric(simulated_metagenes['start'])


simulated_metrics = pd.DataFrame({
    'Simulated Read Frame Proportions 0': {i: simulated_read_frame_dict[i][0] for i in simulated_read_frame_dict},
    'Simulated Read Frame Proportions 1': {i: simulated_read_frame_dict[i][1] for i in simulated_read_frame_dict},
    'Simulated Read Frame Proportions 2': {i: simulated_read_frame_dict[i][2] for i in simulated_read_frame_dict},
    'Read Frame Dominance': read_frame_dominance_simulated,
    'Read Frame Score Trips Viz': read_frame_score_trips_viz_simulated,
    'Read Frame Information Content Score': read_frame_information_content_simulated,
    'Fourier Transform': fourier_transform_simulated,
    'Multitaper': multitaper_simulated,
    # 'Multitaper Thompson': multitaper_thompson_simulated,

    # 'Wavelet Transform': wavelet_transform_simulated,
    'Auto Periodicity': auto_periodicity_simulated,
    'Auto Uniformity': auto_uniformity_simulated,
})

# split read frame information content into 2 columns. Score and Total count

simulated_metrics[['Read Frame Information Content', 'Read Frame Information Content Total Count']] = simulated_metrics['Read Frame Information Content Score'].apply(pd.Series)

simulated_metrics['max_min_fourier'] = (simulated_metrics['Fourier Transform'] - simulated_metrics['Fourier Transform'].min()) / (simulated_metrics['Fourier Transform'].max() - simulated_metrics['Fourier Transform'].min())
# simulated_metrics['max_min_multitaper'] = (simulated_metrics['Multitaper'] - simulated_metrics['Multitaper'].min()) / (simulated_metrics['Multitaper'].max() - simulated_metrics['Multitaper'].min())
# simulated_metrics['max_min_wavelet'] = (simulated_metrics['Wavelet Transform'] - simulated_metrics['Wavelet Transform'].min()) / (simulated_metrics['Wavelet Transform'].max() - simulated_metrics['Wavelet Transform'].min())
# simulated_metrics['max_min_thompson'] = (simulated_metrics['Multitaper Thompson'] - simulated_metrics['Multitaper Thompson'].min()) / (simulated_metrics['Multitaper Thompson'].max() - simulated_metrics['Multitaper Thompson'].min())
simulated_metrics['composite_score'] = simulated_metrics['Read Frame Dominance'] * simulated_metrics['max_min_fourier']

simulated_metrics = simulated_metrics.drop(columns=['Read Frame Information Content Score','Read Frame Information Content Total Count'])

simulated_metrics.head()

Unnamed: 0,Simulated Read Frame Proportions 0,Simulated Read Frame Proportions 1,Simulated Read Frame Proportions 2,Read Frame Dominance,Read Frame Score Trips Viz,Fourier Transform,Multitaper,Auto Periodicity,Auto Uniformity,Read Frame Information Content,max_min_fourier,composite_score
0,0.9614,0.0147,0.0239,0.9614,0.97514,0.64328,96846270.0,4.538178,-0.212285,0.909867,0.9993,0.960727
1,0.9531,0.0137,0.0332,0.9531,0.965166,0.637249,93948990.0,4.443306,-0.216778,0.8955,0.989932,0.943504
2,0.9421,0.0136,0.0442,0.942194,0.953084,0.629106,91209950.0,4.319714,-0.222897,0.877629,0.977282,0.920789
3,0.9322,0.0143,0.0535,0.9322,0.942609,0.62156,88308170.0,4.210221,-0.228623,0.861713,0.96556,0.900095
4,0.923,0.0142,0.0628,0.923,0.931961,0.614521,85790780.0,4.113794,-0.234013,0.848219,0.954625,0.881119


In [41]:
import plotly.graph_objects as go

def plot_simulated_metrics(simulated_metrics, x_col_name="Read Frame Dominance", y_col_name='Read Frame Dominance'):
    
    fig = go.Figure()

    # Create three traces, one for each column
    text = [f'0: {row[0]}, 1: {row[1]}, 2: {row[2]}' for row in simulated_metrics[['Simulated Read Frame Proportions 0', 'Simulated Read Frame Proportions 1', 'Simulated Read Frame Proportions 2']].values]
    fig.add_trace(go.Scatter(x=simulated_metrics[x_col_name],
                                    y=simulated_metrics[y_col_name],
                                    mode='markers',
                                    name=x_col_name,
                                    text=text,))

        # Update the layout with axis titles
    fig.update_layout(
        title='Simulated Read Frame Proportions',
        xaxis_title=x_col_name,
        yaxis_title=y_col_name
    )

    fig.show()

In [45]:
plot_simulated_metrics(simulated_metrics, 'Read Frame Dominance')
plot_simulated_metrics(simulated_metrics, 'Auto Periodicity')
plot_simulated_metrics(simulated_metrics, 'Auto Uniformity')
plot_simulated_metrics(simulated_metrics, 'Read Frame Score Trips Viz')
plot_simulated_metrics(simulated_metrics, 'Read Frame Information Content')
plot_simulated_metrics(simulated_metrics, 'Fourier Transform')
plot_simulated_metrics(simulated_metrics, 'Multitaper')
plot_simulated_metrics(simulated_metrics, 'Wavelet Transform')
plot_simulated_metrics(simulated_metrics, 'composite_score')


KeyError: 'Wavelet Transform'

In [None]:
plot_simulated_metrics(simulated_metrics, 'Read Frame Dominance', 'Fourier Transform')
plot_simulated_metrics(simulated_metrics, 'Read Frame Score Trips Viz', 'Fourier Transform')
plot_simulated_metrics(simulated_metrics, 'Read Frame Information Content', 'Fourier Transform')
plot_simulated_metrics(simulated_metrics, 'Fourier Transform', 'Fourier Transform')
plot_simulated_metrics(simulated_metrics, 'Multitaper', 'Fourier Transform')
plot_simulated_metrics(simulated_metrics, 'Wavelet Transform', 'Fourier Transform')

In [None]:
plot_simulated_metrics(simulated_metrics, 'Multitaper Test', 'Read Frame Dominance')

KeyError: 'Multitaper Test'

In [None]:
def plot_metrics(metrics, scaler=None, title='Metrics for Simulated Metagenes'):
    """
    Plot the metrics for the simulated metagenes.
    
    Args:
        metrics (pd.DataFrame): A DataFrame containing the metrics for the simulated metagenes.
        scaler (sklearn.preprocessing): A scaler object to scale the metrics.
    
    Returns:
        None
    """
    # Scale the metrics
    if scaler:
        metrics_scaled = pd.DataFrame(scaler.fit_transform(metrics), columns=metrics.columns)
    else:
        metrics_scaled = metrics
    
    # Create the figure
    fig = go.Figure()
    
    # Add traces for each metric
    for col in metrics_scaled.columns:
        fig.add_trace(go.Box(y=metrics_scaled[col], name=col))

    # Set the axis labels
    fig.update_layout(xaxis_title='Metagene', yaxis_title='Scaled Value', title=title)

    # Show the plot
    fig.show()

In [None]:
plot_metrics(simulated_metrics, scaler=None, title='Metrics for Simulated Metagenes (Standard Scaler)')

plot_metrics(simulated_metrics, scaler=StandardScaler(), title='Metrics for Simulated Metagenes (Standard Scaler)')
plot_metrics(simulated_metrics, scaler=MinMaxScaler(), title='Metrics for Simulated Metagenes (MinMax Scaler)')
plot_metrics(simulated_metrics, scaler=RobustScaler(), title='Metrics for Simulated Metagenes (Robust Scaler)')
plot_metrics(simulated_metrics, scaler=MaxAbsScaler(), title='Metrics for Simulated Metagenes (MaxAbs Scaler)')


In [None]:
# sort df by reads in frame 0
simulated_metrics_sampled = simulated_metrics.sort_values(by='max_min_fourier', ascending=False).sample()

# plot bar chart of the metrics grouped by sample 
simulated_metrics_sampled.head()


# metrics = ['composite_score', 'Read Frame Dominance', 'max_min_fourier', 'max_min_multitaper', 'max_min_wavelet']

# fig = go.Figure(
#     data=[
#         go.Bar(name='composite_score', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['composite_score']),
#         go.Bar(name='read_frame_score_trips_viz', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['Read Frame Score Trips Viz']),
#         go.Bar(name='Read Frame Information Content', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['Read Frame Information Content']),
#         go.Bar(name='Read Frame Dominance', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['Read Frame Dominance']),
#         go.Bar(name='max_min_fourier', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['max_min_fourier']),
#         go.Bar(name='max_min_multitaper', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['max_min_multitaper']),
#         go.Bar(name='max_min_wavelet', x=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))], y=simulated_metrics_sampled['max_min_wavelet']),
#     ]
# )

# fig.update_layout(
#     barmode='group',
#     title='Simulated Read Frame Proportions, Composite Score, and Read Frame Dominance',
#     xaxis_title='Sample',
#     yaxis_title='Value',
#     xaxis=dict(tickvals=[i for i in range(len(simulated_metrics_sampled))], ticktext=[f'Sample {i}' for i in range(len(simulated_metrics_sampled))]),
# )

# fig.show()


In [None]:
# plot metagene of the sample 1

metagene = simulated_metagenes['start'][simulated_metrics_sampled.index[24]]

# Get the positions and counts from the metagene dictionary
positions = list(metagene.keys())
counts = list(metagene.values())

# Create the bar plot
fig = go.Figure(data=go.Bar(x=positions, y=counts))

# Set the axis labels
fig.update_layout(xaxis_title='Position', yaxis_title='Count', title='Periodic Metagene')

# Show the plot
fig.show()

In [None]:
# plot fourier value in histogram

fig = go.Figure(data=[go.Histogram(x=simulated_metrics['Fourier Transform'])])

fig.show()