In [None]:
import pandas as pd
import plotly.graph_objects as go

# Configurable path to the CSV file with test data
PATH = "../results/mlp_test_full_data.csv"  # Adjust path as needed

df = pd.read_csv(PATH)
df.columns

In [None]:
# Define atom types with associated colors in a dictionary for better maintainability
atom_colors = {
    'C': 'red',
    'Br': 'blue',
    'Cl': 'green',
    'F': 'orange',
    'H': 'purple'
}

# ─── Ensure C columns exist and are zero if missing ────────────────────────────
for label in ['T', 'P']:
    for coord in ['x0', 'y0', 'z0']:
        col_name = f'{label}_C_{coord}'
        if col_name not in df.columns:
            df[col_name] = 0.0  # Assign zeros if missing

fig = go.Figure()

# Iterate through atom types clearly and efficiently
for atom, color in atom_colors.items():
    fig.add_trace(go.Scatter3d(
        x=df[f'T_{atom}_x0'],
        y=df[f'T_{atom}_y0'],
        z=df[f'T_{atom}_z0'],
        mode='markers',
        marker=dict(size=4, color=color, opacity=0.8),
        name=f'Atom: {atom}'
    ))

# Update layout with enhanced clarity and visual appeal
fig.update_layout(
    title=dict(text='3D Atom True Positions', x=0.5),
    scene=dict(
        xaxis=dict(title='X-axis', backgroundcolor="rgb(240,240,240)"),
        yaxis=dict(title='Y-axis', backgroundcolor="rgb(240,240,240)"),
        zaxis=dict(title='Z-axis', backgroundcolor="rgb(240,240,240)"),
    ),
    legend_title='Atom Types',
    margin=dict(l=0, r=0, b=0, t=30),
)

fig.show()
#fig.write_html("Normal_True.html")



In [None]:
import plotly.graph_objects as go

# Define atom types with associated colors
atom_colors = {
    'C': 'red',
    'Br': 'blue',
    'Cl': 'green',
    'F': 'orange',
    'H': 'purple'
}

# ─── Ensure C columns exist and are zero if missing ────────────────────────────
for label in ['T', 'P']:
    for coord in ['x0', 'y0', 'z0']:
        col_name = f'{label}_C_{coord}'
        if col_name not in df.columns:
            df[col_name] = 0.0  # Assign zeros if missing

fig = go.Figure()

# Store atom averages for bond connections
avg_positions = {}

# Plot average atom positions
for atom, color in atom_colors.items():
    avg_x = df[f'T_{atom}_x0'].mean()
    avg_y = df[f'T_{atom}_y0'].mean()
    avg_z = df[f'T_{atom}_z0'].mean()

    # save positions
    avg_positions[atom] = (avg_x, avg_y, avg_z)

    fig.add_trace(go.Scatter3d(
        x=[avg_x],  
        y=[avg_y],
        z=[avg_z],
        mode='markers',
        marker=dict(size=18, color=color, opacity=0.95, symbol='circle'),
        name=f'{atom}'
    ))

# ─── Add bonds from Carbon to all other atoms ────────────────────────────────
for atom in atom_colors:
    if atom != 'C' and 'C' in avg_positions:
        x_coords = [avg_positions['C'][0], avg_positions[atom][0]]
        y_coords = [avg_positions['C'][1], avg_positions[atom][1]]
        z_coords = [avg_positions['C'][2], avg_positions[atom][2]]

        fig.add_trace(go.Scatter3d(
            x=x_coords,
            y=y_coords,
            z=z_coords,
            mode='lines',
            line=dict(color=atom_colors[atom], width=6),  # match atom color
            showlegend=False
        ))

# ─── Update layout for professional look ─────────────────────────────────────
fig.update_layout(
    title=dict(text='Average 3D Atom True Positions with Bonds', 
               x=0.5, font=dict(size=24, color="black")),
    scene=dict(
        xaxis=dict(title='X', backgroundcolor="white", gridcolor="lightgray", zeroline=False),
        yaxis=dict(title='Y', backgroundcolor="white", gridcolor="lightgray", zeroline=False),
        zaxis=dict(title='Z', backgroundcolor="white", gridcolor="lightgray", zeroline=False),
        aspectmode="data"  # keep proportions
    ),
    legend=dict(title="Atom Types", font=dict(size=14)),
    margin=dict(l=20, r=20, b=20, t=50),
    width=1000,
    height=900
)

fig.show()
#fig.write_html("Normal_Real_Avg.html")


In [None]:
# Define atom types and their corresponding colors
atom_colors = {
    'C': 'red',
    'Br': 'blue',
    'Cl': 'green',
    'F': 'orange',
    'H': 'purple'
}

# Initialize the 3D scatter plot
fig = go.Figure()

# Add each atom type as a separate trace
for atom, color in atom_colors.items():
    x_col, y_col, z_col = f'P_{atom}_x0', f'P_{atom}_y0', f'P_{atom}_z0'
    
    if all(col in df.columns for col in [x_col, y_col, z_col]):
        fig.add_trace(go.Scatter3d(
            x=df[x_col],
            y=df[y_col],
            z=df[z_col],
            mode='markers',
            marker=dict(size=4, color=color),
            name=atom
        ))

# Update plot layout
fig.update_layout(
    title='3D Atom Predicted Positions',
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z'
    ),
    legend_title='Atoms',
    margin=dict(l=0, r=0, b=0, t=50)
)

# Display the figure
fig.show()
#fig.write_html("Normal_Oriented_Pred.html")

In [None]:
import pandas as pd
import plotly.graph_objects as go

# ─── Load your DataFrame ───────────────────────────────────────────────────────
# Replace this with however you load your data:
# df = pd.read_csv('your_data.csv')
# For this script we assume `df` already contains columns like
# T_C_x0, T_C_y0, T_C_z0, P_C_x0, etc.

# ─── Define atom types and styling ─────────────────────────────────────────────
atom_colors = {
    'C': 'red',
    'Br': 'blue',
    'Cl': 'green',
    'F': 'orange',
    'H': 'purple'
}

# For True vs Predicted
styles = {
    'T': dict(symbol='circle', opacity=1.0, size=4),
    'P': dict(symbol='x', opacity=0.6, size=3)
}

# ─── Build the figure ─────────────────────────────────────────────────────────
fig = go.Figure()

for atom, color in atom_colors.items():
    for label, style in styles.items():
        xcol = f'{label}_{atom}_x0'
        ycol = f'{label}_{atom}_y0'
        zcol = f'{label}_{atom}_z0'

        # Only add if those columns actually exist
        if {xcol, ycol, zcol}.issubset(df.columns):
            fig.add_trace(go.Scatter3d(
                x=df[xcol],
                y=df[ycol],
                z=df[zcol],
                mode='markers',
                marker=dict(
                    size=style['size'],
                    color=color,
                    symbol=style['symbol'],
                    opacity=style['opacity']
                ),
                name=f'{label} {atom}',
                legendgroup=atom
            ))

# ─── Update layout ─────────────────────────────────────────────────────────────
fig.update_layout(
    title=dict(text='True vs Predicted Atom Positions in 3D', x=0.5),
    scene=dict(
        xaxis=dict(title='X (Å)', backgroundcolor="rgb(240,240,240)"),
        yaxis=dict(title='Y (Å)', backgroundcolor="rgb(240,240,240)"),
        zaxis=dict(title='Z (Å)', backgroundcolor="rgb(240,240,240)"),
    ),
    legend_title='Atom & Source',
    margin=dict(l=0, r=0, b=0, t=40),
    width=800,
    height=700
)

# ─── Show ─────────────────────────────────────────────────────────────────────
fig.show()
#fig.write_html("Normal_Oriented_ALL.html")


In [None]:
import plotly.graph_objects as go

# Define atom types with associated colors
atom_colors = {
    'C': 'red',
    'Br': 'blue',
    'Cl': 'green',
    'F': 'orange',
    'H': 'purple'
}

fig = go.Figure()

# ─── Ensure C columns exist and are zero if missing ────────────────────────────
for label in ['T', 'P']:
    for coord in ['x0', 'y0', 'z0']:
        col_name = f'{label}_C_{coord}'
        if col_name not in df.columns:
            df[col_name] = 0.0  # Assign zeros if missing

# Store atom averages for bond connections
avg_positions_true = {}
avg_positions_pred = {}

# ─── Plot TRUE average atom positions ─────────────────────────────────────────
for atom, color in atom_colors.items():
    avg_x = df[f'T_{atom}_x0'].mean()
    avg_y = df[f'T_{atom}_y0'].mean()
    avg_z = df[f'T_{atom}_z0'].mean()

    avg_positions_true[atom] = (avg_x, avg_y, avg_z)

    fig.add_trace(go.Scatter3d(
        x=[avg_x], y=[avg_y], z=[avg_z],
        mode='markers',
        marker=dict(size=18, color=color, opacity=0.95, symbol='circle'),
        name=f'True {atom}'
    ))

# ─── Plot PREDICTED average atom positions ────────────────────────────────────
for atom, color in atom_colors.items():
    avg_x = df[f'P_{atom}_x0'].mean()
    avg_y = df[f'P_{atom}_y0'].mean()
    avg_z = df[f'P_{atom}_z0'].mean()

    avg_positions_pred[atom] = (avg_x, avg_y, avg_z)

    fig.add_trace(go.Scatter3d(
        x=[avg_x], y=[avg_y], z=[avg_z],
        mode='markers',
        marker=dict(size=14, color=color, opacity=0.9, symbol='diamond'),
        name=f'Predicted {atom}'
    ))

# ─── Add TRUE bonds (solid) ───────────────────────────────────────────────────
for atom in atom_colors:
    if atom != 'C' and 'C' in avg_positions_true:
        x_coords = [avg_positions_true['C'][0], avg_positions_true[atom][0]]
        y_coords = [avg_positions_true['C'][1], avg_positions_true[atom][1]]
        z_coords = [avg_positions_true['C'][2], avg_positions_true[atom][2]]

        fig.add_trace(go.Scatter3d(
            x=x_coords, y=y_coords, z=z_coords,
            mode='lines',
            line=dict(color=atom_colors[atom], width=6, dash="solid"),
            showlegend=False
        ))

# ─── Add PREDICTED bonds (dashed) ─────────────────────────────────────────────
for atom in atom_colors:
    if atom != 'C' and 'C' in avg_positions_pred:
        x_coords = [avg_positions_pred['C'][0], avg_positions_pred[atom][0]]
        y_coords = [avg_positions_pred['C'][1], avg_positions_pred[atom][1]]
        z_coords = [avg_positions_pred['C'][2], avg_positions_pred[atom][2]]

        fig.add_trace(go.Scatter3d(
            x=x_coords, y=y_coords, z=z_coords,
            mode='lines',
            line=dict(color=atom_colors[atom], width=4, dash="dash"),
            showlegend=False
        ))

# ─── Update layout for professional look ─────────────────────────────────────
fig.update_layout(
    title=dict(text='Average 3D Atom Initial Positions: True vs Predicted', 
               x=0.5, font=dict(size=24, color="black")),
    scene=dict(
        xaxis=dict(title='X', backgroundcolor="white", gridcolor="lightgray", zeroline=False),
        yaxis=dict(title='Y', backgroundcolor="white", gridcolor="lightgray", zeroline=False),
        zaxis=dict(title='Z', backgroundcolor="white", gridcolor="lightgray", zeroline=False),
        aspectmode="data"  # keep proportions
    ),
    legend=dict(title="Atom Types", font=dict(size=14)),
    margin=dict(l=20, r=20, b=20, t=50),
    width=1000,
    height=900
)

fig.show()
