In [None]:
#!/usr/bin/env python
# -*-coding:utf-8 -*-
# File    :   barcode_plot.py
# Time    :   22.06.2022
# Author  :   Daniel Weston
# Contact :   dtw545@bham.ac.uk

# imports and hard coding of locations
import numpy as np
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.interpolate as interp
from scipy.signal import savgol_filter

import plotly.express as px
from tqdm import tqdm

folders = ["c-glutamicum", "m-bovis", "r-erythropolis"]
data_dir = "split-data"
plot_dir = "figs"

# https://m-clark.github.io/data-processing-and-visualization/thinking_vis.html
colour = px.colors.sequential.Plasma


In [None]:
# define function to get proper dpi
def dpi_scaler(fig: go.Figure, dpi: int = 600) -> float:
    """
    Set scale parameter when saving plotly figures to target dpi. Assumes that the border on A4 paper is 20mm

    Parameters
    ----------
    fig : go.Figure
        Figure to save.
    dpi : int, optional
        Target dpi.

    Returns
    -------
    float
        Scale factor for figure to achieve target dpi.
    """
    width = fig.layout.width if fig.layout.width is not None else 700
    scale = (170 / 25.4) / (width / dpi)
    return scale


In [None]:
# Data import
sample = []
repeat = []
bacteria = []
wavenumber = []
intensity = []
length = []
for folder in (pbar := tqdm(folders)):
    pbar.set_description(f"Reading {folder}...")
    files = Path(data_dir, folder).glob("*.txt")
    for file in (pbar2 := tqdm(files, leave=False)):
        pbar2.set_description(f"Reading {file}")
        _, sample_number, repeat_number = file.stem.split("_")
        # cast these numbers to int
        sample_number = int(sample_number.split("-")[0])
        repeat_number = int(repeat_number)
        data = np.loadtxt(file)
        wavenumber.append(list(data[:, 0]))
        intensity.append(data[:, 1])
        sample.append(sample_number)
        repeat.append(repeat_number)
        bacteria.append(file.parts[1])
        length.append(data.shape[0])

labels = pd.MultiIndex.from_arrays(
    [bacteria, sample, repeat], names=["bacteria", "sample", "repeat"]
)
df = pd.DataFrame(intensity, index=labels)
# prune columns with NaN in, inplace for speed and memory reasons
df.dropna(axis=1, inplace=True)

In [None]:
# deep copy the dataframe to scale
scaled_df = df.copy(deep=True)
scaler = StandardScaler()
scaled_df = pd.DataFrame(
    scaler.fit_transform(scaled_df), columns=scaled_df.columns, index=labels
)
scaled_df


In [None]:
# Use PCA for optimal component selection
target_var = 0.999
# use everything together
pca_all = PCA(n_components=target_var)
X_all = pca_all.fit(scaled_df)
explained_variance_all = X_all.explained_variance_
npts = explained_variance_all.shape[0]
# use each repeat separately
pca_repeats = [PCA(n_components=target_var) for i in range(3)]
X_repeat = [
    pca.fit(scaled_df.xs(i + 1, level="repeat")) for i, pca in enumerate(pca_repeats)
]
explained_variance_repeats = [X.explained_variance_ for X in X_repeat]
for var in explained_variance_repeats:
    if (new_len := var.shape[0]) > npts:
        npts = new_len
# average the repeats and see what happens
avg_df = scaled_df.groupby(["bacteria", "sample"]).mean()
pca_avg = PCA(n_components=target_var)
X_avg = pca_avg.fit(avg_df)
explained_variance_avg = X_avg.explained_variance_
npts = new_len if (new_len := explained_variance_avg.shape[0]) > npts else npts


In [None]:
# Scree plot
scree_plot = go.Figure()
all_trace = go.Scatter(
    y=explained_variance_all, name="all values", mode="lines + markers"
)
repeat_traces = [
    go.Scatter(y=var, name=f"repeat {i} values", mode="lines + markers")
    for i, var in enumerate(explained_variance_repeats)
]
averaged_trace = go.Scatter(
    y=explained_variance_avg, name="averaged values", mode="lines + markers"
)
decision_line = go.Scatter(
    y=[1] * npts, mode="lines", line=dict(dash="dash", color="red"), showlegend=False
)
scree_plot.add_traces([all_trace, *repeat_traces, averaged_trace, decision_line])
scree_plot.update_xaxes(title=dict(text="Component"))
scree_plot.update_yaxes(title=dict(text="Eigenvalue"))
# scree_plot.update_layout(height = 1000)
scree_plot


In [None]:
# variance ratio plot
ratio_plot = go.Figure()
all_trace = go.Scatter(
    y=100 * np.cumsum(explained_variance_all) / np.sum(explained_variance_all),
    name="all values",
    mode="lines + markers",
)
repeat_traces = [
    go.Scatter(
        y=100 * np.cumsum(var) / np.sum(var),
        name=f"repeat {i} values",
        mode="lines + markers",
    )
    for i, var in enumerate(explained_variance_repeats)
]
averaged_trace = go.Scatter(
    y=100 * np.cumsum(explained_variance_avg) / np.sum(explained_variance_avg),
    name="averaged values",
    mode="lines + markers",
)
ratio_plot.add_traces([all_trace, *repeat_traces, averaged_trace])
ratio_plot.update_xaxes(title=dict(text="Component"))
ratio_plot.update_yaxes(title=dict(text="Cumulative Variance Explained (%)"))
# scree_plot.update_layout(height = 1000)
ratio_plot


In [None]:
# covariance heatmap
covariance_plots = [go.Figure(), make_subplots(rows=1, cols=3), go.Figure()]
all_trace = go.Heatmap(z=X_all.get_covariance(), zmin=-1, zmax=1)
repeat_traces = [go.Heatmap(z=X.get_covariance(), zmin=-1, zmax=1) for X in X_repeat]
averaged_trace = go.Heatmap(z=X_avg.get_covariance(), zmin=-1, zmax=1)
titles = ["All", "Per Sample", "Avg"]
covariance_plots[0].add_trace(all_trace)
for i, trace in enumerate(repeat_traces):
    covariance_plots[1].add_trace(trace, row=1, col=i + 1)
covariance_plots[2].add_trace(averaged_trace)
for i, (plot, title) in enumerate(zip(covariance_plots, titles)):
    if i != 1:
        height = 600
        width = 600
    else:
        height = 600
        width = 1800
    plot.update_layout(height=height, width=width, title_text=title)
    plot.show()


In [None]:
# Reduce the dataframe dimensionality
reduced_components = pd.DataFrame(
    X_all.components_.T, columns=[f"PC{i+1}" for i in range(X_all.n_components_)]
)

reduced_components


In [None]:
# Loadings plot
eigenvalues = X_all.explained_variance_.reshape(reduced_components.shape[1], 1)
# load the reduced components with square rooted eigenvalues if we want to
reduced_components *= np.sqrt(eigenvalues).T

loading_plot = go.Figure()
pc1_trace = go.Scatter(
    x=wavenumber[0], y=reduced_components["PC1"], mode="lines", name="PC1"
)

pc2_trace = go.Scatter(
    x=wavenumber[0], y=reduced_components["PC2"], mode="lines", name="PC2"
)

loading_plot.add_traces([pc1_trace, pc2_trace])
loading_plot.update_xaxes(title=dict(text="Raman Shift"))
loading_plot.update_yaxes(title=dict(text="Loadings"))
loading_plot


In [None]:
# PCA scatter
scores = pd.DataFrame(
    X_all.fit_transform(scaled_df),
    index=labels,
    columns=[f"PC{i+1}" for i in range(X_all.n_components_)],
)
markers = ["circle", "square", "triangle-up"]
scatter_plot = go.Figure()

for i, bacterium in enumerate(np.unique(bacteria)):
    for j, rep in enumerate(np.unique(repeat)):
        trace = go.Scatter(
            x=scores["PC1"].xs((bacterium, rep), level=("bacteria", "repeat")),
            y=scores["PC2"].xs((bacterium, rep), level=("bacteria", "repeat")),
            mode="markers",
            marker=dict(color=colour[2 * i], symbol=markers[j]),
            name=f"{bacterium} - rep {rep}",
        )

        scatter_plot.add_trace(trace)
scatter_plot.update_xaxes(title=dict(text="PC1"))
scatter_plot.update_yaxes(title=dict(text="PC2"))

scatter_plot


In [None]:
# LDA accuracy check
y = bacteria
X = X_all.fit_transform(scaled_df)
print(X.shape)
lda_scores = cross_val_score(LDA(), X, y, cv=4)
print("Accuracy: %0.4f (+/- %0.4f)" % (lda_scores.mean(), lda_scores.std() * 2))


In [None]:
lda_naive = LDA()
X_lda_naive = lda_naive.fit_transform(scaled_df, bacteria)
lda_naive_df = pd.DataFrame(
    X_lda_naive,
    index=labels,
    columns=[f"LD{i+1}" for i in range(lda_naive.classes_.shape[0] - 1)],
)
lda_naive_df


In [None]:
lda_naive_plot = go.Figure()

for i, bacterium in enumerate(np.unique(bacteria)):
    trace = go.Scatter(
        x=lda_naive_df["LD1"].xs(bacterium, level="bacteria"),
        y=lda_naive_df["LD2"].xs(bacterium, level="bacteria"),
        mode="markers",
        marker=dict(color=colour[2 * i]),
        name=f"{bacterium}",
    )

    lda_naive_plot.add_trace(trace)
lda_naive_plot.update_xaxes(title=dict(text="LD1"))
lda_naive_plot.update_yaxes(title=dict(text="LD2"))

lda_naive_plot


In [None]:
# Okay now let's do some LDA!
lda = LDA()
X_lda = lda.fit_transform(X, y)
lda_df = pd.DataFrame(
    X_lda, index=labels, columns=[f"LD{i+1}" for i in range(lda.classes_.shape[0] - 1)]
)
lda_df


In [None]:
lda_plot = go.Figure()

for i, bacterium in enumerate(np.unique(bacteria)):
    trace = go.Scatter(
        x=lda_df["LD1"].xs(bacterium, level="bacteria"),
        y=lda_df["LD2"].xs(bacterium, level="bacteria"),
        mode="markers",
        marker=dict(color=colour[2 * i]),
        name=f"{bacterium}",
    )

    lda_plot.add_trace(trace)
lda_plot.update_xaxes(title=dict(text="LD1"))
lda_plot.update_yaxes(title=dict(text="LD2"))

lda_plot


Naively using LDA doesn't really alter things, *but*, in general, the PCA-LDA is more robust as the PCA matrix is orthogonal by construction, so we use the PCA-LDA approach for confusion matrix stuff

In [None]:
# test prediction quality of the model
lda_test = LDA()
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.4, random_state=42
)
y_train = np.array(y_train)
y_test = np.array(y_test)
lda_test.fit(X_train, y_train)
score = lda_test.score(X_train, y_train)
y_pred = lda_test.predict(X_test)
conf_mat = confusion_matrix(y_true=y_test, y_pred=y_pred)
confusion_plot = px.imshow(
    conf_mat, y=folders, x=folders, text_auto=True, aspect="auto"
)
confusion_plot.update_xaxes(title_text="Predicted")
confusion_plot.update_yaxes(title_text="Actual")

confusion_plot


In [None]:
# classification report
print(classification_report(y_true=y_test, y_pred=y_pred, target_names=folders))


In [None]:
df


In [None]:
avg_df = df.groupby("bacteria").mean()
std_df = df.groupby("bacteria").std()
intensity_data = df.to_numpy().T


In [None]:
# spacing for evenly fitted spline is ~1.07 and scipy assumes delta = 1.0
# doesn't affect derivative size too much (compared to threshold)
# this cell plots X to check filter performance (0 - 299 for c-gluta, 300 - 599 for r-eryth and 600+ for m-bovis)
X = intensity_data[:, 601]
# These parameters are used later on so make sure this cell has run!!
order = 2
window = 5
fig = go.Figure()
intensity_trace = go.Scatter(
    x=wavenumber[1],
    y=X,
    name="Raw",
)
sg_trace = go.Scatter(
    x=wavenumber[1],
    y=savgol_filter(X, window, polyorder=order),
    mode="lines",
    name="SG",
)
sg_traced1 = go.Scatter(
    x=wavenumber[1],
    y=savgol_filter(X, window, polyorder=order, deriv=1),
    mode="lines",
    name="SG 1st derivative",
)
sg_traced2 = go.Scatter(
    x=wavenumber[1],
    y=savgol_filter(X, window, polyorder=order, deriv=2),
    mode="lines",
    name="SG 2nd derivative",
)
fig.add_traces([intensity_trace, sg_trace, sg_traced1, sg_traced2])

fig


In [None]:
# apply savgol filter on data and determine thresholds for bacteria
sg_df = df.apply(
    savgol_filter,
    axis=1,
    result_type="broadcast",
    window_length=window,
    polyorder=order,
    deriv=2,
)
thresholds = sg_df.groupby("bacteria").max().to_numpy()
thresholds = (7.5 / 100) * np.max(thresholds, axis=1)
print(thresholds)
barcode_freq = make_subplots(
    rows=3,
    cols=1,
    subplot_titles=("c-glutamicum", "m-bovis", "r-erythropolis"),
    vertical_spacing=0.05,
)
for i, bacterium in enumerate(np.unique(bacteria)):
    intensity_trace = go.Scatter(
        x=wavenumber[1],
        y=avg_df.loc[bacterium],
        mode="lines",
        line=dict(color="black"),
        showlegend=False,
    )

    upper = go.Scatter(
        x=wavenumber[1],
        y=(avg_df.loc[bacterium] + std_df.loc[bacterium]),
        mode="lines",
        fill="tonexty",
        fillcolor="rgba(68, 68, 68, 0.3)",
        opacity=0.3,
        line=dict(color="rgba(255,255,255,0)", width=0),
        hoverinfo="skip",
        showlegend=False,
    )

    lower = go.Scatter(
        x=wavenumber[1],
        y=(avg_df.loc[bacterium] - std_df.loc[bacterium]),
        mode="lines",
        fill="tonexty",
        fillcolor="rgba(68, 68, 68, 0.3)",
        opacity=0.3,
        line=dict(color="rgba(255,255,255,0)", width=0),
        hoverinfo="skip",
        showlegend=False,
    )

    barcode_freq.add_trace(intensity_trace, row=i + 1, col=1)
    barcode_freq.add_trace(upper, row=i + 1, col=1)
    barcode_freq.add_trace(lower, row=i + 1, col=1)
    # generate histogram data - https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2732026/
    # count rows in each column that's above threshold
    sg_df.loc[bacterium][sg_df.loc[bacterium] >= thresholds[i]] = 1
    sg_df.loc[bacterium][sg_df.loc[bacterium] < thresholds[i]] = 0

    counts = sg_df.loc[bacterium].sum() / sg_df.loc[bacterium].sum().max()
    hist_trace = go.Bar(
        x=wavenumber[1],
        y=counts,
        offset=0,
        showlegend=False,
        marker_color=colour[2 * i],
        marker_line_width=0,
    )
    barcode_freq.add_trace(hist_trace, row=i + 1, col=1)
barcode_freq.update_xaxes(title=dict(text="Raman Shift"), row=3, col=1)
barcode_freq.update_yaxes(title=dict(text="Intensity"), row=2, col=1)
barcode_freq.update_yaxes(range=[0, 1])
barcode_freq.update_layout(bargap=0.0, height=1000)

barcode_freq


In [None]:
# apply savgol filter on data and determine thresholds for bacteria
sg_df = df.apply(
    savgol_filter,
    axis=1,
    result_type="broadcast",
    window_length=window,
    polyorder=order,
    deriv=2,
)
thresholds = sg_df.groupby("bacteria").max().to_numpy()
thresholds = (5 / 100) * np.max(thresholds, axis=1)
print(thresholds)
barcode_col = make_subplots(
    rows=3,
    cols=1,
    subplot_titles=("c-glutamicum", "m-bovis", "r-erythropolis"),
    vertical_spacing=0.05,
)
for i, bacterium in enumerate(np.unique(bacteria)):
    intensity_trace = go.Scatter(
        x=wavenumber[1],
        y=avg_df.loc[bacterium],
        mode="lines",
        # line=dict(color="black"),
        line=dict(color=colour[2 * i]),
        showlegend=False,
    )

    upper = go.Scatter(
        x=wavenumber[1],
        y=(avg_df.loc[bacterium] + std_df.loc[bacterium]),
        mode="lines",
        fill="tonexty",
        fillcolor="rgba(68, 68, 68, 0.3)",
        opacity=0.1,
        # line=dict(color='rgba(255,255,255,0)', width=0),
        line=dict(color=colour[2 * i], width=0),
        hoverinfo="skip",
        showlegend=False,
    )

    lower = go.Scatter(
        x=wavenumber[1],
        y=(avg_df.loc[bacterium] - std_df.loc[bacterium]),
        mode="lines",
        fill="tonexty",
        fillcolor="rgba(68, 68, 68, 0.3)",
        opacity=0.3,
        # line=dict(color='rgba(255,255,255,0)', width=0),
        line=dict(color=colour[2 * i], width=0),
        hoverinfo="skip",
        showlegend=False,
    )
    # here we use the `counts` value to determine bar colour for wavenumber - https://aocs.onlinelibrary.wiley.com/doi/full/10.1007/s11746-016-2808-7?saml_referrer
    # a key difference here is that they used the absolute second derivative value
    # their choice of 1% is a bit mental since our data isn't as exciting, so i used 5%
    sg_df.loc[bacterium][sg_df.loc[bacterium].abs() >= thresholds[i]] = 1
    sg_df.loc[bacterium][sg_df.loc[bacterium].abs() < thresholds[i]] = 0

    counts = sg_df.loc[bacterium].sum() / sg_df.loc[bacterium].sum().max()
    hist_trace = go.Bar(
        x=wavenumber[1],
        y=[1] * len(wavenumber[1]),
        offset=0,
        showlegend=False,
        marker_color=counts,
        marker_line_width=0,
        marker_colorscale="Greys",
    )
    barcode_col.add_trace(intensity_trace, row=i + 1, col=1)
    barcode_col.add_trace(upper, row=i + 1, col=1)
    barcode_col.add_trace(lower, row=i + 1, col=1)
    barcode_col.add_trace(hist_trace, row=i + 1, col=1)
barcode_col.update_xaxes(title=dict(text="Raman Shift"), row=3, col=1)
barcode_col.update_yaxes(title=dict(text="Intensity"), range=[0, 1], row=2, col=1)
barcode_col.update_yaxes(range=[0, 1])
barcode_col.update_layout(height=1000)
barcode_col
