In [3]:
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
import plotly.express as px
from dash.dependencies import MATCH, ALL, Input, Output, State
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import ConvergenceWarning
from sklearn.impute import KNNImputer
import warnings
import umap.umap_ as umap
from sklearn.feature_selection import f_classif
import json
import dash_table

warnings.filterwarnings("ignore")

The dash_table package is deprecated. Please replace
`import dash_table` with `from dash import dash_table`

Also, if you're using any of the table format helpers (e.g. Group), replace 
`from dash_table.Format import Group` with 
`from dash.dash_table.Format import Group`
  import dash_table


In [4]:
# Step 1: Load the Dataset

# Load the entire dataset
full_data = pd.read_csv('METABRIC_RNA_Mutation.csv')

# Assuming the first 31 columns are the clinical attributes
data = full_data.iloc[:, :31]
clinical_attributes = data.columns.tolist()

# All remaining columns are genetic data
df_genetics = full_data.iloc[:, 31:]

print("Clinical DataFrame shape:", data.shape)
print("Genetics DataFrame shape:", df_genetics.shape)

Clinical DataFrame shape: (1904, 31)
Genetics DataFrame shape: (1904, 662)


In [5]:
# Step 2: Identify Feature Types

numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
categorical_cols = [c for c in data.columns if c not in numeric_cols]

In [6]:
# ---------------------------
# Step 3: Handle Missing Data by Feature Category
# ---------------------------

# 3.1: Handle Categorical Features with Low Missingness
# For categorical features with few missing values, we can impute using the mode.
for col in categorical_cols:
    if data[col].isnull().sum() > 0:
        # Calculate mode and impute
        mode_val = data[col].mode(dropna=True)
        if len(mode_val) > 0:
            mode_val = mode_val[0]
        else:
            # If no mode (very rare), assign "Unknown"
            mode_val = "Unknown"
        data[col].fillna(mode_val, inplace=True)

# 3.2: Specific Handling for `3-gene_classifier_subtype` (~5% missing)
# If you want to explicitly handle this column differently (e.g. "Unknown" category),
# you could do so. Here we already did mode imputation above. If you prefer "Unknown":
if '3-gene_classifier_subtype' in data.columns:
    missing_count = data['3-gene_classifier_subtype'].isnull().sum()
    if missing_count > 0:
        # If you want a separate category:
        data['3-gene_classifier_subtype'].fillna("Unknown", inplace=True)

# 3.3: Handle Numeric Features with Low Missingness using Median
low_missing_numeric_cols = [col for col in numeric_cols 
                            if 0 < data[col].isnull().sum() <= 80]

for col in low_missing_numeric_cols:
    median_val = data[col].median()
    data[col].fillna(median_val, inplace=True)

# 3.4: Handle `tumor_stage` (~26% missing)
# For a large proportion of missing data, use KNN Imputer. 
# We will:
#   - Select features that likely correlate with tumor_stage (e.g., tumor_size, mutation_count)
#   - Apply KNN imputation to only the rows/columns relevant.

if 'tumor_stage' in data.columns:
    # Select features for KNN imputation
    # Include tumor_stage and other numeric variables that might help predict it
    knn_features = ['tumor_stage']
    # Add some other numeric features that might correlate with tumor_stage
    # (Adjust these based on domain knowledge)
    potential_predictors = ['tumor_size', 'mutation_count', 'age_at_diagnosis', 'neoplasm_histologic_grade']
    for f in potential_predictors:
        if f in data.columns and pd.api.types.is_numeric_dtype(data[f]):
            knn_features.append(f)

    # Create a subset for KNN imputation
    knn_data = data[knn_features]

    # Check if there's still missing data in the predictor columns; median impute them first
    for col in knn_data.columns:
        if col != 'tumor_stage' and knn_data[col].isnull().sum() > 0:
            knn_data[col].fillna(knn_data[col].median(), inplace=True)

    # Now apply KNN imputer for tumor_stage
    imputer = KNNImputer(n_neighbors=5)
    knn_imputed = imputer.fit_transform(knn_data)
    # Replace the original column with imputed values
    data['tumor_stage'] = knn_imputed[:, knn_features.index('tumor_stage')]
    
    # After KNN imputation (where data['tumor_stage'] may now contain floats),
    # round to nearest integer:
    data['tumor_stage'] = data['tumor_stage'].round().astype(int)

    # If you need to ensure values fall within [1,4], you can clip values:
    data['tumor_stage'] = data['tumor_stage'].clip(lower=1, upper=4)


# For features like `death_from_cancer` or others with very low missingness,
# treat them similarly by mode (if categorical) or median (if numeric).

In [7]:
# ---------------------------
# Step 4: Verify that Missing Data is Handled
# ---------------------------
print(data.isnull().sum())

# At this point, all missing values should be handled according to the proposed strategies.

patient_id                        0
age_at_diagnosis                  0
type_of_breast_surgery            0
cancer_type                       0
cancer_type_detailed              0
cellularity                       0
chemotherapy                      0
pam50_+_claudin-low_subtype       0
cohort                            0
er_status_measured_by_ihc         0
er_status                         0
neoplasm_histologic_grade         0
her2_status_measured_by_snp6      0
her2_status                       0
tumor_other_histologic_subtype    0
hormone_therapy                   0
inferred_menopausal_state         0
integrative_cluster               0
primary_tumor_laterality          0
lymph_nodes_examined_positive     0
mutation_count                    0
nottingham_prognostic_index       0
oncotree_code                     0
overall_survival_months           0
overall_survival                  0
pr_status                         0
radio_therapy                     0
3-gene_classifier_subtype   

In [8]:
numeric_cols = data.select_dtypes(include=np.number).columns.tolist()
mut_cols = [c for c in data.columns if c.endswith('_mut')]
categorical_cols = [c for c in data.columns if c not in numeric_cols]

# Example descriptions (adjust as needed)
descriptions = {
    'MASTECTOMY': "Mastectomy: A surgical procedure that involves removing the entire breast to treat or prevent breast cancer.",
    'BREAST CONSERVING': "Breast-conserving: A surgical procedure that removes only the cancerous tissue and a margin of surrounding normal tissue from the breast.",
    'Breast Cancer': "Breast cancer: A malignant tumor that arises from the cells of the breast, most commonly originating in the ducts or lobules.",
    'Breast Sarcoma': "Breast sarcoma: A rare type of cancer that arises from the connective or supportive tissues of the breast, such as fat, muscle, or blood vessels.",
    'Breast Invasive Ductal Carcinoma': "Breast Invasive Ductal Carcinoma: Cancer that begins in the milk ducts and invades surrounding breast tissue.",
    'Breast Mixed Ductal and Lobular Carcinoma': "Breast Mixed Ductal and Lobular Carcinoma: A combination of ductal and lobular breast cancer characteristics in the same tumor.",
    'Breast Invasive Lobular Carcinoma': "Breast Invasive Lobular Carcinoma: Cancer that starts in the milk-producing lobules and spreads to nearby tissues.",
    'Breast Invasive Mixed Mucinous Carcinoma': "Breast Invasive Mixed Mucinous Carcinoma: A rare type of breast cancer featuring both invasive and mucin-producing cancer cells.",
    'Breast': "Breast: Likely referring to unspecified breast cancer or general breast tissue conditions.",
    'Metaplastic Breast Cancer': "Metaplastic Breast Cancer: A rare, aggressive form of breast cancer involving changes in cell type and structure.",
    'Claudin-low': "Claudin-low: A subtype characterized by low expression of cell adhesion and tight junction proteins, associated with immune cell infiltration and stem cell-like features.",
    'LumA': "LumA: A hormone receptor-positive subtype with high expression of estrogen-related genes and low proliferation rates, often associated with better prognosis.",
    'LumB': "LumB: A more aggressive hormone receptor-positive subtype with higher proliferation rates and lower estrogen signaling compared to LumA.",
    'Her2': "Her2: A subtype defined by overexpression of the HER2 protein, often associated with more aggressive tumor behavior and targeted therapy options.",
    'Normal': "Normal: A group resembling normal breast tissue with low tumor cell content, used as a reference in classification.",
    'Basal': "Basal: A highly aggressive triple-negative subtype often associated with poor differentiation and higher proliferation rates.",
    'NC': "NC: A category for tumors that do not distinctly fit into any of the defined PAM50 subtypes."
}

color_discrete_sequence = px.colors.qualitative.Safe
external_stylesheets = [dbc.themes.FLATLY]
app = dash.Dash(__name__, external_stylesheets=external_stylesheets, suppress_callback_exceptions=True)
server = app.server

In [9]:
########################################
# PAGE 1 LAYOUT (Modified)
########################################

page1_layout = dbc.Container([
    html.H2("Public Interface: Data Exploration", className="mt-4 mb-4"),
    html.P("Explore and visualize the METABRIC dataset with various filters and plot types. Adjust filters, pick plot styles, and compare two plots side-by-side."),

    dbc.Row([
        # Filters Panel
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("General Information", className="bg-light"),
                dbc.CardBody([
                    # Existing Filters
                    html.Label("Type of Breast Surgery"),
                    dcc.Dropdown(
                        id='type_of_breast_surgery_dd',
                        options=[{'label': i, 'value': i} for i in data['type_of_breast_surgery'].dropna().unique()],
                        multi=False,
                        placeholder="Select Surgery Type"
                    ),
                    html.Br(),

                    html.Label("Cancer Type"),
                    dcc.Dropdown(
                        id='cancer_type_dd',
                        options=[{'label': i, 'value': i} for i in data['cancer_type'].dropna().unique()],
                        multi=False,
                        placeholder="Select Cancer Type"
                    ),
                    html.Br(),

                    html.Label("Cancer Type Detailed"),
                    dcc.Dropdown(
                        id='cancer_type_detailed_dd',
                        options=[{'label': i, 'value': i} for i in data['cancer_type_detailed'].dropna().unique()],
                        multi=False,
                        placeholder="Select Detailed Cancer Type"
                    ),
                    html.Br(),

                    html.Label("Pam50+Claudin Subtype"),
                    dcc.Dropdown(
                        id='pam50_dd',
                        options=[{'label': i, 'value': i} for i in data['pam50_+_claudin-low_subtype'].dropna().unique()],
                        multi=False,
                        placeholder="Select Pam50+Claudin Subtype"
                    ),
                    html.Div(id='description_div', style={'marginTop':'20px', 'whiteSpace': 'pre-wrap'})
                ])
            ])
        ], width=6),

        # Basic Statistics
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Basic Statistics", className="bg-light"),
                dbc.CardBody([
                    html.Label("Data Type (Feature)"),
                    dcc.Dropdown(
                        id='stats_feature',
                        options=[{'label': c, 'value': c} for c in clinical_attributes],
                        value=clinical_attributes[0],
                        placeholder="Select Feature for Statistics"
                    ),
                    html.Br(),
                    html.Label("Needed Information"),
                    dcc.Dropdown(
                        id='stats_info',
                        options=[
                            {'label': 'Most Common', 'value': 'most_common'},
                            {'label': 'Least Common', 'value': 'least_common'}, 
                            {'label': 'Minimum', 'value': 'min'},
                            {'label': 'Maximum', 'value': 'max'}, 
                            {'label': 'Mean', 'value': 'mean'}
                        ],
                        value='most_common',
                        placeholder="Select Statistic Type"
                    ),
                    html.Div(id='stats_result', style={'marginTop':'20px'}),
                    html.Hr(),

                    # New Counts Display Section
                    html.H5("Feature Counts"),
                    html.Div(id='counts_display')
                ])   
            ])
        ], width=6)
    ], className="mt-4"),

    html.Br(),

       dbc.Row([
        # Visualization Settings
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Visualization Settings", className="bg-light"),
                dbc.CardBody([
                    html.Label("Primary Plot Type:"),
                    dcc.Dropdown(
                        id="plot-type",
                        options=[
                            {"label": "Scatter Plot", "value": "scatter"},
                            {"label": "Histogram", "value": "histogram"},
                            {"label": "Box Plot", "value": "box"},
                            {"label": "Violin Plot", "value": "violin"},
                            {"label": "Pie Chart", "value": "pie"},
                            {"label": "Heatmap (Correlation)", "value": "heatmap"}
                        ],
                        value="scatter",
                        placeholder="Select Primary Plot Type"
                    ),
                    html.Br(),

                    html.Label("X Value:"),
                    dcc.Dropdown(
                        id='x-axis-feature',
                        options=[{'label': c, 'value': c} for c in clinical_attributes],
                        value=clinical_attributes[0],
                        placeholder="Select X-axis Feature"
                    ),
                    html.Br(),

                    html.Label("Y Value (if applicable):"),
                    dcc.Dropdown(
                        id='y-axis-feature',
                        options=[{'label': c, 'value': c} for c in clinical_attributes],
                        value=clinical_attributes[1],
                        placeholder="Select Y-axis Feature"
                    ),
                    html.Br(),
                    html.Hr(),

                    html.Label("Comparison Plot Type:"),
                    dcc.Dropdown(
                        id="plot-type-2",
                        options=[
                            {"label": "Scatter Plot", "value": "scatter"},
                            {"label": "Histogram", "value": "histogram"},
                            {"label": "Box Plot", "value": "box"},
                            {"label": "Violin Plot", "value": "violin"},
                            {"label": "Pie Chart", "value": "pie"},
                            {"label": "Heatmap (Correlation)", "value": "heatmap"}
                        ],
                        value="scatter",
                        placeholder="Select Comparison Plot Type"
                    ),
                    html.Br(),

                    html.Label("X Value (Comparison):"),
                    dcc.Dropdown(
                        id='x-axis-feature-2',
                        options=[{'label': c, 'value': c} for c in clinical_attributes],
                        value=clinical_attributes[2],
                        placeholder="Select X-axis Feature for Comparison"
                    ),
                    html.Br(),

                    html.Label("Y Value (Comparison, if applicable):"),
                    dcc.Dropdown(
                        id='y-axis-feature-2',
                        options=[{'label': c, 'value': c} for c in clinical_attributes],
                        value=clinical_attributes[3],
                        placeholder="Select Y-axis Feature for Comparison"
                    ),
                ])
            ])
          
        ], width=6),

        # Main and Comparison Plots
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Dynamic Feature Filters", className="bg-light"),
                dbc.CardBody([   
                 # Dynamic Filters Section
                    html.H5("Selected Filters:"),
                    html.Div(id='dynamic_filters')
                ])
            ])
           

        ], width=6),
    ]),

    html.Br(),

    dbc.Row([
        # Visualization Settings
        dbc.Col([
            dbc.Card([
                        dbc.CardHeader("Primary Graph", className="bg-light"),
                        dbc.CardBody([
                            dcc.Graph(id='main-plot-page1', clear_on_unhover=True)
                        ])
                    ])
          
          
        ], width=6),

        # Main and Comparison Plots
        dbc.Col([
             dbc.Card([
                        dbc.CardHeader("Comparison Graph", className="bg-light"),
                        dbc.CardBody([
                            dcc.Graph(id='second-plot-page1', clear_on_unhover=True)
                        ])
                    ])
          
           

        ], width=6),
    ]),


    html.Br(),

    # Store component to hold filtered data
    dcc.Store(id='filtered_data', data=data.to_dict('records')),
], fluid=True)

In [10]:
########################################
# PAGE 2 LAYOUT
########################################

page2_layout = dbc.Container([
    html.H2("Scientific Interface: Mutation Analysis", className="mt-4 mb-4"),
    html.P("Apply feature selection, dimensionality reduction and clustering. Clustering is done on high-dimensional or feature-selected space, and DR is for visualization."),


        dbc.Row([
        dbc.Col([
             dbc.Row([
                dbc.Card([
                dbc.CardHeader("Feature Selection (Intermediate DR)", className="bg-light"),
                dbc.CardBody([
                    html.P("Pre-select a subset of genes based on their relevance to a chosen clinical variable, and click the button. (Mandatory otherwise refresh the page!)"),
                    html.Label("Select a reference clinical variable:"),
                    dcc.Dropdown(
                        id='ref_variable',
                        options=[{'label': c, 'value': c} for c in clinical_attributes],
                        value=None
                    ),
                    html.Br(),
                    html.Label("Number of top genes to select (N):"),
                    dcc.Input(id='num_features', type='number', value=50, min=10, step=10),
                    html.Br(), html.Br(),
                    dbc.Button("Select Features", id='select_features_btn', color='secondary'),
                    html.Div(id='feature_selection_message', style={'marginTop':'10px'}),
                    # Hidden storage for selected features
                    dcc.Store(id='selected_features_storage')
                ])
            ])
       
            ]),

            html.Br(),

            dbc.Row([
                dbc.Card([
                dbc.CardHeader("Dimensionality Reduction Settings", className="bg-light"),
                dbc.CardBody([
                    html.Label("Select DR Method for Visualization:"),
                    dcc.Dropdown(
                        id='dr_method',
                        options=[
                            {"label": "PCA (Linear)", "value": "pca"},
                            {"label": "t-SNE (Non-Linear)", "value": "tsne"},
                            {"label": "UMAP (Non-Linear)", "value": "umap"}
                        ],
                        value="pca"
                    ),
                    html.Br(),
                    html.Label("Number of Components for Visualization (2 or 3):"),
                    dcc.Input(id='dr_n_components', type='number', value=2, min=2, max=3, step=1),
                    html.Br(),
                    html.Small("Note: DR is used here mainly for visualization, not for clustering directly."),
                ])
            ])
       

       
            ]),
            
        ],  width=4),


         dbc.Col([
            dbc.Row([
                dbc.Card([
                dbc.CardHeader("Relationship between two selected features", className="bg-light"),
                dbc.CardBody([
                    html.Label("Feature X:"),
                    dcc.Dropdown(id='scatter_feature_x', multi=False),
                    html.Label("Feature Y:"),
                    dcc.Dropdown(id='scatter_feature_y', multi=False),
                ])
            ])
       
            ]),

            html.Br(),

            dbc.Row([
                dbc.Card([
                dbc.CardHeader("Clustering Settings", className="bg-light"),
                dbc.CardBody([
                    html.Label("Select Clustering Method and press on Run Analysis:"),
                    dcc.Dropdown(
                        id='cluster_method',
                        options=[
                            {"label": "K-Means", "value": "kmeans"},
                            {"label": "Hierarchical (Agglomerative)", "value": "hierarchical"},
                            {"label": "DBSCAN", "value": "dbscan"}
                        ],
                        value="kmeans"
                    ),
                    html.Br(),

                    html.Div([
                        html.Div([
                            html.Label("Number of Clusters (K-Means):"),
                            dcc.Input(id='kmeans_n_clusters', type='number', value=3, min=2, step=1)
                        ], id='kmeans_params', style={'display': 'block'}),

                        html.Div([
                            html.Label("Number of Clusters (Hierarchical):"),
                            dcc.Input(id='agg_n_clusters', type='number', value=3, min=2, step=1),
                            html.Br(),
                            html.Label("Linkage:"),
                            dcc.Dropdown(
                                id='agg_linkage',
                                options=[{'label': l, 'value': l} for l in ['ward', 'complete', 'average', 'single']],
                                value='ward'
                            )
                        ], id='hierarchical_params', style={'display': 'none'}),

                        html.Div([
                            html.Label("DBSCAN eps:"),
                            dcc.Input(id='dbscan_eps', type='number', value=0.5, min=0.1, step=0.1),
                            html.Br(),
                            html.Label("DBSCAN min_samples:"),
                            dcc.Input(id='dbscan_min_samples', type='number', value=5, min=1, step=1)
                        ], id='dbscan_params', style={'display': 'none'}),
                    ]),

                    html.Br(),
                    html.Label("Clustering Space:"),
                    dcc.RadioItems(
                        id='clustering_space',
                        options=[
                            {'label': 'High-dimensional (or feature-selected) space', 'value': 'original'},
                            {'label': 'Dimension-reduced space', 'value': 'reduced'}
                        ],
                        value='original'
                    ),
                    html.Br(),
                    dbc.Button("Run Analysis", id='run_analysis_btn', color='primary')
                ])
            ])
       
            ]),
        ], width=4),

        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Relationship Plot", className="bg-light"),
                dbc.CardBody([
                    dcc.Graph(id='feature_scatter_plot')
                ])
            ])

            
        ], width=4),

       
    ]),

    
    html.Br(),

    dbc.Row([
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Clustering Visualization (DR Projection)", className="bg-light"),
                dbc.CardBody([
                    dcc.Graph(id='cluster_plot')
                ])
            ])
        ], width=12)
    ]),
], fluid=True)

In [11]:
########################################
# MAIN LAYOUT (Navigation)
########################################

app.layout = dbc.Container([
    dcc.Location(id='url', refresh=False),
    dbc.NavbarSimple(
        children=[
            dbc.NavItem(dbc.NavLink("Public Page (Page 1)", href="/page1")),
            dbc.NavItem(dbc.NavLink("Scientific Page (Page 2)", href="/page2")),
        ],
        brand="METABRIC Dashboard",
        brand_href="#",
        color="primary",
        dark=True,
    ),
    html.Br(),
    dbc.Container(id='page-content', fluid=True)
], fluid=True)

In [12]:
########################################
# PAGE 1 CALLBACKS (Already included above)
########################################

@app.callback(
    Output('description_div', 'children'),
    Input('type_of_breast_surgery_dd', 'value'),
    Input('cancer_type_dd', 'value'),
    Input('cancer_type_detailed_dd', 'value'),
    Input('pam50_dd', 'value')
)
def show_description(surgery_val, ctype_val, cdetail_val, pam50_val):
    texts = []
    if surgery_val == "MASTECTOMY":
        texts.append(descriptions['MASTECTOMY'])
    if surgery_val == "BREAST CONSERVING":
        texts.append(descriptions['BREAST CONSERVING'])
    if ctype_val == "Breast Cancer":
        texts.append(descriptions['Breast Cancer'])
    if ctype_val == "Breast Sarcoma":
        texts.append(descriptions['Breast Sarcoma'])
    if cdetail_val  == "Breast Invasive Ductal Carcinoma":
        texts.append(descriptions['Breast Invasive Ductal Carcinoma'])
    if cdetail_val  == "Breast Mixed Ductal and Lobular Carcinoma":
        texts.append(descriptions['Breast Mixed Ductal and Lobular Carcinoma'])
    if cdetail_val  == "Breast Invasive Lobular Carcinoma":
        texts.append(descriptions['Breast Invasive Lobular Carcinoma'])
    if cdetail_val  == "Breast Invasive Mixed Mucinous Carcinoma":
        texts.append(descriptions['Breast Invasive Mixed Mucinous Carcinoma'])
    if cdetail_val  == "Breast":
        texts.append(descriptions['Breast'])
    if cdetail_val  == "Metaplastic Breast Cancer":
        texts.append(descriptions['Metaplastic Breast Cancer'])
    if pam50_val == "claudin-low":
        texts.append(descriptions['Claudin-low'])
    if pam50_val == "LumA":
        texts.append(descriptions['LumA'])
    if pam50_val == "LumB":
        texts.append(descriptions['LumB'])
    if pam50_val == "Her2":
        texts.append(descriptions['Her2'])
    if pam50_val == "Normal":
        texts.append(descriptions['Normal'])
    if pam50_val == "Basal":
        texts.append(descriptions['Basal'])
    if pam50_val == "NC":
        texts.append(descriptions['NC'])
    return "\n\n".join(texts) if texts else "Select a value from the dropdowns above to see a description."


@app.callback(
    Output('stats_result', 'children'),
    Input('stats_feature', 'value'),
    Input('stats_info', 'value'),
    State('filtered_data', 'data')
)
def update_stats_result(selected_feature, selected_info, filtered_data):
    if not selected_feature or not selected_info:
        return "Please select a feature and the type of statistic."

    dff = pd.DataFrame(filtered_data)

    if selected_info == 'most_common' and selected_feature in categorical_cols:
        most_common = dff[selected_feature].mode().iloc[0]
        return f"Most Common {selected_feature}: {most_common}"
    
    elif selected_info == 'least_common' and selected_feature in categorical_cols:
        least_common = dff[selected_feature].value_counts().idxmin()
        return f"Least Common {selected_feature}: {least_common}"
    
    elif selected_info == 'min' and selected_feature in numeric_cols:
        minimum = dff[selected_feature].min()
        return f"Minimum {selected_feature}: {minimum}"
    
    elif selected_info == 'max' and selected_feature in numeric_cols:
        maximum = dff[selected_feature].max()
        return f"Maximum {selected_feature}: {maximum}"
    
    elif selected_info == 'mean' and selected_feature in numeric_cols:
        mean_val = dff[selected_feature].mean()
        return f"Mean {selected_feature}: {mean_val:.2f}"
    
    else:
        return "Selected statistic is not applicable for the chosen feature type."



# HELPER FUNCTIONS
def is_numeric(series):
    return pd.api.types.is_numeric_dtype(series)


def create_figure_page1(dff, plot_type, x_feature, y_feature):
    # Check if the DataFrame is empty
    if dff.empty:
        fig = go.Figure()
        fig.add_annotation(text="No data available after applying filters.", showarrow=False)
        fig.update_layout(template='simple_white')
        return fig

    layout_opts = dict(template='simple_white', color_discrete_sequence=color_discrete_sequence)

    if plot_type == "scatter":
        if x_feature not in dff.columns or y_feature not in dff.columns:
            fig = go.Figure()
            fig.add_annotation(text="Please select valid X and Y features for scatter plot.", showarrow=False)
            fig.update_layout(template='simple_white')
            return fig
        fig = px.scatter(
            dff, x=x_feature, y=y_feature, color='cancer_type',
            title=f"Scatter Plot: {x_feature} vs {y_feature}",
            **layout_opts
        )

    elif plot_type == "histogram":
        if x_feature not in dff.columns:
            fig = go.Figure()
            fig.add_annotation(text="Please select a valid X feature for histogram.", showarrow=False)
            fig.update_layout(template='simple_white')
            return fig
        fig = px.histogram(
            dff, x=x_feature, color='cancer_type',
            title=f"Histogram of {x_feature}",
            **layout_opts
        )

    elif plot_type == "box":
        if y_feature not in dff.columns or y_feature not in numeric_cols:
            default_num = numeric_cols[0] if numeric_cols else None
            if default_num is None:
                fig = go.Figure()
                fig.add_annotation(text="No numeric columns available for box plot.", showarrow=False)
                fig.update_layout(template='simple_white')
                return fig
            y_feature = default_num
        fig = px.box(
            dff, x=x_feature, y=y_feature,
            title=f"Box Plot: {y_feature} by {x_feature}",
            **layout_opts
        )

    elif plot_type == "violin":
        if y_feature not in dff.columns or y_feature not in numeric_cols:
            default_num = numeric_cols[0] if numeric_cols else None
            if default_num is None:
                fig = go.Figure()
                fig.add_annotation(text="No numeric columns available for violin plot.", showarrow=False)
                fig.update_layout(template='simple_white')
                return fig
            y_feature = default_num
        fig = px.violin(
            dff, x=x_feature, y=y_feature, box=True, points='all',
            title=f"Violin Plot: {y_feature} by {x_feature}",
            **layout_opts
        )

    elif plot_type == "pie":
        if x_feature not in dff.columns or x_feature not in categorical_cols:
            default_cat = 'cancer_type' if 'cancer_type' in categorical_cols else None
            if default_cat is None:
                fig = go.Figure()
                fig.add_annotation(text="No suitable categorical feature for pie chart.", showarrow=False)
                fig.update_layout(template='simple_white')
                return fig
            x_feature = default_cat
        fig = px.pie(
            dff, names=x_feature,
            title=f"Pie Chart of {x_feature}",
            **layout_opts
        )

    elif plot_type == "heatmap":
        num_dff = dff.select_dtypes(include=[np.number])
        if num_dff.shape[1] > 1:
            corr = num_dff.corr()
            fig = px.imshow(
                corr, 
                title="Correlation Heatmap",
                template='simple_white',
                color_continuous_scale='RdBu',
                zmin=-1, zmax=1
            )
        else:
            fig = go.Figure()
            fig.add_annotation(text="Not enough numeric features for heatmap.", showarrow=False)
            fig.update_layout(template='simple_white')
            return fig
    else:
        fig = go.Figure(template='simple_white')

    fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
    return fig

@app.callback(
    Output('main-plot-page1', 'figure'),
    Output('second-plot-page1', 'figure'),
    [
        Input('plot-type', 'value'),
        Input('x-axis-feature', 'value'),
        Input('y-axis-feature', 'value'),
        Input('plot-type-2', 'value'),
        Input('x-axis-feature-2', 'value'),
        Input('y-axis-feature-2', 'value'),
        Input('filtered_data', 'data')  # New Input
    ]
)
def update_page1_graph(
    plot_type, x_feature, y_feature,
    plot_type_2, x_feature_2, y_feature_2,
    filtered_data
):
    # Convert filtered data back to DataFrame
    dff = pd.DataFrame(filtered_data)

    fig = create_figure_page1(dff, plot_type, x_feature, y_feature)
    fig2 = create_figure_page1(dff, plot_type_2, x_feature_2, y_feature_2)
    return fig, fig2

    

def get_axis_options(plot_type):
    """Return appropriate x and y axis options and defaults for a given plot type."""
    if plot_type == "scatter":
        x_opts = [{"label": c, "value": c} for c in numeric_cols]
        y_opts = [{"label": c, "value": c} for c in numeric_cols]
        return x_opts, y_opts, False
    elif plot_type in ["box", "violin"]:
        x_opts = [{"label": c, "value": c} for c in categorical_cols]
        y_opts = [{"label": c, "value": c} for c in numeric_cols]
        return x_opts, y_opts, False
    elif plot_type == "histogram":
        x_opts = [{"label": c, "value": c} for c in categorical_cols]
        return x_opts, [], True
    elif plot_type == "pie":
        x_opts = [{"label": c, "value": c} for c in categorical_cols]
        return x_opts, [], True
    elif plot_type == "heatmap":
        return [], [], True
    x_opts = [{"label": c, "value": c} for c in data.columns]
    y_opts = [{"label": c, "value": c} for c in numeric_cols]
    return x_opts, y_opts, False

@app.callback(
    Output("x-axis-feature", "options"),
    Output("y-axis-feature", "options"),
    Output("y-axis-feature", "disabled"),
    Output("x-axis-feature", "value"),
    Output("y-axis-feature", "value"),
    Input("plot-type", "value")
)
def update_axis_options_for_plot_type(plot_type):
    x_opts, y_opts, y_disabled = get_axis_options(plot_type)
    default_x = None
    default_y = None

    if plot_type == 'scatter':
        default_x = x_opts[0]['value'] if x_opts else None
        default_y = numeric_cols[0] if numeric_cols else None
    elif plot_type in ['box', 'violin']:
        default_x = x_opts[0]['value'] if x_opts else None
        default_y = numeric_cols[0] if numeric_cols else None
    elif plot_type == 'histogram':
        default_x = 'cancer_type' if 'cancer_type' in data.columns else (x_opts[0]['value'] if x_opts else None)
    elif plot_type == 'pie':
        default_x = 'cancer_type' if 'cancer_type' in categorical_cols else (x_opts[0]['value'] if x_opts else None)
    elif plot_type == 'heatmap':
        pass
    else:
        default_x = x_opts[0]['value'] if x_opts else None
        default_y = numeric_cols[0] if numeric_cols else None

    return x_opts, y_opts, y_disabled, default_x, default_y

@app.callback(
    Output("x-axis-feature-2", "options"),
    Output("y-axis-feature-2", "options"),
    Output("y-axis-feature-2", "disabled"),
    Output("x-axis-feature-2", "value"),
    Output("y-axis-feature-2", "value"),
    Input("plot-type-2", "value")
)
def update_axis_options_for_plot_type_2(plot_type_2):
    x_opts, y_opts, y_disabled = get_axis_options(plot_type_2)
    default_x = None
    default_y = None

    if plot_type_2 == 'scatter':
        default_x = x_opts[0]['value'] if x_opts else None
        default_y = numeric_cols[0] if numeric_cols else None
    elif plot_type_2 in ['box', 'violin']:
        default_x = x_opts[0]['value'] if x_opts else None
        default_y = numeric_cols[0] if numeric_cols else None
    elif plot_type_2 == 'histogram':
        default_x = 'cancer_type' if 'cancer_type' in data.columns else (x_opts[0]['value'] if x_opts else None)
    elif plot_type_2 == 'pie':
        default_x = 'cancer_type' if 'cancer_type' in categorical_cols else (x_opts[0]['value'] if x_opts else None)
    elif plot_type_2 == 'heatmap':
        pass
    else:
        default_x = x_opts[0]['value'] if x_opts else None
        default_y = numeric_cols[0] if numeric_cols else None

    return x_opts, y_opts, y_disabled, default_x, default_y



@app.callback(
    Output('dynamic_filters', 'children'),
    [
        Input('x-axis-feature', 'value'),
        Input('y-axis-feature', 'value'),
        Input('x-axis-feature-2', 'value'),
        Input('y-axis-feature-2', 'value')
    ]
)
def update_dynamic_filters(x1, y1, x2, y2):
    # Collect unique features from both plots
    selected_features = set()
    for feature in [x1, y1, x2, y2]:
        if feature:
            selected_features.add(feature)

    filter_controls = []

    for feature in selected_features:
        if feature in numeric_cols:
            # Get min and max for the feature
            min_val = data[feature].min()
            max_val = data[feature].max()
            step = (max_val - min_val) / 100 if (max_val - min_val) != 0 else 1  # Prevent step=0

            filter_controls.append(
                dbc.Form([
                    dbc.Row([
                        dbc.Col([
                            html.Label(f"{feature} Range:"),
                            dcc.RangeSlider(
                                id={'type': 'filter-slider', 'index': feature},
                                min=min_val,
                                max=max_val,
                                step=step,
                                value=[min_val, max_val],
                                marks={
                                    float(f"{min_val:.2f}"): f"{min_val:.2f}",
                                    float(f"{max_val:.2f}"): f"{max_val:.2f}"
                                },
                                tooltip={"placement": "bottom", "always_visible": False},
                                allowCross=False
                            ),
                            html.Br()
                        ])
                    ])
                ])
            )
        else:
            # Categorical feature
            unique_vals = data[feature].unique()
            filter_controls.append(
                dbc.Form([
                    dbc.Row([
                        dbc.Col([
                            html.Label(f"{feature} Selection:"),
                            dcc.Dropdown(
                                id={'type': 'filter-dropdown', 'index': feature},
                                options=[{'label': val, 'value': val} for val in unique_vals],
                                value=list(unique_vals),
                                multi=True,
                                placeholder=f"Select values for {feature}"
                            ),
                            html.Br()
                        ])
                    ])
                ])
            )
    
    if not filter_controls:
        filter_controls = [
            html.P("No dynamic filters available based on selected features.")
        ]

    return filter_controls



@app.callback(
    Output('filtered_data', 'data'),
    [
        Input({'type': 'filter-slider', 'index': ALL}, 'value'),
        Input({'type': 'filter-dropdown', 'index': ALL}, 'value')
    ],
    [
        State({'type': 'filter-slider', 'index': ALL}, 'id'),
        State({'type': 'filter-dropdown', 'index': ALL}, 'id')
    ]
)
def apply_filters(slider_values, dropdown_values, slider_ids, dropdown_ids):
    # Start with the original data
    dff = data.copy()

    # Apply RangeSlider filters
    for slider_val, slider_id in zip(slider_values, slider_ids):
        feature = slider_id['index']
        if slider_val is not None and len(slider_val) == 2:
            dff = dff[(dff[feature] >= slider_val[0]) & (dff[feature] <= slider_val[1])]

    # Apply Dropdown filters
    for dropdown_val, dropdown_id in zip(dropdown_values, dropdown_ids):
        feature = dropdown_id['index']
        if dropdown_val:  # Ensure there are selected values
            dff = dff[dff[feature].isin(dropdown_val)]

    return dff.to_dict('records')

@app.callback(
    Output('counts_display', 'children'),
    Input('stats_feature', 'value'),
    State('filtered_data', 'data')
)
def update_counts_display(selected_feature, filtered_data):
    if not selected_feature:
        return html.P("Please select a feature to view counts.")

    dff = pd.DataFrame(filtered_data)

    if selected_feature in categorical_cols:
        # Categorical Feature: Display a table of unique values with counts
        value_counts = dff[selected_feature].value_counts().reset_index()
        value_counts.columns = [selected_feature, 'Count']

        return dash_table.DataTable(
            data=value_counts.to_dict('records'),
            columns=[{"name": i, "id": i} for i in value_counts.columns],
            style_cell={'textAlign': 'left'},
            style_header={
                'backgroundColor': 'rgb(230, 230, 230)',
                'fontWeight': 'bold'
            },
            page_size=10,
            style_table={'overflowX': 'auto'},
        )
    
    elif selected_feature in numeric_cols:
        # Numerical Feature: Display a RangeSlider and count
        min_val = dff[selected_feature].min()
        max_val = dff[selected_feature].max()
        step = (max_val - min_val) / 100 if (max_val - min_val) != 0 else 1  # Prevent step=0

        return html.Div([
            html.Label(f"{selected_feature} Range:"),
            dcc.RangeSlider(
                id='count_range_slider',
                min=min_val,
                max=max_val,
                step=step,
                value=[min_val, max_val],
                marks={
                    float(f"{min_val:.2f}"): f"{min_val:.2f}",
                    float(f"{max_val:.2f}"): f"{max_val:.2f}"
                },
                tooltip={"placement": "bottom", "always_visible": False},
                allowCross=False
            ),
            html.Br(),
            html.Div(id='range_count_display', style={'fontWeight': 'bold'})
        ])
    
    else:
        return html.P("Selected feature type is not supported for counts.")

@app.callback(
    Output('range_count_display', 'children'),
    Input('count_range_slider', 'value'),
    State('stats_feature', 'value'),
    State('filtered_data', 'data')
)
def update_range_count(selected_range, selected_feature, filtered_data):
    if not selected_feature or selected_feature not in numeric_cols:
        return ""

    dff = pd.DataFrame(filtered_data)

    if not selected_range or len(selected_range) != 2:
        return "Please select a valid range."

    lower, upper = selected_range
    count = dff[(dff[selected_feature] >= lower) & (dff[selected_feature] <= upper)].shape[0]

    return f"Count of records within range {lower:.2f} to {upper:.2f}: {count}"


In [13]:
########################################
# PAGE 2 CALLBACKS
########################################

# Show/hide clustering parameters
@app.callback(
    Output('kmeans_params', 'style'),
    Output('hierarchical_params', 'style'),
    Output('dbscan_params', 'style'),
    Input('cluster_method', 'value')
)
def show_hide_cluster_params(method):
    kmeans_style = {'display': 'none'}
    hierarchical_style = {'display': 'none'}
    dbscan_style = {'display': 'none'}
    
    if method == 'kmeans':
        kmeans_style = {'display': 'block'}
    elif method == 'hierarchical':
        hierarchical_style = {'display': 'block'}
    elif method == 'dbscan':
        dbscan_style = {'display': 'block'}
    return kmeans_style, hierarchical_style, dbscan_style


def is_categorical(series):
    return series.dtype == 'object' or series.nunique() < 10

@app.callback(
    Output('feature_selection_message', 'children'),
    Output('selected_features_storage', 'data'),
    Input('select_features_btn', 'n_clicks'),
    State('ref_variable', 'value'),
    State('num_features', 'value')
)
def select_features(n_clicks, ref_var, num_features):
    if n_clicks is None:
        return "", json.dumps(df_genetics.columns.tolist())

    if ref_var is None:
        return "Please select a reference variable.", json.dumps(df_genetics.columns.tolist())
    
    if ref_var not in data.columns:
        return "Reference variable not found in clinical data.", json.dumps(df_genetics.columns.tolist())

    if num_features is None or num_features <= 0:
        return "Please provide a positive number of features.", json.dumps(df_genetics.columns.tolist())

    # Determine if ref_var is categorical or numeric
    ref_series = data[ref_var].dropna()
    # Filter df and ref_series to align indices
    common_idx = df_genetics.index.intersection(ref_series.index)
    X = df_genetics.loc[common_idx]
    y = ref_series.loc[common_idx]

    X = X.select_dtypes(include=[np.number]).dropna(axis=1)

    if X.shape[0] < 2:
        return "Not enough data after alignment.", json.dumps(df_genetics.columns.tolist())

    # If categorical (low cardinality), use ANOVA F-test
    # If numeric, use correlation
    if is_categorical(y) or y.dtype == 'object':
        # Convert y to categories if needed
        y_factorized, _ = pd.factorize(y)
        # Apply ANOVA F-test
        F, p = f_classif(X, y_factorized)
        # Rank features by F score
        feature_scores = pd.Series(F, index=X.columns)
        selected = feature_scores.nlargest(num_features).index.tolist()
    else:
        # numeric variable: use absolute correlation
        # If variance=0 in some features, correlation might fail, drop constant columns
        X = X.loc[:, X.var() > 1e-12]
        corr_values = X.corrwith(y).abs()
        selected = corr_values.nlargest(num_features).index.tolist()

    return f"Selected top {num_features} features based on {ref_var}.", json.dumps(selected)


@app.callback(
    Output('scatter_feature_x', 'options'),
    Output('scatter_feature_y', 'options'),
    Input('selected_features_storage', 'data')
)
def update_scatter_feature_options(selected_features_json):
    selected_features = json.loads(selected_features_json) if selected_features_json else df_genetics.columns.tolist()
    opts = [{'label': f, 'value': f} for f in selected_features]
    return opts, opts


@app.callback(
    Output('feature_scatter_plot', 'figure'),
    Input('scatter_feature_x', 'value'),
    Input('scatter_feature_y', 'value'),
    State('selected_features_storage', 'data')
)
def update_feature_scatter_plot(x_feat, y_feat, selected_features_json):
    fig = go.Figure()
    selected_features = json.loads(selected_features_json) if selected_features_json else df_genetics.columns.tolist()
    df_sub = df_genetics[selected_features].dropna()

    if x_feat is not None and y_feat is not None and x_feat in df_sub.columns and y_feat in df_sub.columns:
        fig = px.scatter(df_sub, x=x_feat, y=y_feat, template='simple_white')
    else:
        fig.add_annotation(text="Select two features to visualize their relationship.", showarrow=False)
    return fig


@app.callback(
    Output('cluster_plot', 'figure'),
    Input('run_analysis_btn', 'n_clicks'),
    State('dr_method', 'value'),
    State('dr_n_components', 'value'),
    State('selected_features_storage', 'data'),
    State('cluster_method', 'value'),
    State('kmeans_n_clusters', 'value'),
    State('agg_n_clusters', 'value'),
    State('agg_linkage', 'value'),
    State('dbscan_eps', 'value'),
    State('dbscan_min_samples', 'value'),
    State('clustering_space', 'value')
)
def run_analysis(n_clicks, dr_method, n_components, selected_features_json,
                 cluster_method, kmeans_n, agg_n, agg_linkage, dbscan_eps, dbscan_min_samples,
                 clustering_space):
    if n_clicks is None:
        return go.Figure()

    selected_features = json.loads(selected_features_json) if selected_features_json else df_genetics.columns.tolist()

    # Prepare data
    df = df_genetics[selected_features].dropna()
    scaler = StandardScaler()
    X = scaler.fit_transform(df)

    # If clustering on original space:
    if clustering_space == 'original':
        cluster_input = X
    else:
        cluster_input = None  # will define after DR if needed

    # Dimensionality Reduction (for visualization):
    # We'll always do DR for visualization in this callback to produce cluster_plot
    # If clustering_space=='reduced', we cluster on DR directly.
    if dr_method == 'pca':
        dr_model = PCA(n_components=n_components)
        X_reduced = dr_model.fit_transform(X)
    elif dr_method == 'tsne':
        dr_model = TSNE(n_components=n_components, perplexity=30, n_iter=1000, random_state=42)
        X_reduced = dr_model.fit_transform(X)
    elif dr_method == 'umap':
        dr_model = umap.UMAP(n_components=n_components, random_state=42)
        X_reduced = dr_model.fit_transform(X)
    else:
        X_reduced = X  # fallback

    if clustering_space == 'reduced':
        cluster_input = X_reduced

    # Clustering
    if cluster_method == 'kmeans':
        n_clusters = kmeans_n if kmeans_n is not None else 3
        cluster_model = KMeans(n_clusters=n_clusters, random_state=42)
        labels = cluster_model.fit_predict(cluster_input)
    elif cluster_method == 'hierarchical':
        n_clusters = agg_n if agg_n is not None else 3
        linkage = agg_linkage if agg_linkage else 'ward'
        cluster_model = AgglomerativeClustering(n_clusters=n_clusters, linkage=linkage)
        labels = cluster_model.fit_predict(cluster_input)
    elif cluster_method == 'dbscan':
        eps = dbscan_eps if dbscan_eps is not None else 0.5
        min_samples = dbscan_min_samples if dbscan_min_samples is not None else 5
        cluster_model = DBSCAN(eps=eps, min_samples=min_samples)
        labels = cluster_model.fit_predict(cluster_input)
    else:
        labels = np.zeros(X_reduced.shape[0])

    # Visualization
    if n_components == 2:
        fig = px.scatter(x=X_reduced[:,0], y=X_reduced[:,1], color=labels.astype(str), 
                         title=f"Clusters via {cluster_method.upper()} (Clustering on {clustering_space})\nDR: {dr_method.upper()}",
                         template='simple_white',
                         color_discrete_sequence=color_discrete_sequence)
    else:
        fig = px.scatter_3d(x=X_reduced[:,0], y=X_reduced[:,1], z=X_reduced[:,2],
                            color=labels.astype(str),
                            title=f"Clusters via {cluster_method.upper()} (Clustering on {clustering_space})\nDR: {dr_method.upper()}",
                            template='simple_white',
                            color_discrete_sequence=color_discrete_sequence)
    fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
    return fig

In [14]:
########################################
# PAGE NAVIGATION
########################################
@app.callback(
    Output('page-content', 'children'),
    Input('url', 'pathname')
)
def display_page(pathname):
    if pathname == '/page1':
        return page1_layout
    elif pathname == '/page2':
        return page2_layout
    else:
        return page1_layout

if __name__ == '__main__':
    app.run_server(debug=True, port=8056)
