In [18]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re

# Function to read the data
def read_data(file_path):
    data = pd.read_csv(file_path, delimiter='\t')
    data.columns = data.columns.str.lower().str.strip()
    return data

# Function to filter correlations
def filter_correlation(data, target='kj_charge_pc', threshold=0.2):
    if target not in data.columns:
        raise ValueError(f"Target column '{target}' not found in the data. Available columns are: {data.columns.tolist()}")

    correlations = data.corr()[target].reset_index()
    correlations.columns = ['Parameter', 'Correlation']
    correlations['AbsCorrelation'] = correlations['Correlation'].abs()
    filtered_correlations = correlations[correlations['AbsCorrelation'] > threshold]
    high_corr_columns = filtered_correlations['Parameter'].tolist()

    if target not in high_corr_columns:
        high_corr_columns.append(target)

    filtered_data = data[high_corr_columns]
    return filtered_correlations, filtered_data

# Function to plot correlations
def plot_correlations(correlations):
    if correlations.empty:
        raise ValueError("No correlations above the specified threshold to plot.")

    plt.figure(figsize=(10, 8))
    correlations.sort_values('AbsCorrelation', inplace=True)
    plt.barh(correlations['Parameter'], correlations['AbsCorrelation'], color='skyblue')
    plt.xlabel('Absolute Correlation Coefficient')
    plt.title('Correlation with Target Variable')
    plt.tight_layout()
    plt.show()

# Function to filter columns by name and clean data
def filter_and_clean_data(data, exclude_keywords, target='kj_charge_pC'):
    target = target.lower().strip()

    if target not in data.columns:
        raise ValueError(f"Target column '{target}' not found in the data. Available columns are: {data.columns.tolist()}")

    filtered_columns = []
    for col in data.columns:
        if col == target:
            filtered_columns.append(col)
            continue

        if not any(re.search(keyword, col, re.IGNORECASE) for keyword in exclude_keywords):
            filtered_columns.append(col)

    filtered_data = data[filtered_columns]
    filtered_data = filtered_data.replace([np.inf, -np.inf], np.nan)
    cleaned_data = filtered_data.dropna()

    if target not in cleaned_data.columns:
        raise ValueError(f"Target column '{target}' was lost during cleaning. It's possible all rows with target were invalid.")

    return cleaned_data


file_path = "/Users/samuelbarber/Desktop/s6.txt"
data = read_data(file_path)

target_column = 'kj_charge_pC'  # make sure this is the exact name, considering case sensitivity
filtered_correlations, data_high_corr = filter_correlation(data, target=target_column)

plot_correlations(filtered_correlations)

exclude_keywords = ['phosphor', 'magspec', 'ebeam', 'shot', 'magcam', 'picoscope',]
cleaned_data = filter_and_clean_data(data_high_corr, exclude_keywords, target='kj_charge_pC')
# print(cleaned_data.head())



ValueError: Target column 'kj_charge_pC' not found in the data. Available columns are: ['u_dg645_shotcontrol shot #', 'u_dg645_shotcontrol delay.ch a', 'u_dg645_shotcontrol amplitude.ch ab', 'u_dg645_shotcontrol delay.ch b', 'u_dg645_shotcontrol delay.ch g', 'u_modeimageresp shot #', 'u_modeimageresp position.axis 1 alias:modeimager', 'u_modeimageresp position.axis 3 alias:probedelay', 'u_modeimageresp position.axis 2 alias:jetblade', 'u_bcavehallprobe shot #', 'u_bcavehallprobe field', 'u_bcavehallprobe rawfield', 'u_compaerotech shot #', 'u_compaerotech position.axis1 alias:grating separation (um)', 'u_compaerotech position.axis2 alias:grating1 angle', 'u_compaerotech position.axis3 alias:grating2 angle', 'u_hamaspectro shot #', 'u_hamaspectro lambda_b', 'u_hamaspectro lambda_r', 'u_hamaspectro lambda_width (sd)', 'u_hamaspectro meanwavelength', 'u_esp_jetxyz shot #', 'u_esp_jetxyz position.axis 1 alias:jet_x (mm)', 'u_esp_jetxyz position.axis 2 alias:jet_y (mm)', 'u_esp_jetxyz position.axis 3 alias:jet_z (mm)', 'u_hp_daq shot #', 'u_hp_daq analogoutput.channel 1 alias:pressurecontrolvoltage', 'u_pxi_slow shot #', 'u_pxi_slow measurements.analysis1 alias:ontargetenergy', 'u_pxi_slow measurements.analysis2 alias:gaia energy', 'u_pxi_slow calculations.linked.analysis1', 'u_pxi_slow calculations.linked.analysis2', 'u_pxi_slow calculations.method.analysis1', 'u_pxi_slow calculations.method.analysis2', 'u_pxi_slow measurements.analysis4', 'u_traserver02 shot #', 'u_traserver02 position.axis 3 alias:beamline_out-x', 'u_traserver02 position.axis 4 alias:beamline_out-y', 'u_traserver02 position.axis 5 alias:beamline_in-x', 'u_traserver02 position.axis 6 alias:beamline_in-y', 'u_hexapod shot #', 'u_hexapod uangle', 'u_hexapod vangle', 'u_hexapod wangle', 'u_hexapod xpos', 'u_hexapod ypos', 'u_hexapod zpos', 'u_pxi_fast2 shot #', 'u_pxi_fast2 calculations.method.analysis1', 'u_pxi_fast2 calculations.method.analysis2', 'u_pxi_fast2 calculations.linked.analysis1', 'u_pxi_fast2 calculations.linked.analysis2', 'u_pxi_fast2 measurements.analysis1', 'u_pxi_fast2 measurements.analysis2', 'u_bcavepicoserver shot #', 'u_bcavepicoserver position.axis 111 alias:jetbladeheight', 'uc_alineebeam3 shot #', 'uc_alineebeam3 exposure', 'uc_alineebeam3 triggerdelay', 'uc_alineebeam3 2ndmomw0x45', 'uc_alineebeam3 2ndmomw0x', 'uc_alineebeam3 2ndmomw0y45', 'uc_alineebeam3 2ndmomw0y', 'uc_alineebeam3 centroidx', 'uc_alineebeam3 centroidy', 'uc_alineebeam3 maxcounts', 'uc_alineebeam3 meancounts', 'uc_chicaneslit shot #', 'uc_chicaneslit exposure', 'uc_chicaneslit triggerdelay', 'uc_probe shot #', 'uc_probe exposure', 'uc_probe triggerdelay', 'uc_alineebeam1 shot #', 'uc_alineebeam1 exposure', 'uc_alineebeam1 triggerdelay', 'uc_diagnosticsphosphor shot #', 'uc_diagnosticsphosphor triggerdelay', 'uc_diagnosticsphosphor exposure', 'uc_diagnosticsphosphor 2ndmomw0x45', 'uc_diagnosticsphosphor 2ndmomw0x', 'uc_diagnosticsphosphor 2ndmomw0y45', 'uc_diagnosticsphosphor 2ndmomw0y', 'uc_diagnosticsphosphor centroidx', 'uc_diagnosticsphosphor centroidy', 'uc_diagnosticsphosphor maxcounts', 'uc_diagnosticsphosphor meancounts', 'uc_modeimager shot #', 'uc_modeimager exposure alias:exposure', 'uc_modeimager triggerdelay', 'uc_modeimager centroidx', 'uc_modeimager centroidy', 'uc_modeimager 2ndmomw0x45', 'uc_modeimager 2ndmomw0x', 'uc_modeimager 2ndmomw0y45', 'uc_modeimager 2ndmomw0y', 'uc_modeimager maxcounts', 'uc_tc_phosphor shot #', 'uc_tc_phosphor exposure', 'uc_tc_phosphor triggerdelay', 'uc_tc_phosphor 2ndmomw0x45', 'uc_tc_phosphor 2ndmomw0x', 'uc_tc_phosphor 2ndmomw0y45', 'uc_tc_phosphor 2ndmomw0y', 'uc_tc_phosphor centroidx', 'uc_tc_phosphor centroidy', 'uc_tc_phosphor maxcounts', 'uc_tc_phosphor meancounts', 'uc_amp3_ir_input shot #', 'uc_amp3_ir_input 2ndmomw0x45', 'uc_amp3_ir_input 2ndmomw0x', 'uc_amp3_ir_input 2ndmomw0y45', 'uc_amp3_ir_input 2ndmomw0y', 'uc_amp3_ir_input centroidx alias:uc_amp3_ir_input_x', 'uc_amp3_ir_input centroidy alias:uc_amp3_ir_input_y', 'uc_amp3_ir_input maxcounts alias:uc_amp3_ir_input_max', 'uc_amp3_ir_input exposure', 'uc_amp3_ir_input meancounts alias:uc_amp3_ir_input_mean', 'uc_amp4_ir_input shot #', 'uc_amp4_ir_input 2ndmomw0x45', 'uc_amp4_ir_input 2ndmomw0x', 'uc_amp4_ir_input 2ndmomw0y45', 'uc_amp4_ir_input 2ndmomw0y', 'uc_amp4_ir_input centroidx alias:uc_amp4_ir_input_x', 'uc_amp4_ir_input centroidy alias:uc_amp4_ir_input_y', 'uc_amp4_ir_input maxcounts alias:uc_amp4_ir_input_max', 'uc_amp4_ir_input exposure', 'uc_amp4_ir_input meancounts alias:uc_amp4_ir_input_mean', 'uc_amp4_ir_output shot #', 'uc_amp4_ir_output 2ndmomw0x', 'uc_amp4_ir_output meancounts', 'uc_amp4_ir_output 2ndmomw0x45', 'uc_amp4_ir_output 2ndmomw0y45', 'uc_amp4_ir_output 2ndmomw0y', 'uc_amp4_ir_output centroidx', 'uc_amp4_ir_output centroidy', 'uc_amp4_ir_output maxcounts', 'uc_expanderin1_pulsed shot #', 'uc_expanderin1_pulsed 2ndmomw0x45', 'uc_expanderin1_pulsed 2ndmomw0x', 'uc_expanderin1_pulsed 2ndmomw0y45', 'uc_expanderin1_pulsed 2ndmomw0y', 'uc_expanderin1_pulsed centroidx', 'uc_expanderin1_pulsed centroidy', 'uc_expanderin1_pulsed exposure', 'uc_expanderin1_pulsed maxcounts', 'uc_expanderin1_pulsed meancounts', 'u_velmex shot #', 'u_velmex position', 'u_ssr_2 shot #', 'u_ssr_2 digitaloutput.channel 6 alias:visaalignmentdiode', 'u_ssr_2 digitaloutput.channel 7 alias:alineebeam1', 'u_ssr_2 digitaloutput.channel 0 alias:emq1polarityswitch', 'u_visaplungers shot #', 'u_visaplungers digitaloutput.channel 0 alias:visaplunger8', 'u_visaplungers digitaloutput.channel 1 alias:visaplunger7', 'u_visaplungers digitaloutput.channel 2 alias:visaplunger6', 'u_visaplungers digitaloutput.channel 3 alias:visapunger5', 'u_visaplungers digitaloutput.channel 4 alias:visaplunger4', 'u_visaplungers digitaloutput.channel 5 alias:visaplunger3', 'u_visaplungers digitaloutput.channel 6 alias:visaplunger2', 'u_visaplungers digitaloutput.channel 7 alias:visaplunger1', 'uc_visaebeam1 shot #', 'uc_visaebeam1 exposure', 'uc_visaebeam2 shot #', 'uc_visaebeam2 exposure', 'uc_visaebeam4 shot #', 'uc_visaebeam4 exposure', 'uc_visaebeam3 shot #', 'uc_visaebeam3 exposure', 'uc_visaebeam6 shot #', 'uc_visaebeam6 exposure', 'uc_visaebeam5 shot #', 'uc_visaebeam5 exposure', 'uc_visaebeam7 shot #', 'uc_visaebeam7 exposure', 'uc_visaebeam8 shot #', 'uc_visaebeam8 exposure', 'uc_amp2_ir_input shot #', 'uc_amp2_ir_input exposure', 'uc_amp2_ir_input centroidx alias:uc_amp2_ir_input_x', 'uc_amp2_ir_input centroidy alias:uc_amp2_ir_input_y', 'uc_amp2_ir_input fwhmx', 'uc_amp2_ir_input fwhmy', 'uc_amp2_ir_input 2ndmomw0x45', 'uc_amp2_ir_input 2ndmomw0x', 'uc_amp2_ir_input 2ndmomw0y45', 'uc_amp2_ir_input 2ndmomw0y', 'uc_amp2_ir_input maxcounts alias:uc_amp2_ir_input_max', 'uc_amp2_ir_input meancounts alias:uc_amp2_ir_input_mean', 'uc_amp2depletion_south shot #', 'uc_amp2depletion_south exposure', 'uc_amp2depletion_south 2ndmomw0x', 'uc_amp2depletion_south 2ndmomw0y45', 'uc_amp2depletion_south 2ndmomw0y', 'uc_amp2depletion_south 2ndmomw0x45', 'uc_amp2depletion_south centroidx', 'uc_amp2depletion_south centroidy', 'uc_amp2depletion_south maxcounts', 'uc_amp2depletion_south meancounts', 'uc_amp3depletion_south shot #', 'uc_amp3depletion_south exposure', 'uc_amp3depletion_south meancounts', 'uc_amp3depletion_south 2ndmomw0x45', 'uc_amp3depletion_south 2ndmomw0x', 'uc_amp3depletion_south 2ndmomw0y45', 'uc_amp3depletion_south 2ndmomw0y', 'uc_amp3depletion_south centroidx', 'uc_amp3depletion_south centroidy', 'uc_amp3depletion_south maxcounts', 'uc_amp4depletion_south shot #', 'uc_amp4depletion_south exposure', 'uc_amp4depletion_south 2ndmomw0x45', 'uc_amp4depletion_south 2ndmomw0y45', 'uc_amp4depletion_south 2ndmomw0y', 'uc_amp4depletion_south centroidx', 'uc_amp4depletion_south centroidy', 'uc_amp4depletion_south maxcounts', 'uc_amp4depletion_south meancounts', 'uc_amp4depletion_south 2ndmomw0x', 'uc_phosphor1 shot #', 'uc_phosphor1 exposure', 'uc_phosphor1 2ndmomw0x', 'uc_phosphor1 2ndmomw0x45', 'uc_phosphor1 2ndmomw0y45', 'uc_phosphor1 2ndmomw0y', 'uc_phosphor1 centroidx', 'uc_phosphor1 centroidy', 'uc_phosphor1 triggerdelay', 'uc_phosphor1 meancounts', 'uc_phosphor1 maxcounts', 'uc_undulatorrad1 shot #', 'uc_undulatorrad1 exposure', 'uc_undulatorrad2 shot #', 'uc_undulatorrad2 exposure', 'uc_ghostfocus shot #', 'uc_ghostfocus centroidx', 'uc_ghostfocus centroidy', 'uc_ghostfocus 2ndmomw0x45', 'uc_ghostfocus 2ndmomw0x', 'uc_ghostfocus 2ndmomw0y45', 'uc_ghostfocus 2ndmomw0y', 'uc_ghostfocus fwhmx', 'uc_ghostfocus fwhmy', 'uc_ghostfocus maxcounts', 'uc_ghostfocus meancounts', 'uc_ghostupstream shot #', 'uc_ghostupstream 2ndmomw0x45', 'uc_ghostupstream 2ndmomw0x', 'uc_ghostupstream 2ndmomw0y45', 'uc_ghostupstream 2ndmomw0y', 'uc_ghostupstream centroidx', 'uc_ghostupstream centroidy', 'uc_ghostupstream fwhmx', 'uc_ghostupstream fwhmy', 'uc_ghostupstream maxcounts', 'uc_ghostupstream meancounts', 'u_zaber shot #', 'u_zaber position.ch1 alias:phasicsfocusstage', 'u_zaber position.ch2 alias:ghostfocusstage', 'u_highland shot #', 'u_highland channel1.delay', 'u-picoscope5245d shot #', 'u-picoscope5245d measurements.analysis1', 'u-picoscope5245d measurements.analysis2', 'u-picoscope5245d measurements.analysis3', 'u-picoscope5245d scopetrace.channel0', 'uc_alineebeam2 shot #', 'uc_alineebeam2 exposure', 'u_probecamstage shot #', 'u_probecamstage position', 'u_plc shot #', 'u_plc do.ch1 alias:oap-in-2-filter', 'u_traserver03 shot #', 'u_traserver03 position.axis 1 alias:bcavein-x', 'u_traserver03 position.axis 2 alias:bcavein-y', 'u_traserver03 position.axis 3 alias:oapin1-x', 'u_traserver03 position.axis 4 alias:oapin1-y', 'u_traserver03 position.axis 5 alias:dmsurface-x', 'u_traserver03 position.axis 6 alias:dmsurface-y', 'u_vacuumgauge shot #', 'u_vacuumgauge ai_mean.channel 0 alias:b-cave', 'u_vacuumgauge ai_mean.channel 1 alias:a-cave', 'u_vacuumgauge ai_mean.channel 2', 'u_vacuumgauge ai_mean.channel 3', 'u_chicanemotors shot #', 'u_chicanemotors ch1ax1stepsize', 'uc_topview shot #', 'uc_topview centroidx', 'uc_topview centroidy', 'uc_topview maxcounts', 'uc_topview meancounts', 'u_phasicsfilecopy shot #', 'u_phasicsfilecopy numfiles', 'u_1hzshiftedbox shot #', 'u_1hzshiftedbox delay.channel_c alias:gaia lamp timing', 'uc_gaiamode shot #', 'uc_gaiamode maxcounts', 'uc_gaiamode meancounts', 'uc_gaiamode centroidx', 'uc_gaiamode centroidy', 'uc_oapin2 shot #', 'uc_oapin2 2ndmomw0x45', 'uc_oapin2 2ndmomw0x', 'uc_oapin2 2ndmomw0y45', 'uc_oapin2 2ndmomw0y', 'uc_oapin2 centroidx', 'uc_oapin2 centroidy', 'uc_oapin2 maxcounts', 'uc_oapin2 meancounts', 'uc_tubein shot #', 'uc_tubein centroidx', 'uc_tubein centroidy', 'u_traserver01 shot #', 'u_traserver01 position.axis 1 alias:tubein-x', 'u_traserver01 position.axis 2 alias:tubein-y', 'u_traserver01 position.axis 3 alias:compin-x', 'u_traserver01 position.axis 4 alias:compin-y', 'uc_gratingmode shot #', 'uc_gratingmode centroidx', 'uc_gratingmode centroidy', 'uc_oapin1 shot #', 'uc_oapin1 centroidx', 'uc_oapin1 centroidy', 'uc_dmsurface shot #', 'uc_dmsurface centroidx', 'uc_dmsurface centroidy', 'uc_bcavein shot #', 'uc_bcavein centroidx', 'uc_bcavein centroidy', 'u_ghostwfs shot #', 'u_ghostwfs pupilxcent', 'u_ghostwfs beamcentroidx', 'u_ghostwfs beamcentroidy', 'u_ghostwfs beamdiameterx', 'u_ghostwfs beamdiametery', 'u_ghostwfs pupilxdiam', 'u_ghostwfs pupilycent', 'u_ghostwfs pupilydiam', 'u_ghostwfs radiusofcurvature', 'u_ghostwfs wavefrontmax', 'u_ghostwfs wavefrontmin', 'u_ghostwfs zernike1', 'u_ghostwfs zernike2', 'u_ghostwfs zernike3', 'u_ghostwfs zernike4', 'u_ghostwfs zernike5', 'u_ghostwfs zernike6', 'u_ghostwfs zernike7', 'u_ghostwfs zernike8', 'u_ghostwfs zernike9', 'u_ghostwfs zernike10', 'u_stretchterxmcc shot #', 'u_stretchterxmcc position.ch1 alias:stretcher-out-y', 'u_stretchterxmcc position.ch2 alias:stretcher-out-x', 'u_hiresmagcam shot #', 'u_hiresmagcam exposure', 'u_hiresmagcam maxpicocoulomb', 'u_hiresmagcam triggerdelay', 'u_hiresmagcam charge', 'u_gaiadaq shot #', 'u_gaiadaq ai_mean.channel 0', 'uc_bcavemagspeccam1 shot #', 'uc_bcavemagspeccam1 charge', 'uc_bcavemagspeccam1 maxpicocoulomb', 'uc_bcavemagspeccam2 shot #', 'uc_bcavemagspeccam2 charge', 'uc_bcavemagspeccam2 maxpicocoulomb', 'uc_bcavemagspeccam3 shot #', 'uc_bcavemagspeccam3 charge', 'uc_bcavemagspeccam3 maxpicocoulomb', 'u_bcavemagspec shot #', 'u_bcavemagspec deltaangle', 'datetime shot #', 'datetime timestamp', 'rep rate shot #', 'rep rate hz', 'shotnumber', 'bin #', 'scan', 'kj_charge_pc', 's_upramp fit param1 (arb)', 's_upramp fit param2 (mm)', 's_upramp fit param2 (mm).1', 's_plateau gauss fit amplitude (x10^18 cm^-3)', 's_plateau gauss fit sigma (mm)', 's_max downramp gradient (10^18 cm^-3 / mm)', 's_peak density (10^18 cm^-3)', 'shotnum']

In [17]:
filtered_cols

NameError: name 'filtered_cols' is not defined

In [14]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

def predict_charge(cleaned_data, target='kj_charge_pC'):
    # Ensure the target column is in the cleaned data
    if target not in cleaned_data.columns:
        raise ValueError(f"Target column '{target}' not found in the cleaned data. Available columns are: {cleaned_data.columns.tolist()}")

    # Split the data into two halves
    middle_idx = len(cleaned_data) // 2
    train_data = cleaned_data[:middle_idx]
    test_data = cleaned_data[middle_idx:]

    # Separate the features from the target variable for both training and testing
    X_train = train_data.drop(columns=[target])
    y_train = train_data[target]
    X_test = test_data.drop(columns=[target])
    y_test = test_data[target]

    # Initialize and train the model
    model = LinearRegression()
    model.fit(X_train, y_train)

    # Predict the target variable for the test set
    y_pred = model.predict(X_test)

    # Calculate the model evaluation metrics
    mse = mean_squared_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)

    # Prepare the result
    result = {
        "Mean Squared Error": mse,
        "R2 Score": r2,
        "Predictions": y_pred,
        "Actual Values": y_test.values
    }

    return result

# Usage
target_column = 'kj_charge_pC'  # ensure this is the exact name, considering case sensitivity
prediction_results = predict_charge(cleaned_data, target=target_column)

print("Mean Squared Error:", prediction_results["Mean Squared Error"])
print("R2 Score:", prediction_results["R2 Score"])

# If you want to visualize the results or do further analysis, you can access 'Predictions' and 'Actual Values' from the result.


ValueError: Target column 'kj_charge_pC' not found in the cleaned data. Available columns are: ['uc_amp3_ir_input meancounts alias:uc_amp3_ir_input_mean', 'uc_amp4_ir_input meancounts alias:uc_amp4_ir_input_mean', 'uc_expanderin1_pulsed meancounts', 'uc_amp2depletion_south meancounts', 'uc_amp3depletion_south 2ndmomw0x45', 'uc_amp3depletion_south 2ndmomw0y45', 'uc_amp4depletion_south 2ndmomw0y45', 'uc_amp4depletion_south meancounts', 'uc_ghostfocus 2ndmomw0x45', 'uc_ghostfocus 2ndmomw0x', 'uc_ghostfocus 2ndmomw0y45', 'uc_ghostfocus 2ndmomw0y', 'uc_ghostfocus fwhmy', 'uc_ghostfocus meancounts', 'uc_ghostupstream 2ndmomw0x45', 'uc_ghostupstream 2ndmomw0x', 'uc_ghostupstream 2ndmomw0y45', 'uc_ghostupstream 2ndmomw0y', 'uc_ghostupstream meancounts', 'uc_topview centroidx', 'uc_topview maxcounts', 'uc_topview meancounts', 'uc_oapin2 meancounts', 'u_ghostwfs radiusofcurvature', 'u_ghostwfs wavefrontmin', 'u_ghostwfs zernike4', 'u_ghostwfs zernike5', 'kj_charge_pc']