In [1]:
import pandas as pd
import plotly.graph_objects as go
from plotly.offline import iplot, init_notebook_mode
import ipywidgets as widgets
from IPython.display import display
import numpy as np

# Initialize plotly offline mode
init_notebook_mode(connected=True)

print("📊 Loading perovskite data...")

# Load data from Excel
data = pd.read_excel('PlottedInks.xlsx', sheet_name='Sheet2')
print(f"✅ Loaded {len(data)} data points")
print(f"📋 Columns: {list(data.columns)}")
print(f"🧪 Compounds: {list(data['Solute'].unique())}")
print(f"⚖️ Stability types: {list(data['Stability'].unique())}")

# Show data
#display(data.head())

# Get color options 
color_options = ['DN', 'BP', 'heat of Vap', 'hov_temp', 'mw', 'vis_temp', 'Viscosity']
print(f"🎨 Color options: {color_options}")

# Get solute options
all_solutes = sorted(data['Solute'].unique().tolist())
print(f"🧪 Available solutes: {all_solutes}")

# Create 3D plot
def make_3d_plot(color_by='DN'):
    print(f"Creating plot with colorbar for {color_by}...")  # Debug
    
    # Marker symbols mapping
    marker_map = {'Stable': 'circle', 'Semi-stable': 'diamond', 'Not stable': 'x'}
    
    # Get marker symbols for each point
    symbols = [marker_map.get(stability, 'circle') for stability in data['Stability']]
    
    # Adjust marker sizes (make X smaller)
    sizes = [8 if symbol == 'x' else 12 for symbol in symbols]
    
    # Check color data
    print(f"Color data range: {data[color_by].min():.2f} to {data[color_by].max():.2f}")  # Debug
    
    # Create figure with explicit colorbar
    fig = go.Figure()
    
    # Add the main scatter plot
    scatter = go.Scatter3d(
        x=data['D'],
        y=data['H'],
        z=data['P'],
        mode='markers',
        marker=dict(
            size=sizes,
            color=data[color_by],
            colorscale='Agsunset',
            symbol=symbols,
            opacity=0.8,
            cmin=data[color_by].min(),  # Explicit color range
            cmax=data[color_by].max(),
            colorbar=dict(
                title=color_by,
                titleside="right",
                x=1.02,  # Try closer position first
                len=0.8,
                thickness=15
            )
        ),
        text = [create_hover_text(row, color_by) for _, row in data.iterrows()],
        hovertemplate='%{text}<extra></extra>',
        name='Data Points'
    )
    
    fig.add_trace(scatter)
    
    # Add legend traces
    for stability in data['Stability'].unique():
        fig.add_trace(go.Scatter3d(
            x=[None], y=[None], z=[None],
            mode='markers',
            marker=dict(
                size=12 if marker_map.get(stability) != 'x' else 8,
                symbol=marker_map.get(stability, 'circle'),
                color='gray'
            ),
            name=stability,
            showlegend=True
        ))
    
    fig.update_layout(
        title=f'3D Perovskite Plot - Colored by {color_by}',
        scene=dict(
            xaxis_title='D Parameter',
            yaxis_title='H Parameter',
            zaxis_title='P Parameter'
        ),
        width=900,
        height=600,
        margin=dict(r=120),
        legend=dict(x=0.02, y=0.98)
    )
    
    print("Figure created, showing...")  # Debug
    return fig

# Function to create hover text with all solvents for same solute
def create_hover_text(row, color_by):
    solute = row['Solute']
    current_solvent = row['Solvent']
    stability = row['Stability']
    
    # Find all solvents for this solute
    same_solute_data = data[data['Solute'] == solute]
    all_solvents = same_solute_data['Solvent'].tolist()
    
    # Create hover text
    hover_text = f"<b>{solute}</b><br>"
    hover_text += f"<b>Current Sample:</b><br>"
    hover_text += f"• Solvent: {current_solvent}<br>"
    hover_text += f"• Stability: {stability}<br>"
    hover_text += f"• {color_by}: {row[color_by]:.2f}<br><br>"
    
    if len(all_solvents) > 1:
        hover_text += f"<b>All Solvents for {solute} ({len(all_solvents)} samples):</b><br>"
        for i, solvent in enumerate(all_solvents, 1):
            if solvent == current_solvent:
                hover_text += f"• <b>{i}. {solvent}</b> ← Current<br>"
            else:
                hover_text += f"• {i}. {solvent}<br>"
    else:
        hover_text += f"<b>Only sample for {solute}</b>"
    
    return hover_text

# Create 3D plot with multiple solute selection
def make_3d_plot(color_by='DN', selected_solutes=None):
    if selected_solutes is None or len(selected_solutes) == 0:
        # If nothing selected, show all
        selected_solutes = all_solutes
        title_suffix = "All Solutes"
    else:
        # Convert tuple to list if needed
        selected_solutes = list(selected_solutes)
        if len(selected_solutes) == 1:
            title_suffix = selected_solutes[0]
        elif len(selected_solutes) == len(all_solutes):
            title_suffix = "All Solutes"
        else:
            title_suffix = f"{len(selected_solutes)} Selected Solutes"
    
    print(f"Creating plot - Color: {color_by}, Solutes: {selected_solutes}")
    
    # Filter data based on solute selection
    plot_data = data[data['Solute'].isin(selected_solutes)].copy()
    
    print(f"📊 Plotting {len(plot_data)} data points")
    
    if len(plot_data) == 0:
        print("⚠️ No data to plot for selected solutes")
        return go.Figure().update_layout(title="No data for selected solutes")
    
    # Marker symbols mapping
    marker_map = {'Stable': 'circle', 'Semi-stable': 'diamond', 'Not stable': 'x'}
    
    # Get marker symbols for each point
    symbols = [marker_map.get(stability, 'circle') for stability in plot_data['Stability']]
    
    # Adjust marker sizes (make X smaller)
    sizes = [8 if symbol == 'x' else 12 for symbol in symbols]
    
    # Create figure
    fig = go.Figure()
    
    # Add the main scatter plot
    fig.add_trace(go.Scatter3d(
        x=plot_data['D'],
        y=plot_data['H'],
        z=plot_data['P'],
        mode='markers',
        marker=dict(
            size=sizes,
            color=plot_data[color_by],
            colorscale='Agsunset',
            symbol=symbols,
            opacity=0.8,
            colorbar=dict(
                title=color_by,
                titleside="right",
                x=1.02,
                len=0.8,
                thickness=15
            )
        ),
        text=[create_hover_text(row, color_by) for _, row in plot_data.iterrows()],
        hovertemplate='%{text}<extra></extra>',
        name='Data Points'
    ))
    
    # Add legend traces (only for stability types present in filtered data)
    present_stabilities = plot_data['Stability'].unique()
    for stability in present_stabilities:
        fig.add_trace(go.Scatter3d(
            x=[None], y=[None], z=[None],
            mode='markers',
            marker=dict(
                size=12 if marker_map.get(stability) != 'x' else 8,
                symbol=marker_map.get(stability, 'circle'),
                color='gray'
            ),
            name=stability,
            showlegend=True
        ))
    
    fig.update_layout(
        title=f'3D Perovskite Plot - {title_suffix} - Colored by {color_by}',
        scene=dict(
            xaxis_title='D Parameter',
            yaxis_title='H Parameter',
            zaxis_title='P Parameter'
        ),
        width=900,
        height=600,
        margin=dict(r=120),
        legend=dict(x=0.02, y=0.98)
    )
    
    return fig

# Enhanced interactive controls with multi-select
def update_plot(color_param, selected_solutes):
    if len(selected_solutes) == 0:
        print("⚠️ No solutes selected, showing all")
        selected_solutes = all_solutes
    
    print(f"🎨 Updating plot - Color: {color_param}")
    print(f"🧪 Selected solutes ({len(selected_solutes)}): {list(selected_solutes)}")
    
    fig = make_3d_plot(color_param, selected_solutes)
    fig.show()
    
    # Show detailed summary
    filtered_data = data[data['Solute'].isin(selected_solutes)]
    print(f"\n📈 Dataset Summary:")
    print(f"   • Total data points: {len(filtered_data)}")
    print(f"   • Solutes shown: {len(selected_solutes)}")
    print(f"   • Stability breakdown:")
    
    stability_counts = filtered_data['Stability'].value_counts()
    for stability, count in stability_counts.items():
        print(f"     - {stability}: {count} samples")
    
    # Show solvent count per solute
    if len(selected_solutes) <= 3:  # Only show details for few solutes
        print(f"\n🧪 Solvent details:")
        for solute in selected_solutes:
            solute_data = filtered_data[filtered_data['Solute'] == solute]
            solvents = solute_data['Solvent'].tolist()
            print(f"   • {solute}: {len(solvents)} solvent(s)")
            if len(solvents) <= 5:  # Only list if not too many
                for i, solvent in enumerate(solvents, 1):
                    print(f"     {i}. {solvent}")

# Create multi-select widgets
print("🎛️ Interactive Multi-Select Controls:")
print("💡 Hold Ctrl/Cmd to select multiple solutes, or leave empty for all")

# Color dropdown
color_widget = widgets.Dropdown(
    options=color_options,
    value='DN',
    description='Color by:',
    style={'description_width': '80px'}
)

# Multi-select for solutes
solute_widget = widgets.SelectMultiple(
    options=all_solutes,
    value=all_solutes,  # Start with all selected
    description='Solutes:',
    style={'description_width': '80px'},
    layout={'height': '150px', 'width': '400px'}  # Make it taller
)

# Quick selection buttons
def select_all_solutes(b):
    solute_widget.value = all_solutes

def clear_solutes(b):
    solute_widget.value = []

select_all_btn = widgets.Button(description="Select All", button_style='info')
clear_btn = widgets.Button(description="Clear All", button_style='warning')

select_all_btn.on_click(select_all_solutes)
clear_btn.on_click(clear_solutes)

# Display controls
display(widgets.HBox([color_widget]))
display(widgets.VBox([
    solute_widget,
    widgets.HBox([select_all_btn, clear_btn])
]))

# Create interactive plot
widgets.interact(update_plot, 
                color_param=color_widget, 
                selected_solutes=solute_widget)

📊 Loading perovskite data...
✅ Loaded 11 data points
📋 Columns: ['Solute', 'Solvent', 'D', 'H', 'P', 'DN', 'BP', 'heat of Vap', 'hov_temp', 'mw', 'vis_temp', 'Viscosity', 'Stability']
🧪 Compounds: ['MAPbI3', 'MABrMAClCsIPbBr2PbI2FAl', 'CsFAPbI3', 'MaPbI3 0.32', 'FAPbI3']
⚖️ Stability types: ['Stable', 'Not stable', 'Semi-stable']
🎨 Color options: ['DN', 'BP', 'heat of Vap', 'hov_temp', 'mw', 'vis_temp', 'Viscosity']
🧪 Available solutes: ['CsFAPbI3', 'FAPbI3', 'MABrMAClCsIPbBr2PbI2FAl', 'MAPbI3', 'MaPbI3 0.32']
🎛️ Interactive Multi-Select Controls:
💡 Hold Ctrl/Cmd to select multiple solutes, or leave empty for all


HBox(children=(Dropdown(description='Color by:', options=('DN', 'BP', 'heat of Vap', 'hov_temp', 'mw', 'vis_te…

VBox(children=(SelectMultiple(description='Solutes:', index=(0, 1, 2, 3, 4), layout=Layout(height='150px', wid…

interactive(children=(Dropdown(description='Color by:', options=('DN', 'BP', 'heat of Vap', 'hov_temp', 'mw', …

<function __main__.update_plot(color_param, selected_solutes)>