Ramanalysis: Interactive comparison and matching of Raman spectra

Copyright (C) 2025 , Peter Methley

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

In [2]:
import numpy as np
import pandas as pd
import sqlite3

from tqdm import tqdm

from scipy.special import wofz
from IPython.display import display

import ipywidgets as widgets
from IPython.display import display

from plotly_default import go
from plotly.subplots import make_subplots
import plotly.colors

import ast

import io

from Raman_helper_functions import find_peak_positions, arpls

In [3]:
def get_selected_xs(figure):
    
    # workaround as directly accessing figure.layout.selections does not work
    layout = figure.to_dict()["layout"]
    
    if "selections" in layout:
        selections = layout["selections"]
    else:
        selections = []
    
    return [(s["x0"], s["x1"]) for s in selections]


def in_selection(selected_xs, x_points):
    """Returns array of bools for each point in x_points if it is within the selected regions in selected_xs"""

    lower_bounds = [min(region) for region in selected_xs]
    upper_bounds = [max(region) for region in selected_xs]
   
    selected = []
    
    for x in x_points:
        is_selected = any(lower <= x <= upper for lower, upper in zip(lower_bounds, upper_bounds))
        selected.append(is_selected)
        
    return np.array(selected)


def make_peak_traces(peak_xs, peak_ys, peak_prominences):
    xx = []
    yy = []
    texts = []

    for x, y, p in zip(peak_xs, peak_ys, peak_prominences):
        xx.append(x)
        yy.append(0)
        texts.append(f"Prominence: {p:.3f}")
        
        xx.append(x)
        yy.append(y)
        texts.append(f"Prominence: {p:.3f}")

        xx.append(None)
        yy.append(None)
        texts.append(None)

    return xx, yy, texts

# Custom function to check if all elements belong to the allowed list
def contains_only_allowed(value, allowed):
    if value is None:
        return False
    elements = set(value.split(', '))  # Split CSV into a set of elements
    allowed = allowed.replace(" ", "").rstrip(",")  # Remove spaces and trailing comma
    allowed_set = set(allowed.split(','))        # Convert allowed list to a set
    return elements.issubset(allowed_set)  # Check if all elements are in the allowed list

In [4]:
conn = sqlite3.connect("Raman_database.db")
conn.create_function("CONTAINS_ONLY_ALLOWED", 2, contains_only_allowed)

cursor = conn.cursor()


# Track selected items
selected_items = []

# Search box
search_box = widgets.Text(
    placeholder='Search RRUFF database...',
    description='Search:',
    layout=widgets.Layout(width='95%')
)
# Filter by elements
elements_box = widgets.Text(
    placeholder='e.g. Ca, C, O',
    description='Elements:',
    layout=widgets.Layout(width='95%')
)

# Search filters
filter_box = widgets.SelectMultiple(
    options=['raw', 'processed'],
    value=['processed'],
    description='Filters',
    layout=widgets.Layout(width='95%', height="50px")
)
polarisation_box = widgets.SelectMultiple(
    options=['ccw', 'depolarised', 'unoriented'],
    value=['unoriented'],
    description='Polarisation',
    layout=widgets.Layout(width='95%', height="75px")
)

# Left box: Available items
available_items_box = widgets.Select(
    options=[],
    description='Database',
    layout=widgets.Layout(width='95%', height='210px')
)

# Right box: Selected items
selected_items_box = widgets.Select(
    options=selected_items,
    description='Selected',
    layout=widgets.Layout(width='95%', height='150px')
)

file_selector = widgets.FileUpload(
    description="Select measurement file(s)",
    accept=".txt,.csv",
    multiple=True
)

normalisation_box = widgets.Checkbox(
    value=True,
    description="Normalise"
)

background_box = widgets.Checkbox(
    value=False,
    description="Remove Background"
)

lambda_box = widgets.FloatText(value=10000, description="λ")
ratio_box = widgets.FloatText(value=0.05, description="Ratio")
# iterations_box = widgets.IntText(value=100, description="Max. iterations")
prominence_box = widgets.FloatText(value=0.05, description="Peak Prominence", step=0.01, min=0)
x_tolerance_box = widgets.FloatText(value=2.0, description="Wavenumber tolerance", step=0.01, min=0)

match_button = widgets.Button(description="Match Spectra")

selection_box = widgets.Checkbox(
    value=False,
    description="Select specific regions to match"
)

out_widget = widgets.Output(layout={'font_family': 'Cascadia Code'})

# Graph
f1 = go.FigureWidget(make_subplots(rows=2, cols=1, shared_xaxes=True))
f1.update_layout(height=700)
f1.update_xaxes(title="Wavenumber (cm⁻¹)", showticklabels=True)
f1.update_yaxes(title="Normalised Intensity")

# Update both boxes on search
def update_boxes(change=None):
    # Remove selection so items don't get accidentally added or removed
    available_items_box.value = None
    selected_items_box.value = None
    
    # Search the database
    
    search_term = search_box.value.lower()
    filters = [f"'%{filt}%'" for filt in filter_box.value]
    polarisations = [f"'%{filt}%'" for filt in polarisation_box.value]
    
    filter_query = " OR filename LIKE ".join(filters)
    polarisation_query = " OR filename LIKE ".join(polarisations)
    
    query = f"""SELECT filename FROM database_table
    WHERE filename LIKE '%{search_term}%'
    AND (filename LIKE {filter_query})
    AND (filename LIKE {polarisation_query})"""
    
    if elements_box.value:
        query += f" AND (CONTAINS_ONLY_ALLOWED (elements, '{elements_box.value}'))"
    
    cursor.execute(query)
    results = cursor.fetchall()
    filtered = [item[0] for item in results if item[0] not in selected_items]
    
    # Update the list
    available_items_box.options = filtered
    selected_items_box.options = selected_items

# Move items from available to selected
def select_item(change):
    global selected_items
    if change['new']:
        item = change['new']
        selected_items.append(item)
        update_boxes()
        update_reference_spectra()

# Remove items from selected
def deselect_item(change):
    global selected_items
    if change['new']:
        item = change['new']
        selected_items.remove(item)
        update_boxes()
        update_reference_spectra()
        

def update_measured_spectra(change=None):
    
    global measured_peaks, measured_prominences
    
    # Remove existing traces on first row
    f1.data = [trace for trace in f1.data if (trace.meta != 'Measured Spectrum' and trace.meta != 'Measured Peaks')]
    
    for i, file in enumerate(file_selector.value):
        
        if file.name.endswith(".txt"):
            delimiter = r'\s+'
            sample_name = file.name.rstrip(".txt")
        else:
            delimiter = ','
            sample_name = file.name.rstrip(".csv")
        
        df = pd.read_csv(io.BytesIO(file.content), sep=delimiter, names=["wavenumber", "intensity"])
        
        y_norm = np.array(df["intensity"]/df["intensity"].max())
        
        if background_box.value == True:
            with np.errstate(over="ignore"): # for some reason this often overflows but doesn't affect the spectrum.
                background = arpls(y_norm, lam=lambda_box.value, ratio=ratio_box.value)
        else:
            background = 0
            
        xx = np.array(df["wavenumber"])
        yy = y_norm - background
        
        if normalisation_box.value == False:
            scale_factor = df["intensity"].max()
        else:
            scale_factor = 1/yy.max()
            
        peak_xs, peak_ys, peak_prominences = find_peak_positions(xx, yy, prominence_threshold=prominence_box.value, remove_bg= not background_box.value)
        
        measured_peaks = peak_xs
        measured_prominences = peak_prominences
        
        yy *= scale_factor
        peak_ys *= scale_factor
        
        trace_color = plotly.colors.qualitative.Bold[i]
        
        f1.add_scatter(x=xx, y=yy, marker_color=trace_color, name=sample_name, meta="Measured Spectrum", row=1, col=1)
        
        # Add peaks as lines
        peak_xx, peak_yy, peak_texts = make_peak_traces(peak_xs, peak_ys, peak_prominences)
        f1.add_scatter(x=peak_xx, y=peak_yy, mode='lines', opacity=0.5, marker_color=trace_color, line_width=1, showlegend=False, text=peak_texts, meta="Measured Peaks", row=1, col=1)
        f1.add_scatter(x=peak_xx, y=peak_yy, mode='lines', opacity=0.5, marker_color=trace_color, line_width=1, showlegend=False, text=peak_texts, meta="Measured Peaks", row=2, col=1)

def update_reference_spectra(change=None):
    
    # Remove existing traces on second row
    f1.data = [trace for trace in f1.data if (trace.meta != 'Reference Peaks' and trace.meta != 'Reference Spectrum')]    
    
    for i, file in enumerate(selected_items):
        
        cursor.execute(f"SELECT x_data, y_data FROM database_table WHERE filename = '{file}'")
        x_data_str, y_data_str = cursor.fetchone()
        
        x_data = np.array(ast.literal_eval(x_data_str), dtype=float)
        y_data = np.array(ast.literal_eval(y_data_str), dtype=float)
        
        mineral_name =  file.split('__Raman')[0]
        
        y_norm = np.array(y_data/y_data.max())
        
        if background_box.value == True:
            with np.errstate(over="ignore"): # for some reason this often overflows but doesn't affect the spectrum.
                background = arpls(y_norm, lam=lambda_box.value, ratio=ratio_box.value)
        else:
            background = 0
            
        yy = y_norm - background
        xx = x_data
        
        if normalisation_box.value == False:
            scale_factor = y_data.max()
        else:
            scale_factor = 1/yy.max()
            
        peak_xs, peak_ys, peak_prominences = find_peak_positions(xx, yy, remove_bg= not background_box.value)
            
        yy *= scale_factor
        
        trace_color = plotly.colors.qualitative.D3[i]
        
        f1.add_scatter(x=xx, y=yy, marker_color=trace_color, name=mineral_name, meta="Reference Spectrum", row=2, col=1)
        
        # Add peaks as lines
        peak_xx, peak_yy, peak_texts = make_peak_traces(peak_xs, peak_ys, peak_prominences)
        f1.add_scatter(x=peak_xx, y=peak_yy, mode='lines', opacity=0.5, marker_color=trace_color, line_width=1, showlegend=False, text=peak_texts, meta="Reference Peaks", row=1, col=1)
        f1.add_scatter(x=peak_xx, y=peak_yy, mode='lines', opacity=0.5, marker_color=trace_color, line_width=1, showlegend=False, text=peak_texts, meta="Reference Peaks", row=2, col=1)
    
    return None


def change_normalisation(change=None):
    if normalisation_box.value == True:
        f1.update_yaxes(title="Normalised Intensity")
    else:
        f1.update_yaxes(title="Intensity")
        
    update_measured_spectra()
    update_reference_spectra()
    
    
def enable_disable_selection(change=None):
    if selection_box.value == True:
        f1.layout.dragmode = 'select'
    else:
        f1.layout.dragmode = 'zoom'


selected_xs = []

def add_selected_xs(trace, points, selector):
    global selected_xs
    selected_xs.append(selector.xrange)

def clear_selected_xs(trace, points):
    global selected_xs
    selected_xs = []

def calculate_score(ref_peaks, ref_prominences, measured_peaks, measured_prominences, x_tolerance=2, prominence_threshold=0.05, penalty=1.0):

    score = 0
    
    # Ignore peaks below the prominence threshild
    filt_prominences = ref_prominences[ref_prominences > prominence_threshold]
    filt_peaks = ref_peaks[ref_prominences > prominence_threshold]
    
    for ref_peak, ref_prominence in zip(filt_peaks, filt_prominences):
        
        # print(f"{ref_peak=}, {ref_prominence=}")
    
        matches = np.abs(measured_peaks - ref_peak) < x_tolerance
        
        is_match = np.any(matches)
        
        # print(f"{matches=}, {is_match=}")
        
        if is_match:
            measured_prominence = measured_prominences[matches].max()
        
            score += ref_prominence * (1 - abs(measured_prominence - ref_prominence))
        else:
            score -= ref_prominence * penalty
                
    return score



def match_spectra(change=None):
    out_widget.clear_output()
    
    # filter measured peaks by selected areas (if selection present)
    selection_xs = get_selected_xs(f1)
    if selection_xs:
        peaks_in_selection = in_selection(selection_xs, measured_peaks)
        filtered_measured_peaks = measured_peaks[peaks_in_selection]
        filtered_measured_prominences = measured_prominences[peaks_in_selection]
        # Normalise within selected regions
        if filtered_measured_prominences.size > 0:
            filtered_measured_prominences /= filtered_measured_prominences.max()
        
    else:
        filtered_measured_peaks = measured_peaks
        filtered_measured_prominences = measured_prominences
    
    filenames = available_items_box.options
    
    if len(filenames) == 0:
        with out_widget:
            print("No reference spectra matching given filter criteria available to match.")
        return
    
    if len(file_selector.value) == 0:
        with out_widget:
            print("Please select a measured spectrum to match.")
        return
    elif len(file_selector.value) > 1:
        with out_widget:
            print(f"{len(file_selector.value)} measured spectra selected. Matching only {file_selector.value[-1]['name']}")
    
    scores = np.zeros(len(filenames))  # Preallocate for speed

    # Query all filenames at once to reduce database calls
    quoted_filenames = [f"'{filename}'" for filename in filenames]
    filename_query = ', '.join(quoted_filenames)
    cursor.execute(f"SELECT filename, peak_xs, peak_prominences, mineral_name, elements, rruff_id FROM database_table WHERE filename IN ({filename_query})")
    
    results = cursor.fetchall()
    data_dict = {row[0]: row[1:3] for row in results}
    
    results_df = pd.DataFrame(results, columns=["Filename", "Peak_xs", "Peak_prominences", "Mineral", "Elements", "ID"])

    with out_widget: # display progress bar
        for i, filename in enumerate(tqdm(filenames)):

            peak_xs_str, peak_prominences_str = data_dict[filename]
            
            peak_xs = np.array(ast.literal_eval(peak_xs_str), dtype=float)
            peak_prominences = np.array(ast.literal_eval(peak_prominences_str), dtype=float)
            
            # if selection_xs and peak_xs.size > 0:
            #     peaks_in_selection = in_selection(selection_xs, peak_xs)
            #     filtered_peaks = peak_xs[peaks_in_selection]
            #     filtered_prominences = peak_prominences[peaks_in_selection]
            #     # Normalise within selected regions
            #     if filtered_prominences.size > 0:
            #         filtered_prominences /= filtered_prominences.max()
                
            # else:
            #     filtered_peaks = peak_xs
            #     filtered_prominences = peak_prominences

            scores[i] = calculate_score(
                peak_xs, peak_prominences, filtered_measured_peaks, filtered_measured_prominences,
                x_tolerance=x_tolerance_box.value, prominence_threshold=prominence_box.value
            )
        
        # Sort dataframe by max score grouped by mineral
        results_df["Score"] = scores
        results_df.drop(columns=["Filename", "Peak_xs", "Peak_prominences"], inplace=True)
        
        results_df = results_df.groupby("Mineral").max()
        results_df = results_df.sort_values("Score", ascending=False)
        
        display(results_df.head(5))

        best_match = filenames[np.argmax(scores)]
        # print(f"Best match: {best_match}")
        select_item({'new': best_match})


# Observers
search_box.observe(update_boxes, names='value')
elements_box.observe(update_boxes, names='value')
filter_box.observe(update_boxes, names='value')
polarisation_box.observe(update_boxes, names='value')
available_items_box.observe(select_item, names='value')
selected_items_box.observe(deselect_item, names='value')
file_selector.observe(update_measured_spectra, names='value')
normalisation_box.observe(change_normalisation, names="value")
background_box.observe(update_measured_spectra, names="value")
background_box.observe(update_reference_spectra, names="value")
prominence_box.observe(update_measured_spectra, names="value")
prominence_box.observe(update_reference_spectra, names="value")
selection_box.observe(enable_disable_selection, names="value")
match_button.on_click(match_spectra)

# Add dummy traces to figure so we can use the on_selection event
f1.add_scatter(x=[0], y=[0], visible=False, mode='markers', marker=dict(size=0), meta="Dummy", row=1, col=1)
f1.add_scatter(x=[0], y=[0], visible=False, mode='markers', marker=dict(size=0), meta="Dummy", row=2, col=1)

f1.data[0].on_selection(add_selected_xs)
f1.data[0].on_deselect(clear_selected_xs)

# Layout
selection_UI = widgets.HBox(
    [widgets.VBox([search_box, elements_box, available_items_box], layout=widgets.Layout(width="100%")),
     widgets.VBox([filter_box, polarisation_box, selected_items_box], layout=widgets.Layout(width="100%")),
     widgets.VBox([file_selector, normalisation_box, lambda_box, ratio_box, background_box, prominence_box, x_tolerance_box, selection_box, match_button], layout=widgets.Layout(width="85%"))],
     layout=widgets.Layout(width='95%'))

update_boxes()

display(selection_UI, f1, out_widget)

HBox(children=(VBox(children=(Text(value='', description='Search:', layout=Layout(width='95%'), placeholder='S…

FigureWidget({
    'data': [{'marker': {'size': 0},
              'meta': 'Dummy',
              'mode': 'markers',
              'type': 'scatter',
              'uid': '45856a7d-0c70-4ede-9342-0b572dc5de2e',
              'visible': False,
              'x': [0],
              'xaxis': 'x',
              'y': [0],
              'yaxis': 'y'},
             {'marker': {'size': 0},
              'meta': 'Dummy',
              'mode': 'markers',
              'type': 'scatter',
              'uid': '7ecb062e-5c93-4a5c-89e1-edb607a7b4b5',
              'visible': False,
              'x': [0],
              'xaxis': 'x2',
              'y': [0],
              'yaxis': 'y2'}],
    'layout': {'height': 700,
               'template': '...',
               'xaxis': {'anchor': 'y',
                         'domain': [0.0, 1.0],
                         'matches': 'x2',
                         'showticklabels': True,
                         'title': {'text': 'Wavenumber (cm⁻¹)'}},
          

Output()