In [180]:
import os
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from msc_project.utils.fit_NLS import *

plt.style.use('ggplot')

In [181]:
# TODO: Multithreading

def fit_voltages(data_path, n_vals, n_tol, output_file):
    csv_files = [f for f in os.listdir(data_path) if f.endswith('.csv')]
    fit_data = {"size": [], "voltage": [], "n": [], "A": [], "omega": [], "log_tlorentz": [], "r2_score": []}

    for f in csv_files:
        size = f.split(' ')[0]
        data = load_data(os.path.join(data_path, f))

        for n in n_vals:
            p0 = [0.2, 0.1, -7, n]
            bounds = ([-np.inf, -np.inf, -np.inf, n-n_tol], [np.inf, np.inf, np.inf, n+n_tol])
            for col in data.columns[1:]:
                try:
                    popt = fit_polarization(data['Pulse Width'], data[col], type="lorentzian", p0=p0, bounds=bounds)
                except Exception as e:
                    print(f"Couldn't fit {col}V (size {size}): {e}")
                    continue

                r2 = r2_score(data[col], polarization_lorentzian(data['Pulse Width'], *popt))

                fit_data["size"].append(size)
                fit_data["voltage"].append(col)
                fit_data["A"].append(popt[0])
                fit_data["omega"].append(popt[1])
                fit_data["log_tlorentz"].append(popt[2])
                fit_data["n"].append(popt[3])
                fit_data["r2_score"].append(r2)

                print(f"Fit {col}V (size {size}): {popt}")

    fit_df = pd.DataFrame(fit_data)
    fit_df.to_csv(output_file, index=False)

def fit_tau_lorentz(data_path, n_vals, n_tol, fix_n_tau, exclude_voltages, output_file, fix_V0=False):
    fit_df = pd.read_csv(data_path)

    tau_fit_data = {"size": [], "n": [], "t0": [], "V0": [], "n_tau": [], "r2_score": []}
    unique_sizes = fit_df['size'].unique()
    for size in unique_sizes:
        size_df = fit_df[fit_df['size'] == size]
        print(f"Processing size: {size}")

        for n in n_vals:
            size_n_df = size_df[np.isclose(size_df['n'], n, atol=n_tol*1.1)]
            if not size_n_df.empty:
                print(f"  n = {n}: {len(size_n_df)} entries")
                p0 = [1e-5, 5, n]
                bounds = ([-np.inf, -np.inf, 0], [np.inf, np.inf, 4])

                x = np.array(size_n_df['voltage'])
                y = np.array(size_n_df['log_tlorentz'])
                if exclude_voltages is not None:
                    mask = np.isin(x, exclude_voltages, invert=True)
                    x = x[mask]
                    y = y[mask]

                if fix_n_tau:
                    bounds = ([-np.inf, -np.inf, n-n_tol], [np.inf, np.inf, n+n_tol])

                if fix_V0:
                    p0[1] = fix_V0
                    bounds[0][1] = fix_V0 - 1e-5
                    bounds[1][1] = fix_V0 + 1e-5

                try:
                    popt = fit_tau(x, y, p0=p0, bounds=bounds)
                except Exception as e:
                    print(f"  Couldn't fit {size} (n={n}): {e}")
                    continue

                r2 = r2_score(y, np.log(f_tau(x, *popt)))

                tau_fit_data["size"].append(size)
                tau_fit_data["n"].append(n)
                tau_fit_data["t0"].append(popt[0])
                tau_fit_data["V0"].append(popt[1])
                tau_fit_data["n_tau"].append(popt[2])
                tau_fit_data["r2_score"].append(r2)

    tau_fit_df = pd.DataFrame(tau_fit_data)
    tau_fit_df.to_csv(output_file, index=False)

In [182]:
def plot_size_n(data_path, voltage_fit_path, tau_fit_path, size, n, n_tol, savefig=None,
                tau_fit_fixn_path=None, tau_fit_exclude_path=None, tau_fit_fixn_exclude_path=None):
    # read data
    csv_files = [f for f in os.listdir(data_path) if f.endswith('.csv')]
    size_csv_file = None
    for csv_file in csv_files:
        if f"{size} nm" in csv_file:
            size_csv_file = csv_file
            break
    if size_csv_file is None:
        raise ValueError(f"Couldn't find data for size {size}")
    data = load_data(os.path.join(data_path, size_csv_file))

    fit_df = pd.read_csv(voltage_fit_path)
    tau_fit_df = pd.read_csv(tau_fit_path)
    tau_fit_fixn_df = None if tau_fit_fixn_path is None else pd.read_csv(tau_fit_fixn_path)
    tau_fit_exclude_df = None if tau_fit_exclude_path is None else pd.read_csv(tau_fit_exclude_path)
    tau_fit_fixn_exclude_df = None if tau_fit_fixn_exclude_path is None else pd.read_csv(tau_fit_fixn_exclude_path)

    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    fig.suptitle(f"Size: {size} nm, n={n}")

    # plot polarization vs. pulse duration with fit on the left
    size_df = fit_df[fit_df['size'] == size]
    size_n_df = size_df[np.isclose(size_df['n'], n, atol=n_tol*1.1)]

    t = np.logspace(np.log10(data['Pulse Width'].min()), np.log10(data['Pulse Width'].max()), 1000)

    if not size_n_df.empty:
        norm = clr.Normalize()
        cmap=plt.cm.plasma     
        colors = cmap(norm(data.columns[1:].astype(float)))
        cbar = plt.colorbar(cm.ScalarMappable(cmap=cmap, norm=norm), ax=axs[0])
        cbar.set_label("V")
        
        for col, c in zip(data.columns[1:], colors):           
            axs[0].scatter(data["Pulse Width"], data[col], color=c)
            # plot fit
            popt = size_n_df[size_n_df['voltage'] == float(col)].values
            if len(popt) == 0:
                continue
            r2 = popt[0][-1]
            popt = popt[0][2:-1]
            n, A, omega, log_tlorentz = popt
            p = polarization_lorentzian(t, A, omega, log_tlorentz, n)
            label = f'{col}V: $A$={A:.2f}, $\\omega$={omega:.2f}, $\\log(\\tau_{{Lorentz}})$={log_tlorentz:.2f}, $n$={n:.2f}, ($R^2$={r2:.3f})'
            axs[0].plot(t, p, color=c, label=label)

        # plot tau vs. voltage with fit on the right
        axs[1].scatter(size_n_df['voltage'], size_n_df['log_tlorentz'], label="$\\tau_{Lorentz}$ (exp.)")

        tau_dfs = [tau_fit_df, tau_fit_fixn_df, tau_fit_exclude_df, tau_fit_fixn_exclude_df]
        notes = ["", ", fixed n", ", exclude low/high voltages", ", fixed n & exclude low/high voltages"]
        for df,note in zip(tau_dfs, notes):
            if df is not None:
                popt = df[(df['size'] == size) & np.isclose(df['n'], n, atol=n_tol*1.1)].values
                if popt is not None and len(popt) > 0:
                    r2 = popt[0][-1]
                    popt = popt[0][2:-1]
                    v = np.linspace(size_n_df['voltage'].min(), size_n_df['voltage'].max(), 1000)
                    label=f'$\\log(\\tau_{{Lorentz}})$ (fit{note}): $\\tau_0$={popt[0]:.2e}, $V_0$={popt[1]:.2f}, $n$={popt[2]:.2f}, ($R^2$={r2:.3f})'
                    t = f_tau(v, *popt)
                    axs[1].plot(v, np.log(t), label=label)
            

    axs[0].set(xlabel = 'Pulse duration', ylabel = 'Partial polarization', xscale='log', ylim=(0, 1))
    axs[0].legend()
    axs[1].set(xlabel = 'Voltage', ylabel = '$\log(\\tau_{{lorentz}})$')
    axs[1].legend()
    plt.tight_layout()

    if savefig is not None:
        plt.savefig(savefig)
    else:
        plt.show()

In [183]:
#DATA_PATH = '/Users/pauluv/Documents/Code/msc_project/data/Data_for_NLS'
DATA_PATH = '/scratch/msc24h18/nls_fit/data/Data_for_NLS'
RESULTS_DIR = '/scratch/msc24h18/nls_fit/results/nls_model'

results_subdir = os.path.join(RESULTS_DIR, datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
voltage_fits_path = os.path.join(results_subdir, 'voltage_fits.csv')
tau_fits_path = os.path.join(results_subdir, 'tau_fits.csv')
tau_fits_path_fixn = os.path.join(results_subdir, 'tau_fits_fixn.csv')
tau_fits_path_exclude3 = os.path.join(results_subdir, 'tau_fits_exclude3.csv')
tau_fits_path_fixn_exclude3 = os.path.join(results_subdir, 'tau_fits_fixn_exclude3.csv')

if not os.path.exists(results_subdir):
    os.makedirs(results_subdir)

n_vals = [1.5, 1.75, 2, 2.25, 2.5]
n_tol = 0.01

# this takes 20-40 minutes to run, depending on the computer
#fit_voltages(DATA_PATH, n_vals, n_tol, voltage_fits_path)

In [None]:
fix_V0 = None
#fit_tau_lorentz(voltage_fits_path, n_vals, n_tol, fix_n_tau=False, exclude_voltages=None, output_file=tau_fits_path, fix_V0=5)
#fit_tau_lorentz(voltage_fits_path, n_vals, n_tol, fix_n_tau=True, exclude_voltages=None, output_file=tau_fits_path_fixn, fix_V0=5)
#fit_tau_lorentz(voltage_fits_path, n_vals, n_tol, fix_n_tau=False, exclude_voltages=[1, 1.25, 1.5, 4.5, 4.75, 5], output_file=tau_fits_path_exclude3, fix_V0=5)
fit_tau_lorentz(voltage_fits_path, n_vals, n_tol, fix_n_tau=True, exclude_voltages=[1, 1.25, 1.5, 4.5, 4.75, 5], output_file=tau_fits_path_fixn_exclude3, fix_V0=fix_V0)

In [None]:
sizes = [500, 600, 800, 1000, 5000, 10000, 20000, 30000]
for size in sizes:
    size_dir = os.path.join(RESULTS_DIR, f"{size}nm")
    if not os.path.exists(size_dir):
        os.mkdir(size_dir)

    for n in n_vals:
        plot_size_n(DATA_PATH, voltage_fits_path, tau_fits_path, size, n, n_tol, savefig=os.path.join(size_dir, f"{size}nm_n{n}.png"),
                    tau_fit_fixn_path=tau_fits_path_fixn,
                    tau_fit_exclude_path=tau_fits_path_exclude3,
                    tau_fit_fixn_exclude_path=tau_fits_path_fixn_exclude3);
        print(f"Plotted {size}nm n={n}")

In [185]:
def plot_tau_vs_size(filename, suptitle="", minsize=0, savefig=None):
    tau_df = pd.read_csv(filename)
    tau_df.sort_values(by='size', ignore_index=True, inplace=True)

    fig, axs = plt.subplots(1, 2, figsize=(14, 7))
    fig.suptitle(suptitle)
    
    n_vals = tau_df['n'].unique()
    n_vals = sorted(n_vals)

    for n in n_vals:
        n_df = tau_df[(tau_df['n'] == n) & (tau_df['size'] >= minsize)]
        # plot tau vs. size on the left
        axs[0].plot(n_df['size'], n_df['t0'], label=f"n={n}", marker='o')
        # plot V0 vs. size on the right
        axs[1].plot(n_df['size'], round(n_df['V0']), label=f"n={n}", marker='o')

    axs[0].set(xlabel = 'Size (nm)', ylabel = '$\\tau_0$', yscale='log')
    axs[0].legend()
    axs[1].set(xlabel = 'Size (nm)', ylabel = '$V_0$')
    axs[1].legend()
    plt.tight_layout()

    if savefig is not None:
        plt.savefig(savefig)
    plt.show()

In [None]:
#plot_tau_vs_size(tau_fits_path, suptitle="Optimal $\\tau_0$ and $V_0$ vs. size", minsize=1000, savefig=os.path.join(results_subdir, "tauV0_vs_size.png"))
#plot_tau_vs_size(tau_fits_path_fixn, suptitle="Optimal $\\tau_0$ and $V_0$ vs. size (fixed n)", minsize=1000, savefig=os.path.join(results_subdir, "tauV0_vs_size_fixn.png"))
#plot_tau_vs_size(tau_fits_path_exclude3, suptitle="Optimal $\\tau_0$ and $V_0$ vs. size (exclude low/high voltages)", minsize=1000, savefig=os.path.join(results_subdir, "tauV0_vs_size_exclude3.png"))
plot_tau_vs_size(tau_fits_path_fixn_exclude3, suptitle="Optimal $\\tau_0$ and $V_0$ vs. size (fixed n, exclude low/high voltages)", minsize=1000, savefig=os.path.join(results_subdir, "tauV0_vs_size_fixn_exclude3.png"))

In [None]:
def plot_r2_vs_n(voltages_filename, tau_filename, exclude_voltages=None, minsize=0, suptitle="", savefig=None):
    # plot sum of R2 and R2 for log(tau) vs. voltage fit as function of n (fixed n and excluding low/high voltages only), for different sizes
    volt_df = pd.read_csv(voltages_filename)
    if exclude_voltages is not None:
        volt_df = volt_df[~volt_df['voltage'].isin(exclude_voltages)]
    volt_df = volt_df[volt_df['size'] >= minsize]
    volt_df.sort_values(by='n', ignore_index=True, inplace=True)

    tau_df = pd.read_csv(tau_filename)
    tau_df = tau_df[tau_df['size'] >= minsize]
    tau_df.sort_values(by='n', ignore_index=True, inplace=True)

    n_vals = tau_df['n'].unique()
    n_vals = sorted(n_vals)
    sizes = tau_df['size'].unique()
    sizes = sorted(sizes)

    fig, axs = plt.subplots(1, 2, figsize=(14, 7))
    fig.suptitle(suptitle)

    # Use a color palette with more colors
    #colors = plt.cm.plasma(np.linspace(0, 1, len(sizes)))
    colors = [f"C{i}" for i in range(len(sizes))]

    for size, color in zip(sizes, colors):
        # plot sum of r2 scores for polarization fit on the left
        size_volt_df = volt_df[volt_df['size'] == size]
        r2_sum = np.array([np.sum(size_volt_df[np.isclose(size_volt_df['n'], n, atol=n_tol*1.1)]['r2_score']) for n in n_vals])
        r2_lens = np.array([len(size_volt_df[np.isclose(size_volt_df['n'], n, atol=n_tol*1.1)]) for n in n_vals])
        r2_sum_norm = r2_sum / r2_lens
        
        axs[0].plot(n_vals, r2_sum_norm, label=f"{size} nm", marker='o', color=color)

        # plot r2 scores for tau fit on the right
        size_tau_df = tau_df[tau_df['size'] == size]
        axs[1].plot(size_tau_df['n'], size_tau_df['r2_score'], label=f"{size} nm", marker='o', color=color)
        
    axs[0].set(xlabel = 'n', ylabel = '[Normalized] Sum of $R^2$ scores for polarization fit')
    axs[0].legend()
    axs[1].set(xlabel = 'n', ylabel = '$R^2$ score for $\\log(\\tau_{{Lorentz}})=f(1/V)$ fit')
    axs[1].legend()
    plt.tight_layout()
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()

In [None]:
exclude_voltages = [1, 1.25, 1.5, 4.5, 4.75, 5]
plot_r2_vs_n(voltage_fits_path, tau_fits_path, exclude_voltages, minsize=1000, suptitle="Quality of fit vs n (excluding low/high voltages) (V0=7V)", savefig=os.path.join(results_subdir, "r2_vs_n.png"))