<a href="https://colab.research.google.com/github/May-BG/SuPreMo-Enformer/blob/visualization/dash_variant_prioritization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import pandas as pd
import dash
from dash import dcc, html, Input, Output, callback
import plotly.express as px
import random

In [3]:
pip install flask dash pandas plotly




In [12]:
# Generate example data (10 rows x 5 columns)
data = {
    f"Column_{i+1}": [random.randint(1, 100) for _ in range(10)] for i in range(5)
}
df = pd.DataFrame(data)
df['Row'] = df.index  # Adding row number for identification

# Start the Dash app
app = dash.Dash(__name__)

# Layout of the app
app.layout = html.Div([
    html.H1("Interactive Scatter Plot Viewer", style={"textAlign": "center"}),

    html.Label("Select a Column for X-axis"),
    dcc.Dropdown(
        id='x-column',
        options=[{'label': col, 'value': col} for col in df.columns if col != 'Row'],
        value='Column_1'  # Default value
    ),

    html.Label("Select a Column for Y-axis"),
    dcc.Dropdown(
        id='y-column',
        options=[{'label': col, 'value': col} for col in df.columns if col != 'Row'],
        value='Column_2'  # Default value
    ),

    dcc.Graph(id='scatter-plot'),
    html.Div(id='clicked-point', style={"textAlign": "center", "marginTop": "20px", "fontSize": "18px"})
])

# Callbacks for the interactivity
@app.callback(
    Output('scatter-plot', 'figure'),
    Input('x-column', 'value'),
    Input('y-column', 'value')
)
def update_scatter_plot(x_col, y_col):
    # Create a scatter plot
    fig = px.scatter(
        df,
        x=x_col,
        y=y_col,
        hover_data=['Row'],  # Show Row number on hover
        title=f"Scatter Plot of {x_col} vs {y_col}"
    )
    return fig

@app.callback(
    Output('clicked-point', 'children'),
    Input('scatter-plot', 'clickData')
)
def display_clicked_point(clickData):
    if clickData is None:
        return "Click on a point to see its row number."
    else:
        row_num = clickData['points'][0]['customdata'][0]  # Extract Row number
        return f"You clicked on Row Number: {row_num}"

if __name__ == '__main__':
    app.run_server(debug=False)


<IPython.core.display.Javascript object>

In [31]:
# Incorporate data
df = pd.read_csv('/concatenated_data.txt', sep='\t')
df.head()
# Add a Row column for indexing
df['Row'] = df.index

#df.columns

Unnamed: 0,"corr_CAGE:Fibroblast - skin dystrophia myotonica,_0"
0,1.000000
1,1.000000
2,1.000000
3,1.000000
4,
...,...
22904,1.000000
22905,0.981862
22906,0.983416
22907,1.000000


In [34]:
# Clean column names by stripping whitespace
df.columns = df.columns.str.strip()

# Add a Row column for indexing
df['Row'] = df.index

# Filter relevant columns: Exclude 'CHROM', 'POS', 'END', 'REF', 'ALT', 'ALT2', 'SVTYPE', 'SVLEN', 'var_index'
value_columns = [col for col in df.columns if col.startswith("mse_") or col.startswith("corr_")]

# Extract unique strings following 'mse_' and 'corr_'
unique_keys = list(set(col.split("mse_")[1] if col.startswith("mse_") else col.split("corr_")[1] for col in value_columns))
unique_keys.sort()

# Start the Dash app
app = dash.Dash(__name__)

# Layout of the app
app.layout = html.Div([
    html.H1("Interactive MSE and CORR Scatter Plot Viewer", style={"textAlign": "center"}),

    html.Label("Select a Unique Key for Scatter Plot"),
    dcc.Dropdown(
        id='key-selector',
        options=[{'label': key, 'value': key} for key in unique_keys],
        value=unique_keys[0]  # Default to the first key
    ),
    dcc.Graph(id='scatter-plot'),

    html.Hr(),

    html.Label("Select Multiple Unique Keys to Plot Mean MSE vs Mean CORR"),
    dcc.Dropdown(
        id='multi-key-selector',
        options=[{'label': key, 'value': key} for key in unique_keys],
        multi=True,
        value=[unique_keys[0]]  # Default to first key
    ),
    dcc.Graph(id='mean-scatter-plot')
])

# Callback for scatter plot of a single mse_ vs corr_
@app.callback(
    Output('scatter-plot', 'figure'),
    Input('key-selector', 'value')
)
def update_single_scatter_plot(selected_key):
    mse_col = f"mse_{selected_key}"
    corr_col = f"corr_{selected_key}"
    hover_cols = ['CHROM', 'POS', 'END', 'REF', 'ALT', 'ALT2', 'SVTYPE', 'SVLEN', 'var_index']
    if mse_col in df.columns and corr_col in df.columns:
        df_filtered = df[[mse_col, corr_col, 'Row'] + hover_cols].dropna()  # Drop rows with NaN
        print(f"Plotting: {mse_col} vs {corr_col}, Rows: {len(df_filtered)}")
        fig = px.scatter(
            df_filtered,
            x=mse_col,
            y=corr_col,
            hover_data=hover_cols + ['Row'],
            title=f"Scatter Plot: {mse_col} vs {corr_col}",
            labels={"x": mse_col, "y": corr_col}
        )
        return fig
    return px.scatter(title="No Data Found for Selected Key")

# Callback for scatter plot of mean MSE vs mean CORR
@app.callback(
    Output('mean-scatter-plot', 'figure'),
    Input('multi-key-selector', 'value')
)
def update_mean_scatter_plot(selected_keys):
    if not selected_keys:
        return px.scatter(title="No Keys Selected")

    mse_cols = [f"mse_{key}" for key in selected_keys if f"mse_{key}" in df.columns]
    corr_cols = [f"corr_{key}" for key in selected_keys if f"corr_{key}" in df.columns]
    hover_cols = ['CHROM', 'POS', 'END', 'REF', 'ALT', 'ALT2', 'SVTYPE', 'SVLEN', 'var_index']

    if mse_cols and corr_cols:
      # Copy only relevant columns and calculate mean while ignoring NaN
        df_filtered = df[mse_cols + corr_cols + hover_cols].copy()
        df_filtered['mean_mse'] = df_filtered[mse_cols].mean(axis=1, skipna=True)
        df_filtered['mean_corr'] = df_filtered[corr_cols].mean(axis=1, skipna=True)

        # Drop rows where both means are NaN (optional, ensures clean plotting)
        df_filtered = df_filtered.dropna(subset=['mean_mse', 'mean_corr'])
        print(f"Mean Plot Columns: MSE: {mse_cols}, CORR: {corr_cols}, Rows: {len(df_filtered)}")
        fig = px.scatter(
            df_filtered,
            x='mean_mse',
            y='mean_corr',
            hover_data=hover_cols + ['Row'],
            title="Scatter Plot: Mean MSE vs Mean CORR",
            labels={"x": "Mean MSE", "y": "Mean CORR"}
        )
        return fig
    return px.scatter(title="No Data Found for Selected Keys")

if __name__ == '__main__':
    app.run_server(debug=False)



<IPython.core.display.Javascript object>