In [1]:
import numpy as np
import plotly.graph_objs as go
from ipywidgets import interact, IntSlider, Checkbox, Layout, Dropdown
from IPython.display import display
import pandas as pd
import os
os.chdir('/Volumes/gech/cna/RCANE')

In [4]:
seg_info = pd.read_csv('data/start_end_chr_in_segs.csv')
file = np.load('data/predict/TCGA_test.npz', allow_pickle=True)
cna = file['cna']
chr_seg_nums = np.loadtxt('data/chr_index.csv', delimiter=',', dtype=int)
chr_index = np.repeat(np.array(range(23))+1,np.array(chr_seg_nums)).reshape(-1,1).astype(str)
chr_index[chr_index=='23'] = 'X'
cna = np.where(np.isnan(cna), None, cna)
profile = file['profile']

In [5]:
# Click the "Run" button to get a visualization of the data


# Define the colormap bounds and the two colormaps (RdBu and Viridis)
vmin = -0.7
vmax = 0.7
colormap = 'RdBu'
custom_chr_index = np.concatenate((profile.reshape(-1,1,1).repeat(cna.shape[1],axis=1), np.stack((seg_info['chr'].to_numpy().reshape(1, cna.shape[1]),seg_info['start'].to_numpy().reshape(1, cna.shape[1]), seg_info['end'].to_numpy().reshape(1, cna.shape[1])),axis=-1).repeat(cna.shape[0],axis=0)),axis=-1)
# Create the initial figure with the first slice of the data and the RdBu colormap
fig = go.FigureWidget(data=go.Heatmap(
    z=cna[0],  # Start with the first slice of the data
    colorscale=colormap,  # Use the RdBu colormap initially
    zmin=vmin, zmax=vmax,  # Set the z-axis range to match vmin and vmax
    customdata=custom_chr_index[0],  # Custom data for hover information
    hovertemplate='<b>Sample</b>: %{customdata[0]}<br><b>Chromosome</b>: %{customdata[1]}<br><b>Start</b>: %{customdata[2]}<br><b>End</b>: %{customdata[3]}<br><b>Intensity</b>: %{z:.3f}<extra></extra>'  # Custom hover information
))

# Add title and axis labels
fig.update_layout(
    title=f'{profile[0]}',
xaxis=dict(
        showticklabels=False,  # Hide x-axis tick labels
        ticks="",  # No ticks
        showgrid=False,  # No grid lines
        title="",  # No title
    ),
    yaxis=dict(
        showticklabels=False,  # Hide y-axis tick labels
        ticks="",  # No ticks
        showgrid=False,  # No grid lines
        title="",  # No title
    )
)

# Function to update the plot based on the selected data slice and colormap
def update_plot(index, chr, display):
    with fig.batch_update():
        if display:
            temp_data = cna
            temp_customdata = np.array(custom_chr_index)
        else:
            temp_data = cna[index].reshape(1,-1)
            temp_customdata = custom_chr_index[index].reshape(1, custom_chr_index.shape[1],-1)
        if chr == 'All':
            fig.data[0].z = temp_data
            fig.data[0].customdata = temp_customdata
        else:
            fig.data[0].z = np.take(temp_data, np.where(chr_index==chr)[0], axis=-1)
            fig.data[0].customdata = np.take(temp_customdata, np.where(chr_index==chr)[0], axis=1)
        fig.layout.title = 'All samples' if display else f'{profile[index]}'

# Create an interactive slider to choose the slice index
slider = IntSlider(min=0, max=cna.shape[0]-1, step=1, value=0,
                   description='Sample index',
                   style={'description_width': '100px'}
                   )

def on_checkbox_change(change):
    if change['new']:  # If the checkbox is checked
        slider.disabled = True  # Disable the slider
    else:
        slider.disabled = False  # Enable the slider

# Create a checkbox to toggle between the two colormaps
chr_Dropdown = Dropdown(
    options = ['All'] + [f"{i}" for i in range(1, 23)] + ['X'],
    value = 'All',
    description = 'Chromosome',
    layout=Layout(width='200px'),
    style={'description_width': '100px'}
)
Display_all_samples = Checkbox(
    value=False,  # Default is unchecked (slider is enabled)
    description='Display all samples',
    disabled=False,
    style={'description_width': '100px'}
)
Display_all_samples.observe(on_checkbox_change, names='value')



# Link the slider and checkbox to the update function
interact(update_plot, index=slider, chr=chr_Dropdown, display=Display_all_samples)

# Display the plot (only once)
display(fig)

interactive(children=(IntSlider(value=0, description='Sample index', max=2225, style=SliderStyle(description_w…

FigureWidget({
    'data': [{'colorscale': [[0.0, 'rgb(103,0,31)'], [0.1, 'rgb(178,24,43)'],
                             [0.2, 'rgb(214,96,77)'], [0.3, 'rgb(244,165,130)'],
                             [0.4, 'rgb(253,219,199)'], [0.5, 'rgb(247,247,247)'],
                             [0.6, 'rgb(209,229,240)'], [0.7, 'rgb(146,197,222)'],
                             [0.8, 'rgb(67,147,195)'], [0.9, 'rgb(33,102,172)'],
                             [1.0, 'rgb(5,48,97)']],
              'customdata': array([[['TCGA.HT.A615.01A.11R.A29R.07', '1', 14404, 778626],
                                    ['TCGA.HT.A615.01A.11R.A29R.07', '1', 778770, 1063288],
                                    ['TCGA.HT.A615.01A.11R.A29R.07', '1', 1070966, 1311677],
                                    ...,
                                    ['TCGA.HT.A615.01A.11R.A29R.07', 'X', 154010500, 154518631],
                                    ['TCGA.HT.A615.01A.11R.A29R.07', 'X', 154531391, 155669944],
                