In [56]:
from fairchem.core.datasets.ase_datasets import AseReadMultiStructureDataset
from fairchem.core.models.model_registry import model_name_to_local_file
from fairchem.core.common.relaxation.ase_utils import OCPCalculator

import yaml
import ase

import torch
from ase import Atoms
from torch_geometric.data import Data
import ase.data
import numpy as np
from fairchem.core.datasets.ase_datasets import AseReadMultiStructureDataset
from fairchem.core.models.model_registry import model_name_to_local_file
from fairchem.core.common.relaxation.ase_utils import OCPCalculator
from fairchem.core.common.relaxation.ase_utils import OCPCalculator

import yaml
import typing

In [80]:
import numpy as np
from rich.console import Console
from rich.table import Table

def calculate_force_metrics(labels, preds):
    """
    Calculate MAE, Cosine Similarity, and Force Magnitude for each model.
    
    Args:
        labels (list): List of true forces for each molecule.
        preds (list): List of predicted forces for each model.
    
    Returns:
        dict: Dictionary with metrics for each model
    """
    # metrics = {}
    # model_names = [key for key in results.keys() if key != "labels"]
    
    # for model_name in model_names:
        # Initialize accumulators
    molecule_mae = []
    atom_cosine_sim = []
    atom_pred_magnitudes = []
    atom_true_magnitudes = []
        
    # Process each molecule
    for i in range(len(labels)):
        pred_forces = preds[i]
        true_forces = labels[i]
        
        # MAE calculation for this molecule
        mae = np.mean(np.abs(pred_forces - true_forces))
        molecule_mae.append(mae)
        
        # Force magnitude calculation (vectorized)
        pred_mags = np.linalg.norm(pred_forces, axis=1)
        true_mags = np.linalg.norm(true_forces, axis=1)
        atom_pred_magnitudes.extend(pred_mags)
        atom_true_magnitudes.extend(true_mags)
        
        # Cosine similarity calculation (vectorized)
        dot_products = np.sum(pred_forces * true_forces, axis=1)
        # Avoid division by zero
        valid_indices = (pred_mags > 1e-10) & (true_mags > 1e-10)
        if np.any(valid_indices):
            cos_sims = np.zeros(len(pred_mags))
            cos_sims[valid_indices] = dot_products[valid_indices] / (pred_mags[valid_indices] * true_mags[valid_indices])
            atom_cosine_sim.extend(cos_sims[valid_indices])
    
        # Compute final metrics
        metrics = {
            "MAE": np.mean(molecule_mae),
            "Mean_Pred_Magnitude": np.mean(atom_pred_magnitudes),
            "Mean_True_Magnitude": np.mean(atom_true_magnitudes),
            "MAE_Magnitude": np.mean(np.abs(np.array(atom_pred_magnitudes) - np.array(atom_true_magnitudes))),
            "Mean_Cosine_Similarity": np.mean(atom_cosine_sim) if atom_cosine_sim else np.nan
        }
    
    return metrics

# Display results in a nice table
def display_metrics_table(metrics):
    console = Console()
    table = Table(title="Force Prediction Metrics")
    
    table.add_column("Model", style="cyan")
    table.add_column("MAE (Hartree)", style="green")
    table.add_column("Mean Cosine Sim", style="yellow")
    table.add_column("Mean Force Mag", style="magenta")
    
    for model_name, model_metrics in metrics.items():
        table.add_row(
            model_name,
            f"{model_metrics['MAE']:.4f}",
            f"{model_metrics['Mean_Cosine_Similarity']:.4f}",
            f"{model_metrics['Mean_Pred_Magnitude']:.4f}"
        )
    
    console.print(table)

In [57]:
config = yaml.safe_load(open("configs/Experiment-1/dataset.yml", "r"))
ckpt_path = "checkpoints/2025-07-25-16-34-24/checkpoint.pt"

In [75]:
model = OCPCalculator(config_yml="configs/Experiment-1/equiformer2.yml",
                     checkpoint_path="checkpoints/2025-07-25-16-34-24/checkpoint.pt",)


INFO:root:amp: false
cmd:
  checkpoint_dir: /home/c23125717/Projects/fairchem/checkpoints/2025-08-06-09-46-56
  commit: core:63249b01,experimental:NA
  identifier: ''
  logs_dir: /home/c23125717/Projects/fairchem/logs/wandb/2025-08-06-09-46-56
  print_every: 100
  results_dir: /home/c23125717/Projects/fairchem/results/2025-08-06-09-46-56
  seed: null
  timestamp_id: 2025-08-06-09-46-56
  version: ''
dataset:
  a2g_args:
    r_energy: true
    r_forces: true
  format: ase_read_multi
  keep_in_memory: true
  key_mapping:
    atoms: atomic_numbers
  pattern: '*.traj'
  use_tqdm: true
evaluation_metrics:
  metrics:
    energy:
    - mae
    forces:
    - mae
    - cosine_similarity
    - magnitude_error
    misc:
    - energy_forces_within_threshold
  primary_metric: forces_mae
gp_gpus: null
gpus: 0
logger: wandb
loss_functions:
- energy:
    coefficient: 4
    fn: mae
- forces:
    coefficient: 100
    fn: l2mae
model:
  alpha_drop: 0.1
  attn_activation: silu
  attn_alpha_channels: 64
  

INFO:root:Loaded EquiformerV2 with 31058690 parameters.
INFO:root:Loading checkpoint from: checkpoints/2025-07-25-16-34-24/checkpoint.pt
  checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loading checkpoint in inference-only mode, not loading keys associated with trainer state!


In [76]:
config = yaml.safe_load(open("configs/Experiment-1-length-test/dataset.yml", "r"))
ckpt_path = "checkpoints/2025-08-05-15-51-44/checkpoint.pt"

In [77]:
calculators = OCPCalculator(
    checkpoint_path=ckpt_path,
    cpu=False,
    trainer="equiformerv2_forces"
    )


  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
INFO:root:local rank base: 0
INFO:root:amp: false
cmd:
  checkpoint_dir: /home/c23125717/Projects/fairchem/checkpoints/2025-08-06-09-46-56
  commit: core:63249b01,experimental:NA
  identifier: ''
  logs_dir: /home/c23125717/Projects/fairchem/logs/wandb/2025-08-06-09-46-56
  print_every: 100
  results_dir: /home/c23125717/Projects/fairchem/results/2025-08-06-09-46-56
  seed: null
  timestamp_id: 2025-08-06-09-46-56
  version: ''
dataset:
  a2g_args:
    r_energy: true
    r_forces: true
  format: ase_read_multi
  keep_in_memory: true
  key_mapping:
    atoms: atomic_numbers
  pattern: '*.traj'
  use_tqdm: true
evaluation_metrics:
  metrics:
    energy:
    - mae
    forces:
    - mae
    - cosine_similarity
    - magnitude_error
    misc:
    - energy_forces_within_threshold
  primary_metric: forces_mae
gp_gpus: null
gpus: 1
logger: wandb
loss_functions:
- energy:
    coefficient: 4
    fn: mae
- forces:
    co

In [133]:
data = AseReadMultiStructureDataset({"src": "databases/Experiment-1/single",
                                     'pattern': '*.traj',
                                     'a2g_args': {'r_energy': True, 'r_forces': True},
                                     'key_mapping': {'atoms': 'atomic_numbers'}})
symbols = [ase.data.chemical_symbols[z.item()] for z in data[0].atomic_numbers]

100%|██████████| 1/1 [00:00<00:00,  4.64it/s]


In [134]:
labels = []
preds = []

for idx, molecule in enumerate(data):
    atoms = Atoms(
        symbols=symbols,
        positions=molecule.pos.numpy(),
        cell=molecule.cell.numpy().reshape(3, 3),
        pbc=molecule.pbc.numpy()
        )
    labels.append(molecule.forces.numpy())

    atoms.calc = calculators
    # energy = atoms.get_potential_energy()
    forces = atoms.get_forces()
    preds.append(forces)

In [135]:
metrics_res = calculate_force_metrics(labels, preds)

#### For C6H14

In [136]:
metrics_res

{'MAE': 0.0004525841,
 'Mean_Pred_Magnitude': 1.3642163,
 'Mean_True_Magnitude': 1.3640522,
 'MAE_Magnitude': 0.0005191829,
 'Mean_Cosine_Similarity': 0.9999995698284474}

#### For C5H12

In [132]:
metrics_res

{'MAE': 0.00043337065,
 'Mean_Pred_Magnitude': 1.3671523,
 'Mean_True_Magnitude': 1.3669759,
 'MAE_Magnitude': 0.0005000761,
 'Mean_Cosine_Similarity': 0.9999995971094238}

#### For C4H10

In [128]:
metrics_res

{'MAE': 0.00040020532,
 'Mean_Pred_Magnitude': 1.3712018,
 'Mean_True_Magnitude': 1.3710089,
 'MAE_Magnitude': 0.000470241,
 'Mean_Cosine_Similarity': 0.9999996724242708}

#### For C3H8

In [124]:
metrics_res

{'MAE': 0.00035500422,
 'Mean_Pred_Magnitude': 1.3775953,
 'Mean_True_Magnitude': 1.3773923,
 'MAE_Magnitude': 0.00042585368,
 'Mean_Cosine_Similarity': 0.9999997823046849}

#### For C2H6

In [120]:
metrics_res

{'MAE': 0.00031542403,
 'Mean_Pred_Magnitude': 1.4087495,
 'Mean_True_Magnitude': 1.4084563,
 'MAE_Magnitude': 0.0004201115,
 'Mean_Cosine_Similarity': 0.999999863377618}

#### For C13H28

In [116]:
metrics_res

{'MAE': 0.0027231583,
 'Mean_Pred_Magnitude': 1.3564223,
 'Mean_True_Magnitude': 1.3556314,
 'MAE_Magnitude': 0.0030262843,
 'Mean_Cosine_Similarity': 0.9999187853132155}

#### For C12H26

In [112]:
metrics_res

{'MAE': 0.0026943157,
 'Mean_Pred_Magnitude': 1.3554399,
 'Mean_True_Magnitude': 1.3547657,
 'MAE_Magnitude': 0.0029879145,
 'Mean_Cosine_Similarity': 0.9999291093095682}

#### For C11H24

In [108]:
metrics_res

{'MAE': 0.0025007997,
 'Mean_Pred_Magnitude': 1.3608065,
 'Mean_True_Magnitude': 1.3602604,
 'MAE_Magnitude': 0.0027923489,
 'Mean_Cosine_Similarity': 0.9999415297437783}

#### For C10H22

In [104]:
metrics_res

{'MAE': 0.0021731923,
 'Mean_Pred_Magnitude': 1.3652018,
 'Mean_True_Magnitude': 1.3646536,
 'MAE_Magnitude': 0.0024131841,
 'Mean_Cosine_Similarity': 0.9999696888670525}

#### For C9H20

In [100]:
metrics_res

{'MAE': 0.0020462163,
 'Mean_Pred_Magnitude': 1.3618153,
 'Mean_True_Magnitude': 1.3613834,
 'MAE_Magnitude': 0.0022848016,
 'Mean_Cosine_Similarity': 0.9999754807592536}

#### For C8H18

In [94]:
metrics_res

{'MAE': 0.0018281699,
 'Mean_Pred_Magnitude': 1.3654447,
 'Mean_True_Magnitude': 1.3652817,
 'MAE_Magnitude': 0.002042961,
 'Mean_Cosine_Similarity': 0.9999593346513038}

In [96]:
# Append it to a csv file
import pandas as pd
results_df = pd.DataFrame({
    "Molecule": ["C8H18"],
    "MAE (Hartree)": [metrics_res["MAE"]],
    "Mean Cosine Sim": [metrics_res["Mean_Cosine_Similarity"]],
    "Mean Force Mag": [metrics_res["Mean_Pred_Magnitude"]]
})
results_df.to_csv("force_metrics_results.csv", index=False)

#### For C7H14

In [82]:
metrics_res

{'MAE': 0.0014768806,
 'Mean_Pred_Magnitude': 1.3588446,
 'Mean_True_Magnitude': 1.3588817,
 'MAE_Magnitude': 0.0016397743,
 'Mean_Cosine_Similarity': 0.9999399784944911}

#### 

In [91]:
data[0].atomic_numbers

tensor([6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1], dtype=torch.uint8)

In [83]:
# Make a csv file with the results
import pandas as pd
results_df = pd.DataFrame({
    "Molecule": ["C7H16"],
    "MAE (Hartree)": [metrics_res["MAE"]],
    "Mean Cosine Sim": [metrics_res["Mean_Cosine_Similarity"]],
    "Mean Force Mag": [metrics_res["Mean_Pred_Magnitude"]]
})
results_df.to_csv("force_metrics_results.csv", index=False)

In [44]:
from torch import nn
MAE = nn.L1Loss()
CS = nn.CosineSimilarity(dim=1, eps=1e-6)
mae_list = []
# cos_sim_list = []
for prd, gt in zip(preds, labels):
    prd = torch.tensor(prd, dtype=torch.float32)
    gt = torch.tensor(gt, dtype=torch.float32)
    mae = MAE(prd, gt)
    # cos_sim = CS(prd, gt)
    mae_list.append(mae.item())
    # cos_sim_list.append(cos_sim.item())


In [46]:
mae_list

[0.0007576398202218115,
 0.8377722501754761,
 1.0636065006256104,
 1.034834623336792,
 0.6447876691818237,
 0.7229059338569641,
 0.9090290665626526,
 1.0366417169570923,
 1.1340786218643188,
 0.9195564985275269,
 0.9236732125282288,
 0.9725053310394287,
 0.812979519367218,
 0.7796967029571533,
 0.8663342595100403,
 0.9388853311538696,
 1.059939980506897,
 0.9767672419548035,
 0.7958314418792725,
 0.969752848148346,
 0.9110025763511658,
 0.9462427496910095,
 0.8719473481178284,
 0.8269702196121216,
 0.8987808227539062,
 0.9142641425132751,
 1.0519920587539673,
 0.9974897503852844,
 0.8796082139015198,
 1.0370166301727295,
 0.89495849609375,
 0.9257460832595825,
 0.9799578785896301,
 0.9651706218719482,
 0.9126549363136292,
 1.0877214670181274,
 1.0899158716201782,
 0.9702037572860718,
 0.8011153340339661,
 0.9351211786270142,
 0.9239380359649658,
 0.9949817061424255,
 0.9574052095413208,
 0.8925915360450745,
 0.8459885716438293,
 0.9109943509101868,
 0.9419719576835632,
 0.7789093255996

In [49]:
data[0].pos.numpy() - data[10].pos.numpy()

array([[ 0.4264918 ,  0.29288927, -1.0815362 ],
       [ 1.3345457 ,  2.5796535 , -0.84776354],
       [ 0.29947388,  0.541071  , -0.30678672],
       [ 1.3772286 ,  2.3902788 , -1.5461023 ],
       [ 0.29516244,  0.58793974, -0.02107689],
       [ 2.1882806 ,  4.4701834 , -0.7977887 ],
       [ 0.19925714,  0.582987  ,  0.23256782],
       [ 2.034242  ,  4.0076685 , -1.3852389 ],
       [ 0.2080574 ,  0.46520504,  0.6791181 ],
       [ 2.394621  ,  3.6018586 , -1.678426  ],
       [ 0.25667763,  0.5071036 ,  0.8460924 ],
       [ 2.2385964 ,  3.2163734 , -2.1254668 ],
       [-0.19900084,  0.24914762,  1.9576911 ],
       [-0.45220846, -0.9360124 , -0.02621263],
       [ 0.7260671 , -0.39092606, -3.0441835 ],
       [ 1.1366374 ,  3.2367404 ,  1.0779711 ],
       [ 2.2049713 ,  3.9087734 , -1.7275147 ],
       [ 0.68066   ,  0.32262206,  0.00703704],
       [-0.02740335,  0.8739431 , -0.19751358],
       [ 1.4473557 ,  1.9312046 , -3.5150032 ],
       [ 0.52098775,  1.1022935 , -0.667

### Ploting

In [173]:
# Load the csv file and plot a line plot with plotly
import pandas as pd
import plotly.express as px
results_df = pd.read_csv("force_metrics_results.csv")
fig = px.line(results_df, x="Molecule", y=["MAE (Hartree)"], # , "Mean Cosine Sim", "Mean Force Mag"
              title="Force Prediction Metrics per Molecule",
              labels={"MAE (Hartree)": "Mean Absolute Error (Hartree)", "Molecule": "Molecule"})

# Set Y-axis range to focus on the specified data range and precision
y_min = 0.0002
y_max = 0.0030
fig.update_yaxes(tickformat=".5f", range=[y_min, y_max])

# Put a box on first 5 molecules
fig.update_traces(mode='markers+lines', marker=dict(size=10))

# Get the number of molecules for proper indexing
num_molecules = len(results_df)

# Set a background color for first 5 molecules (C2H6 to C6H14)
fig.add_shape(type="rect",
              x0=-0.5, x1=4.5, y0=y_min, y1=y_max,
              fillcolor="LightSkyBlue", opacity=0.3,
              layer="below", line_width=0)

# Set another box on last 7 molecules (C7H16 to C13H28)
fig.add_shape(type="rect",
              x0=4.5, x1=num_molecules-0.5, y0=y_min, y1=y_max,
              fillcolor="LightCoral", opacity=0.3,
              layer="below", line_width=0)

fig.update_layout(xaxis_title="Molecule", yaxis_title="Mean Absolute Error (Hartree)")
# Set a theme for academic publication
fig.update_layout(template="plotly_white", font=dict(size=14))

# Rotate x-axis labels for better readability
fig.update_xaxes(tickangle=45)

fig.show()
# fig.write_html("force_metrics_results.html")
# fig.write_image("force_metrics_results.png")

In [161]:
results_df["Mean Cosine Sim"]

0     1.000000
1     1.000000
2     1.000000
3     1.000000
4     1.000000
5     0.999940
6     0.999959
7     0.999975
8     0.999970
9     0.999942
10    0.999929
11    0.999919
Name: Mean Cosine Sim, dtype: float64

In [176]:
# Load the csv file and create subplots with plotly
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

results_df = pd.read_csv("force_metrics_results.csv")

# Create subplots with secondary y-axis or separate subplots
# Option 1: Vertically stacked subplots
fig = make_subplots(
    rows=2, cols=1,
    subplot_titles=('Mean Absolute Error (Hartree)', 'Mean Cosine Similarity'),
    vertical_spacing=0.1
)

# Get the number of molecules for proper indexing
num_molecules = len(results_df)

# Add MAE plot to first subplot
fig.add_trace(
    go.Scatter(x=results_df["Molecule"], y=results_df["MAE (Hartree)"],
               mode='markers+lines', marker=dict(size=10),
               name="MAE", line=dict(color='blue')),
    row=1, col=1
)

# Add Mean Cosine Sim plot to second subplot
fig.add_trace(
    go.Scatter(x=results_df["Molecule"], y=results_df["Mean Cosine Sim"],
               mode='markers+lines', marker=dict(size=10),
               name="Cosine Sim", line=dict(color='red')),
    row=2, col=1
)

# Set Y-axis ranges for each subplot
# MAE subplot
y_min_mae = 0.0002
y_max_mae = 0.0030
fig.update_yaxes(tickformat=".5f", range=[y_min_mae, y_max_mae], row=1, col=1)

# Cosine Sim subplot
y_min_cos = results_df["Mean Cosine Sim"].min() * 0.999995
y_max_cos = results_df["Mean Cosine Sim"].max() * 1.000005
fig.update_yaxes(tickformat=".8f", range=[y_min_cos, y_max_cos], row=2, col=1)

# Add background rectangles for MAE subplot
fig.add_shape(type="rect",
              x0=-0.5, x1=4.5, y0=y_min_mae, y1=y_max_mae,
              fillcolor="LightSkyBlue", opacity=0.3,
              layer="below", line_width=0, row=1, col=1)
fig.add_shape(type="rect",
              x0=4.5, x1=num_molecules-0.5, y0=y_min_mae, y1=y_max_mae,
              fillcolor="LightCoral", opacity=0.3,
              layer="below", line_width=0, row=1, col=1)

# Add background rectangles for Cosine Sim subplot
fig.add_shape(type="rect",
              x0=-0.5, x1=4.5, y0=y_min_cos, y1=y_max_cos,
              fillcolor="LightSkyBlue", opacity=0.3,
              layer="below", line_width=0, row=2, col=1)
fig.add_shape(type="rect",
              x0=4.5, x1=num_molecules-0.5, y0=y_min_cos, y1=y_max_cos,
              fillcolor="LightCoral", opacity=0.3,
              layer="below", line_width=0, row=2, col=1)

# Update layout
fig.update_layout(
    title="Force Prediction Metrics per Molecule",
    template="plotly_white",
    font=dict(size=14),
    height=800,  # Increase height for better visibility
    showlegend=False
)

# Rotate x-axis labels for both subplots
fig.update_xaxes(tickangle=45, row=1, col=1)
fig.update_xaxes(tickangle=45, row=2, col=1, title_text="Molecule")

fig.show()
# fig.write_html("force_metrics_results_subplots.html")
# fig.write_image("force_metrics_results_subplots.png")

In [196]:
# Option 2: Single plot with secondary y-axis (overlapping plots)
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

results_df = pd.read_csv("force_metrics_results.csv")

# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Add MAE trace on primary y-axis
fig.add_trace(
    go.Scatter(x=results_df["Molecule"], y=results_df["MAE (Hartree)"],
               mode='markers+lines', marker=dict(size=10, color='blue'),
               name="MAE (Hartree)", line=dict(color='blue')),
    secondary_y=False,
)

# Add Cosine Similarity trace on secondary y-axis
fig.add_trace(
    go.Scatter(x=results_df["Molecule"], y=results_df["Mean Cosine Sim"],
               mode='markers+lines', marker=dict(size=10, color='red'),
               name="Mean Cosine Sim", line=dict(color='red')),
    secondary_y=True,
)

# Get the number of molecules for proper indexing
num_molecules = len(results_df)

# Set y-axes titles and ranges
fig.update_yaxes(title_text="MAE (Hartree)", tickformat=".4f", 
                 range=[0.0002, 0.0030], secondary_y=False)
fig.update_yaxes(title_text="Mean Cosine Similarity", tickformat=".5f",
                 range=[results_df["Mean Cosine Sim"].min() * 0.999995, 
                        results_df["Mean Cosine Sim"].max() * 1.000005], 
                 secondary_y=True)

# Add background rectangles
fig.add_shape(type="rect",
              x0=-0.5, x1=4.5, y0=0, y1=1,
              fillcolor="LightSkyBlue", opacity=0.2,
              layer="below", line_width=0, yref="paper")
fig.add_shape(type="rect",
              x0=4.5, x1=num_molecules-0.5, y0=0, y1=1,
              fillcolor="LightCoral", opacity=0.2,
              layer="below", line_width=0, yref="paper")

# Update layout with legend positioned inside the plot
fig.update_layout(
    title="Force Prediction Metrics per Molecule (Dual Y-Axis)",
    template="plotly_white",
    font=dict(size=14),
    xaxis_title="Molecule",
    legend=dict(
        x=0.06,  # Position legend inside the plot area (left side)
        y=0.48,  # Position at the top
        xanchor='left',
        yanchor='bottom',
        bgcolor='rgba(255, 255, 255, 0.8)',  # Semi-transparent white background
        bordercolor='rgba(0, 0, 0, 0.2)',    # Light border
        borderwidth=1
    )
)

# Rotate x-axis labels
fig.update_xaxes(tickangle=45)

fig.show()
# fig.write_html("force_metrics_results_dual_axis.html")
# fig.write_image("force_metrics_results_dual_axis.png")

In [202]:
# Option 3: Side-by-side subplots
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

results_df = pd.read_csv("force_metrics_results.csv")

# Create side-by-side subplots
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Mean Absolute Error (Hartree)', 'Mean Cosine Similarity'),
    horizontal_spacing=0.1
)

# Get the number of molecules for proper indexing
num_molecules = len(results_df)

# Add MAE plot to first subplot
fig.add_trace(
    go.Scatter(x=results_df["Molecule"], y=results_df["MAE (Hartree)"],
               mode='markers+lines', marker=dict(size=8),
               name="MAE", line=dict(color='blue')),
    row=1, col=1
)

# Add Mean Cosine Sim plot to second subplot
fig.add_trace(
    go.Scatter(x=results_df["Molecule"], y=results_df["Mean Cosine Sim"],
               mode='markers+lines', marker=dict(size=8),
               name="Cosine Sim", line=dict(color='red')),
    row=1, col=2
)

# Set Y-axis ranges for each subplot
# MAE subplot
y_min_mae = 0.0002
y_max_mae = 0.0030
fig.update_yaxes(tickformat=".4f", range=[y_min_mae, y_max_mae], row=1, col=1)

# Cosine Sim subplot
y_min_cos = results_df["Mean Cosine Sim"].min() * 0.999995
y_max_cos = results_df["Mean Cosine Sim"].max() * 1.000005
fig.update_yaxes(tickformat=".5f", range=[y_min_cos, y_max_cos], row=1, col=2)

# Add background rectangles for MAE subplot
fig.add_shape(type="rect",
              x0=-0.5, x1=4.5, y0=y_min_mae, y1=y_max_mae,
              fillcolor="LightSkyBlue", opacity=0.3,
              layer="below", line_width=0, row=1, col=1)
fig.add_shape(type="rect",
              x0=4.5, x1=num_molecules-0.5, y0=y_min_mae, y1=y_max_mae,
              fillcolor="LightCoral", opacity=0.3,
              layer="below", line_width=0, row=1, col=1)

# Add background rectangles for Cosine Sim subplot
fig.add_shape(type="rect",
              x0=-0.5, x1=4.5, y0=y_min_cos, y1=y_max_cos,
              fillcolor="LightSkyBlue", opacity=0.3,
              layer="below", line_width=0, row=1, col=2)
fig.add_shape(type="rect",
              x0=4.5, x1=num_molecules-0.5, y0=y_min_cos, y1=y_max_cos,
              fillcolor="LightCoral", opacity=0.3,
              layer="below", line_width=0, row=1, col=2)

# Update layout
fig.update_layout(
    title="Force Prediction Metrics per Molecule",
    template="plotly_white",
    font=dict(size=14),
    width=1200,  # Increase width for better visibility
    height=500,
    showlegend=False
)

# Rotate x-axis labels and add titles
fig.update_xaxes(tickangle=45, title_text="Molecule", row=1, col=1)
fig.update_xaxes(tickangle=45, title_text="Molecule", row=1, col=2)

fig.show()
# fig.write_html("force_metrics_results_sidebyside.html")
# fig.write_image("force_metrics_results_sidebyside.png")