In [113]:
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
figs_folder = '/Users/alfredoparra/Library/CloudStorage/GoogleDrive-parra.h.alfredo@gmail.com/My Drive/My Documents/Career & Work/Research/2024-08 Cluster headaches/Figures/'


In [114]:
def transform_intensity(intensities, method='linear', max_value=100, power=2, base=np.e, scaling_factor=1.0):
    if method == 'linear':
        return intensities * (max_value / 10)
    
    elif method == 'piecewise_linear':
        breakpoint = 8
        lower_slope = (max_value / 2) / breakpoint
        upper_slope = (max_value / 2) / (10 - breakpoint)
        return np.where(intensities <= breakpoint,
                        lower_slope * intensities,
                        (max_value / 2) + upper_slope * (intensities - breakpoint))
    
    elif method == 'power':
        return (intensities / 10)**power * max_value
    
    elif method == 'exponential':
        return (base**(intensities/scaling_factor) - 1) * (max_value / (base**(10/scaling_factor) - 1))
    else:
        raise ValueError("Invalid method.")

def plot_intensity_transformations(max_value=1):
    intensities = np.linspace(0, 10, 101)  # 0 to 10 in steps of 0.1

    fig = go.Figure()
    
    methods = ['linear', 'power', 'exponential']
    labels = ['Linear', 'Power', 'Exponential']
    colors = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)']
    markers = ['circle', 'square', 'diamond']
    
    for method, label, color, marker in zip(methods, labels, colors, markers):
        transformed = transform_intensity(intensities, method=method, max_value=max_value, power=2)
        fig.add_trace(go.Scatter(
            x=intensities,
            y=transformed,
            mode='lines+markers',
            name=label,
            line=dict(color=color, width=2),
            marker=dict(
                color=color,
                size=[10 if x.is_integer() else 0 for x in intensities],
                symbol=marker,
                line=dict(width=2, color='white')  # Add a white border to markers
            )
        ))

    fig.update_layout(
        xaxis_title='Original Pain Intensity',
        yaxis_title='Transformed Intensity (Weight)',
        template='plotly_white',
        width=500,
        height=500,
        legend=dict(
            orientation="v",
            yanchor="top",
            y=.9,
            xanchor="left",
            x=0.1,
            itemsizing='constant',  # This ensures all legend items are the same size
            font=dict(size=12),
            bordercolor="Grey",
            borderwidth=1,
        ),
        xaxis=dict(
            range=[0, 10],
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(0, 0, 0, 0.1)',
            color='black',
            linecolor='black',
            ticks='outside',
            tickcolor='black',
            ticklen=6,
            tickmode='array',
            tickvals=list(range(11)),  # 0 to 10
            ticktext=[str(i) for i in range(11)],
        ),
        yaxis=dict(
            range=[0, 1],
            tickmode='array',
            tickvals=[i/10 for i in range(11)],  # 0 to 1 in steps of 0.1
            ticktext=['0'].append([f'{i/10:.1f}' for i in range(1,11)]),
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(0, 0, 0, 0.1)',
            color='black',
            linecolor='black',
            ticks='outside',
            tickcolor='black',
            ticklen=6,
        ),
        margin=dict(l=50, r=50, t=10, b=50),

    )

    fig.show()
    fig.write_image(figs_folder + "Intensity-weights.png", scale=4)

# Call the function
plot_intensity_transformations(max_value=1)