In [42]:
from bokeh.layouts import gridplot
from bokeh.models import ColumnDataSource, DataTable, TableColumn, BoxSelectTool, LassoSelectTool
from bokeh.plotting import figure, show, output_notebook
from bokeh.layouts import row, column
from bokeh.transform import factor_cmap
from bokeh.palettes import Spectral11, Category20
from bokeh.io import push_notebook
from bokeh.models import Legend, LegendItem, CustomJS
import pandas as pd
import numpy as np
from bokeh.models import CDSView, GroupFilter

output_notebook()

def create_scatter_dict(x_col, y_col, color_column=None,title="Projection"):
    """
    Create a dictionary with details for a scatter plot.
    
    Parameters:
    - x_col: Column name for x-axis
    - y_col: Column name for y-axis
    - color_column: Optional column name for color mapping
    
    Returns:
    - Dictionary with scatter plot details.
    """
    scatter_dict = {
        "x": x_col,
        "y": y_col,
        "title": title,
        "x_axis_label": x_col,
        "y_axis_label": y_col,
        "color_column": color_column
    }
    return scatter_dict

def plot_data_with_table(df, scatter_plots, time_series_x, time_series_y, table_columns,  time_series_color_column=None):
    # Setting up the main data source
    source = ColumnDataSource(df)
    
    # List to store individual plots
    plots = []

    # Loop over scatter_plots to create individual figures
    for plot_info in scatter_plots:
        x_col = plot_info["x"]
        y_col = plot_info["y"]
        color_column = plot_info.get("color_column", None)
        x_axis_label = plot_info.get("x_axis_label", "X Axis")
        y_axis_label = plot_info.get("y_axis_label", "Y Axis")
        plot_title = plot_info.get("title", "Scatter Plot")
        
        # Generate palette for the color column
        if color_column:
            factors = df[color_column].unique().tolist()
            palette = Category20[20][:len(factors)]
            mapper = factor_cmap(color_column, palette=palette, factors=factors)
        else:
            mapper = None

        # Initialize individual plot
        p = figure(width=600, height=600, title=plot_title, tools="lasso_select,box_select,reset")

        r = p.scatter(x=x_col, y=y_col, source=source, color=mapper if color_column else "navy", legend_field=color_column)

        p.xaxis.axis_label = x_axis_label
        p.yaxis.axis_label = y_axis_label

        # Adding selection tools
        #lasso_select = LassoSelectTool()
        #box_select = BoxSelectTool()
        #p.add_tools(lasso_select, box_select)

        # Setting up the legend
        legend_items = [LegendItem(label=color_column, renderers=[r])]
        legend = Legend(items=legend_items, location="center")
        p.add_layout(legend, "right")
        p.legend.visible = False

        plots.append(p)

    groups = df['Group'].unique()

    if time_series_color_column:
            tsfactors = df[time_series_color_column].unique().tolist()
            tspalette = Category20[20][:len(factors)]
            tsmapper = factor_cmap(time_series_color_column, palette=tspalette, factors=tsfactors)
    else:
            tsmapper = None

    
    for group in groups:
        # Create a view for the current group
        view = CDSView(source=source, filters=[GroupFilter(column_name='Group', group=group)])

        # Create time series plot for the current group using the same source but with the group-specific view
        p_time_series = figure(width=600, height=600, title=f'Time Series for Group: {group}', tools='lasso_select,box_select,reset')
        p_time_series.multi_line(xs=time_series_x, ys=time_series_y, source=source, view=view, color=tsmapper if time_series_color_column else "navy", legend_field=time_series_color_column)  # or use a color mapper if needed
        p_time_series.xaxis.axis_label = time_series_x
        p_time_series.yaxis.axis_label = time_series_y
        p_time_series.legend.visible = False
        plots.append(p_time_series)

    
    
    # Setting up the table
    table_source = ColumnDataSource(df[table_columns])
    fresh_table_source = ColumnDataSource(df[table_columns])
    columns = [TableColumn(field=col, title=col) for col in table_columns]
    data_table = DataTable(source=table_source, columns=columns, width=600, height=600, fit_columns=False, index_position=None, selectable=True)
    
    summary_table_source=ColumnDataSource(df[table_columns]);
    summary_data_table = DataTable(source=summary_table_source, columns=[TableColumn(field=col, title=col) for col in table_columns], width=600, height=80, fit_columns=False, index_position=None, selectable=True)
    
    
    change = CustomJS(args=dict(source=source, table_source=table_source, original=fresh_table_source, table_columns=table_columns, summary_table_source=summary_table_source), code="""
    

function computeStatistics(dataObj) {
    let result = {};

    for (let col in dataObj) {
        let values = dataObj[col];
        let isNumeric = typeof values[0] === "number";

        if (isNumeric) {
            values.sort((a, b) => a - b); // Sort numerically
            let mid = Math.floor(values.length / 2);
            result[col] = (values.length % 2) === 0 ? [(values[mid - 1] + values[mid]) / 2.0] : [values[mid]];
        } else {
            let freqMap = {};

            for (let val of values) {
                if (!freqMap[val]) {
                    freqMap[val] = 1;
                } else {
                    freqMap[val]++;
                }
            }

            // Sort the unique values by their frequencies in descending order
            let sortedCategories = Object.keys(freqMap).sort((a, b) => freqMap[b] - freqMap[a]);
            
            let topCategories = [];
            let numCategoriesToInclude = Math.min(3, sortedCategories.length);

            for (let i = 0; i < numCategoriesToInclude; i++) {
                let category = sortedCategories[i];
                let percentage = (freqMap[category] / values.length * 100).toFixed(2);
                topCategories.push(`${category} (${percentage} %)`);
            }
            
            result[col] = [topCategories.join(", ")];
        }
    }
    
    return result;
}

function createFieldToIndexMap(arr) {
    let resultMap = {};
    
    for (let i = 0; i < arr.length; i++) {
        let obj = arr[i];
        resultMap[obj.field] = i;
    }
    
    return resultMap;
}

console.log("tc",table_columns);
const inds = cb_obj.indices;
console.log(cb_obj);
const ts_d = table_source.data;
const fs_d = original.data;
table_columns.forEach(col => {
    ts_d[col] = [];
    inds.forEach(index => {
        ts_d[col].push(fs_d[col][index]); 
    })
});
table_source.change.emit();

const stats = computeStatistics(ts_d);
Object.keys(stats).forEach(stat_key => {
    summary_table_source.data[stat_key] = stats[stat_key];
})
summary_table_source.change.emit();
debugger;

        """)
    source.selected.js_on_change("indices", change)

    # Combine plots and table in a grid layout
    plots.append(data_table)
    layout = gridplot([[plots[0],plots[1],plots[5]],
                       [None, None, summary_data_table],
                       [plots[3],plots[4],plots[2]]])

    handle = show(layout, notebook_handle=True)

In [34]:
demo_projection = pd.read_csv('derived_data/demographics-with-projection.csv')
sc_projection = pd.read_csv("derived_data/subject-chars-with-projection.csv");


# Custom aggregation functions to handle NaN values
def custom_list(series):
    return [x if not pd.isnull(x) else 'NaN' for x in series]
    
agg_funcs_pi = {
    'Pain Interference': custom_list,
    'Pain Interference (Smoothed)': custom_list,
    'Visit Number': custom_list,
    'Visit Count': 'first'
}

agg_funcs_peg = {
    'Peg Score': custom_list,
    'Peg Score (Smoothed)': custom_list,
    'Visit Number': custom_list,
    'Visit Count': 'first'
}

pain_interference = (pd
                     .read_csv('derived_data/pain-interference-smoothed.csv')
                     .sort_values(by='Visit Number')
                     .groupby('USUBJID')
                     .agg(agg_funcs_pi)
                     .reset_index()
                     .rename(columns={"Visit Number":"Visit Number (Pain Interference)"}))

# Function to check if the last 4 values of a list are constant
def is_last_4_constant(series):
    if len(series) < 4:
        return False
    return len(set(series[-4:])) == 1
            
            # Apply the function to the 'Pain Interference (Smoothed)' column of the pain_interference dataframe
pain_interference['is_constant_last_4'] = pain_interference['Pain Interference (Smoothed)'].apply(is_last_4_constant)
pain_interference = pain_interference[pain_interference['Visit Count']>7]


peg = pd.read_csv('derived_data/peg_score_ts.csv').sort_values(by='Visit Number')
peg = peg[peg['Peg Score']<=10];

peg = peg.groupby('USUBJID').agg(agg_funcs_peg).reset_index()
peg.rename(columns={'Visit Number': 'Visit Number (Peg Score)'}, inplace=True)

# Function to check if the last 4 values of a list are constant
def is_last_4_constant(series):
    if len(series) < 4:
        return False
    return len(set(series[-4:])) == 1
            
            # Apply the function to the 'Pain Interference (Smoothed)' column of the peg dataframe
peg['is_constant_last_4'] = peg['Peg Score (Smoothed)'].apply(is_last_4_constant)
peg = peg[peg['Visit Count']>1]

total_df = demo_projection.merge(sc_projection, on="USUBJID", how="left");
total_df = total_df.merge(pain_interference, on="USUBJID", how="left");
total_df = total_df.merge(peg, on="USUBJID", how="left");



In [43]:
plot_data_with_table(total_df, [create_scatter_dict('E1','E2','Gender, Race, Ethnicity',title="Demographics (DM) Projection"),
                                create_scatter_dict('SCAE1','SCAE2','Gender, Race, Ethnicity',title="Subject Characteristics (SC) Projection")],
                     "Visit Number (Peg Score)",
                     'Peg Score',
                     ['Group','AGE','GENDER','RACE','ETHNIC','EMPSTAT', 'MARISTAT', 'HHNUM', 'PAINDUR', 'BPMORE', 'BPSURG',
                       'BPSURGTM', 'BPSURGSF', 'BPUNEMP', 'BPWKCOMP', 'BPLWSUIT', 'BPDISAB',
                       'HEIGHT', 'WEIGHT', 'HHINCOME'], 'Gender, Race, Ethnicity')

In [30]:

peg = pd.read_csv('derived_data/peg_score_ts.csv').sort_values(by='Visit Number').groupby('USUBJID').agg(agg_funcs_peg).reset_index()
peg.rename(columns={'Visit Number': 'Visit Number (Peg Score)'}, inplace=True)

# Function to check if the last 4 values of a list are constant
def is_last_4_constant(series):
    if len(series) < 4:
        return False
    return len(set(series[-4:])) == 1
            
            # Apply the function to the 'Pain Interference (Smoothed)' column of the peg dataframe
peg['is_constant_last_4'] = peg['Peg Score (Smoothed)'].apply(is_last_4_constant)
peg['Visit Count'][peg['Visit Count']>1].mean()

3.8940316686967114

In [127]:
total_df[["USUBJID","Visit Number_x","Visit Number_y"]]


Unnamed: 0,USUBJID,Visit Number_x,Visit Number_y
0,P2CS-0401-00001,"[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...",
1,P2CS-0401-00002,"[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...",
2,P2CS-0401-00003,"[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...",
3,P2CS-0401-00004,"[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...",
4,P2CS-0401-00005,"[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,...",
...,...,...,...
1802,LB3P-0301-00591,,
1803,LB3P-0301-00595,,
1804,LB3P-0301-00592,,
1805,LB3P-0301-00603,,


In [84]:
peg

Unnamed: 0,USUBJID,Pain Interference,Pain Interference (Smoothed),Visit Number
0,BEST-1401-00209,"[NaN, 61.2, 60.73333333333333, 60.266666666666...","[NaN, 61.2, 60.90833333333333, 60.580952380952...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
1,BEST-1401-00225,"[NaN, 52.0, 52.3, 52.6, 52.9, 53.2, 53.5, 53.8...","[NaN, 52.0, 52.1875, 52.39795918367347, 52.628...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
2,BEST-1401-00331,"[NaN, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5...","[NaN, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
3,BEST-1401-00343,"[NaN, 49.6, 49.6, 49.6, 49.6, 49.6, 49.6, 49.6...","[NaN, 49.6, 49.6, 49.6, 49.6, 49.6, 49.6, 49.6...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
4,BEST-1401-00370,"[NaN, 59.9, 59.9, 59.9, 59.9, 59.9, 59.9, 59.9...","[NaN, 59.9, 59.9, 59.9, 59.9, 59.9, 59.9, 59.9...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
...,...,...,...,...
1815,PHENO-0901-00786,"[NaN, 61.2, 61.2, 61.2, 61.2, 61.2, 61.2, 61.2...","[NaN, 61.2, 61.2, 61.2, 61.2, 61.2, 61.2, 61.2...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
1816,PHENO-0901-00787,"[NaN, 62.5, 62.5, 62.5, 62.5, 62.5, 62.5, 62.5...","[NaN, 62.5, 62.5, 62.5, 62.5, 62.5, 62.5, 62.5...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
1817,PHENO-0901-00788,"[NaN, 55.6, 55.6, 55.6, 55.6, 55.6, 55.6, 55.6...","[NaN, 55.6, 55.6, 55.6, 55.6, 55.6, 55.6, 55.6...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."
1818,PHENO-0901-00790,"[NaN, 57.1, 57.1, 57.1, 57.1, 57.1, 57.1, 57.1...","[NaN, 57.1, 57.1, 57.1, 57.1, 57.1, 57.1, 57.1...","[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,..."


In [87]:
    pd.read_csv('derived_data/pain-interference-smoothed.csv')

Unnamed: 0,USUBJID,Visit Number,Pain Interference,Pain Interference (Smoothed),Visit Count,Pain Interference Start,Pain Interference End,Change,Group
0,BEST-1412-00101,-1,,,4,64.766667,58.5,6.266667,Improved
1,BEST-1412-00101,0,66.600000,66.600000,4,64.766667,58.5,6.266667,Improved
2,BEST-1412-00101,1,65.683333,66.027083,4,64.766667,58.5,6.266667,Improved
3,BEST-1412-00101,2,64.766667,65.384014,4,64.766667,58.5,6.266667,Improved
4,BEST-1412-00101,3,63.850000,64.679044,4,64.766667,58.5,6.266667,Improved
...,...,...,...,...,...,...,...,...,...
81895,BPCR01-0201-00282,43,52.000000,52.000014,3,69.000000,52.0,17.000000,Improved
81896,BPCR01-0201-00282,44,52.000000,52.000008,3,69.000000,52.0,17.000000,Improved
81897,BPCR01-0201-00282,45,52.000000,52.000005,3,69.000000,52.0,17.000000,Improved
81898,BPCR01-0201-00282,49,52.000000,52.000003,3,69.000000,52.0,17.000000,Improved


In [104]:
total_df.columns

Index(['Unnamed: 0', 'E1', 'E2', 'USUBJID', 'STUDYID_x', 'AGE', 'Age Group',
       'tag', 'count', 'Gender, Race, Ethnicity',
       'Gender, Race, Ethnicity (Count)', 'GENDER', 'RACE', 'ETHNIC',
       'STUDYID (Count)', 'count_right', 'RACE, ETHNICITY',
       'Pain Interference Start', 'Pain Interference End', 'Visit Count_x',
       'Change', 'Group', 'SCAE1', 'SCAE2', 'STUDYID_y', 'GENIDENT', 'EDLEVEL',
       'EMPSTAT', 'MARISTAT', 'HHNUM', 'PAINDUR', 'BPMORE', 'BPSURG',
       'BPSURGTM', 'BPSURGSF', 'BPUNEMP', 'BPWKCOMP', 'BPLWSUIT', 'BPDISAB',
       'HEIGHT', 'WEIGHT', 'HHINCOME', 'Pain Interference',
       'Pain Interference (Smoothed)', 'Visit Number', 'Visit Count_y',
       'is_constant_last_4'],
      dtype='object')

In [113]:
peg = pd.read_csv('derived_data/peg_score_ts.csv')
peg

Unnamed: 0,USUBJID,Visit Number,Peg Score,Peg Score (Smoothed),Visit Count,Peg Score Start,Peg Score End,Change,Group
0,P2CS-0401-00204,-1,2.11110,2.111100,6,5.43332,5.3333,0.10002,Static
1,P2CS-0401-00204,0,5.05555,3.746906,6,5.43332,5.3333,0.10002,Static
2,P2CS-0401-00204,2,8.00000,5.489977,6,5.43332,5.3333,0.10002,Static
3,P2CS-0401-00204,3,6.66665,5.888579,6,5.43332,5.3333,0.10002,Static
4,P2CS-0401-00204,4,5.33330,5.723396,6,5.43332,5.3333,0.10002,Static
...,...,...,...,...,...,...,...,...,...
83071,P2CS-0401-00161,42,1.88890,1.888900,1,1.88890,1.8889,0.00000,Static
83072,P2CS-0401-00161,43,1.88890,1.888900,1,1.88890,1.8889,0.00000,Static
83073,P2CS-0401-00161,44,1.88890,1.888900,1,1.88890,1.8889,0.00000,Static
83074,P2CS-0401-00161,45,1.88890,1.888900,1,1.88890,1.8889,0.00000,Static


In [14]:
total_df.columns

Index(['Unnamed: 0', 'E1', 'E2', 'USUBJID', 'STUDYID_x', 'AGE', 'Age Group',
       'tag', 'count', 'Gender, Race, Ethnicity',
       'Gender, Race, Ethnicity (Count)', 'GENDER', 'RACE', 'ETHNIC',
       'STUDYID (Count)', 'count_right', 'RACE, ETHNICITY',
       'Pain Interference Start', 'Pain Interference End', 'Visit Count_x',
       'Change', 'Group', 'SCAE1', 'SCAE2', 'STUDYID_y', 'GENIDENT', 'EDLEVEL',
       'EMPSTAT', 'MARISTAT', 'HHNUM', 'PAINDUR', 'BPMORE', 'BPSURG',
       'BPSURGTM', 'BPSURGSF', 'BPUNEMP', 'BPWKCOMP', 'BPLWSUIT', 'BPDISAB',
       'HEIGHT', 'WEIGHT', 'HHINCOME', 'Pain Interference',
       'Pain Interference (Smoothed)', 'Visit Number (Pain Interference)',
       'Visit Count_y', 'is_constant_last_4_x', 'Peg Score',
       'Peg Score (Smoothed)', 'Visit Number Peg', 'Visit Count',
       'is_constant_last_4_y'],
      dtype='object')